diff --git a/RELEASE.md b/RELEASE.md index 39fc46ac6357300ea2b3365fa4c6d432d2a206db..fdf10407fda21444f1d0ee6cf20650d2659b146f 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -160,7 +160,7 @@ answered questions, and were part of inspiring discussions. # Release 1.4.1 ## Bug Fixes and Other Changes -* `LinearClassifier` fix for CloudML Engine. +* `LinearClassifier` fix. # Release 1.4.0 diff --git a/WORKSPACE b/WORKSPACE index 7ae39374f18efd3bddb9aae9bb8dba5c13a61dcc..1e38a9a8cd754886fc5232531816b875de0879a3 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -41,12 +41,12 @@ load("//tensorflow:workspace.bzl", "tf_workspace") tf_workspace() new_http_archive( - name = "inception5h", + name = "inception_v1", build_file = "models.BUILD", - sha256 = "d13569f6a98159de37e92e9c8ec4dae8f674fbf475f69fe6199b514f756d4364", + sha256 = "7efe12a8363f09bc24d7b7a450304a15655a57a7751929b2c1593a71183bb105", urls = [ - "http://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip", - "http://download.tensorflow.org/models/inception5h.zip", + "http://storage.googleapis.com/download.tensorflow.org/models/inception_v1.zip", + "http://download.tensorflow.org/models/inception_v1.zip", ], ) diff --git a/tensorflow/BUILD b/tensorflow/BUILD index da37564697a7159518a6ba71271f911713e3e58e..63849943e4bdef132a9fdaead3d57811e24e686b 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -376,6 +376,7 @@ package_group( "//learning/meta_rank/...", "//tensorflow/...", "//tensorflow_fold/llgtm/...", + "//third_party/py/tensor2tensor/...", ], ) @@ -441,9 +442,6 @@ filegroup( "//tensorflow/contrib/all_reduce:all_files", "//tensorflow/contrib/android:all_files", "//tensorflow/contrib/batching:all_files", - "//tensorflow/contrib/batching/kernels:all_files", - "//tensorflow/contrib/batching/test_util:all_files", - "//tensorflow/contrib/batching/util:all_files", "//tensorflow/contrib/bayesflow:all_files", "//tensorflow/contrib/boosted_trees:all_files", "//tensorflow/contrib/boosted_trees/estimator_batch:all_files", @@ -537,7 +535,7 @@ filegroup( "//tensorflow/contrib/periodic_resample:all_files", "//tensorflow/contrib/predictor:all_files", "//tensorflow/contrib/py2tf:all_files", - "//tensorflow/contrib/py2tf/convert:all_files", + "//tensorflow/contrib/py2tf/converters:all_files", "//tensorflow/contrib/py2tf/pyct:all_files", "//tensorflow/contrib/py2tf/pyct/static_analysis:all_files", "//tensorflow/contrib/quantize:all_files", diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD index f258bcd95684cc58c2ead3886b3ce74e4af6c5aa..c46cb32aa46af474c889095564d46c5f2399c3ad 100644 --- a/tensorflow/c/BUILD +++ b/tensorflow/c/BUILD @@ -26,6 +26,18 @@ filegroup( visibility = ["//tensorflow:__subpackages__"], ) +filegroup( + name = "srcs", + srcs = glob( + [ + "*.cc", + "*.h", + ], + exclude = ["*test*"], + ), + visibility = ["//visibility:public"], +) + tf_cuda_library( name = "c_api_internal", srcs = ["c_api.h"], diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index 6fc75a98f1e05c3971cb4546bd16f015c25b6709..3c7f041b39f01d9b8b187079b00e0c5ad99a38cc 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -927,6 +927,7 @@ int TF_DeviceListCount(const TF_DeviceList* list) { status->status = InvalidArgument("index out of bounds"); \ return err_val; \ } \ + status->status = Status::OK(); \ return list->response[index].accessor; \ } @@ -1469,7 +1470,13 @@ int TF_OperationOutputConsumers(TF_Output oper_out, TF_Input* consumers, } int TF_OperationNumControlInputs(TF_Operation* oper) { - return oper->node.in_edges().size() - oper->node.num_inputs(); + int count = 0; + for (const auto* edge : oper->node.in_edges()) { + if (edge->IsControlEdge() && !edge->src()->IsSource()) { + ++count; + } + } + return count; } int TF_OperationGetControlInputs(TF_Operation* oper, @@ -1477,7 +1484,7 @@ int TF_OperationGetControlInputs(TF_Operation* oper, int max_control_inputs) { int count = 0; for (const auto* edge : oper->node.in_edges()) { - if (edge->IsControlEdge()) { + if (edge->IsControlEdge() && !edge->src()->IsSource()) { if (count < max_control_inputs) { control_inputs[count] = ToOperation(edge->src()); } @@ -1490,7 +1497,7 @@ int TF_OperationGetControlInputs(TF_Operation* oper, int TF_OperationNumControlOutputs(TF_Operation* oper) { int count = 0; for (const auto* edge : oper->node.out_edges()) { - if (edge->IsControlEdge()) { + if (edge->IsControlEdge() && !edge->dst()->IsSink()) { ++count; } } @@ -1502,7 +1509,7 @@ int TF_OperationGetControlOutputs(TF_Operation* oper, int max_control_outputs) { int count = 0; for (const auto* edge : oper->node.out_edges()) { - if (edge->IsControlEdge()) { + if (edge->IsControlEdge() && !edge->dst()->IsSink()) { if (count < max_control_outputs) { control_outputs[count] = ToOperation(edge->dst()); } diff --git a/tensorflow/c/c_api_test.cc b/tensorflow/c/c_api_test.cc index df697e16d3d3fcaac66f967c0d3938450f0b0be6..01954eb235f1a93d943c2ec7ea4c5ca44785d402 100644 --- a/tensorflow/c/c_api_test.cc +++ b/tensorflow/c/c_api_test.cc @@ -575,7 +575,7 @@ TEST(CAPI, ImportGraphDef) { TF_Status* s = TF_NewStatus(); TF_Graph* graph = TF_NewGraph(); - // Create a graph with two nodes: x and 3 + // Create a simple graph. Placeholder(graph, s); ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); ASSERT_TRUE(TF_GraphOperationByName(graph, "feed") != nullptr); @@ -586,7 +586,7 @@ TEST(CAPI, ImportGraphDef) { ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); ASSERT_TRUE(TF_GraphOperationByName(graph, "neg") != nullptr); - // Export to a GraphDef + // Export to a GraphDef. TF_Buffer* graph_def = TF_NewBuffer(); TF_GraphToGraphDef(graph, graph_def, s); ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); @@ -606,6 +606,31 @@ TEST(CAPI, ImportGraphDef) { ASSERT_TRUE(feed != nullptr); ASSERT_TRUE(neg != nullptr); + // Test basic structure of the imported graph. + EXPECT_EQ(0, TF_OperationNumInputs(scalar)); + EXPECT_EQ(0, TF_OperationNumInputs(feed)); + ASSERT_EQ(1, TF_OperationNumInputs(neg)); + TF_Output neg_input = TF_OperationInput({neg, 0}); + EXPECT_EQ(scalar, neg_input.oper); + EXPECT_EQ(0, neg_input.index); + + // Test that we can't see control edges involving the source and sink nodes. + TF_Operation* control_ops[100]; + EXPECT_EQ(0, TF_OperationNumControlInputs(scalar)); + EXPECT_EQ(0, TF_OperationGetControlInputs(scalar, control_ops, 100)); + EXPECT_EQ(0, TF_OperationNumControlOutputs(scalar)); + EXPECT_EQ(0, TF_OperationGetControlOutputs(scalar, control_ops, 100)); + + EXPECT_EQ(0, TF_OperationNumControlInputs(feed)); + EXPECT_EQ(0, TF_OperationGetControlInputs(feed, control_ops, 100)); + EXPECT_EQ(0, TF_OperationNumControlOutputs(feed)); + EXPECT_EQ(0, TF_OperationGetControlOutputs(feed, control_ops, 100)); + + EXPECT_EQ(0, TF_OperationNumControlInputs(neg)); + EXPECT_EQ(0, TF_OperationGetControlInputs(neg, control_ops, 100)); + EXPECT_EQ(0, TF_OperationNumControlOutputs(neg)); + EXPECT_EQ(0, TF_OperationGetControlOutputs(neg, control_ops, 100)); + // Import it again, with an input mapping, return outputs, and a return // operation, into the same graph. TF_DeleteImportGraphDefOptions(opts); @@ -629,7 +654,7 @@ TEST(CAPI, ImportGraphDef) { ASSERT_TRUE(neg2 != nullptr); // Check input mapping - TF_Output neg_input = TF_OperationInput({neg, 0}); + neg_input = TF_OperationInput({neg, 0}); EXPECT_EQ(scalar, neg_input.oper); EXPECT_EQ(0, neg_input.index); diff --git a/tensorflow/c/c_test_util.h b/tensorflow/c/c_test_util.h index 3429009a71a863ae6b69b5cd29ace3c7fd078f4c..6acc2fec0063a8592e8e22a00b530df05a08cdb8 100644 --- a/tensorflow/c/c_test_util.h +++ b/tensorflow/c/c_test_util.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_C_C_TEST_UTIL_H_ -#define THIRD_PARTY_TENSORFLOW_C_C_TEST_UTIL_H_ +#ifndef TENSORFLOW_C_C_TEST_UTIL_H_ +#define TENSORFLOW_C_C_TEST_UTIL_H_ #include "tensorflow/c/c_api.h" @@ -136,4 +136,4 @@ class CSession { std::vector targets_; }; -#endif // THIRD_PARTY_TENSORFLOW_C_C_TEST_UTIL_H_ +#endif // TENSORFLOW_C_C_TEST_UTIL_H_ diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 04a415b909ba3e76dfc12a3522f85d290ba6d36f..a76c8f5ec05fc3199addc67857d7bb2ea0e263c2 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -118,6 +118,23 @@ void TFE_ContextClearCaches(TFE_Context* ctx) { tensorflow::gtl::STLDeleteValues(&ctx->kernel_cache); } +void TFE_ContextSetThreadLocalDevicePlacementPolicy( + TFE_Context* ctx, TFE_ContextDevicePlacementPolicy policy) { + tensorflow::mutex_lock ml(ctx->policy_map_mu); + ctx->thread_local_policies[std::this_thread::get_id()] = policy; +} + +extern TFE_ContextDevicePlacementPolicy TFE_ContextGetDevicePlacementPolicy( + TFE_Context* ctx) { + tensorflow::mutex_lock ml(ctx->policy_map_mu); + auto policy_map_it = + ctx->thread_local_policies.find(std::this_thread::get_id()); + if (policy_map_it != ctx->thread_local_policies.end()) { + return policy_map_it->second; + } + return ctx->policy; +} + TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status) { tensorflow::Tensor tensor; status->status = tensorflow::TF_TensorToTensor(t, &tensor); @@ -435,10 +452,17 @@ tensorflow::Status ValidateInputTypeAndPlacement( const tensorflow::Device* actual_device = op->input_devices[i] == nullptr ? host_device : op->input_devices[i]; if (expected_device != actual_device) { - switch (ctx->policy) { - case TFE_DEVICE_PLACEMENT_EXPLICIT: + switch (TFE_ContextGetDevicePlacementPolicy(ctx)) { + case TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32: // TODO(xpan): See if we could bubble python related error up // to python level. + if (op->inputs[i].dtype() == tensorflow::DT_INT32) { + // Note: enabling silent copies of int32 tensors to match behavior + // of graph mode. + break; + } + TF_FALLTHROUGH_INTENDED; + case TFE_DEVICE_PLACEMENT_EXPLICIT: return tensorflow::errors::InvalidArgument( "Tensors on conflicting devices:" " cannot compute ", diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h index 9b0fd037da35f31e9b97f29b1269bbca9e4c849d..387de078948e5076d0b069d6380dfc04ea6254df 100644 --- a/tensorflow/c/eager/c_api.h +++ b/tensorflow/c/eager/c_api.h @@ -61,14 +61,16 @@ TF_CAPI_EXPORT extern void TFE_ContextOptionsSetConfig( // Controls how to act when we try to run an operation on a given device but // some input tensors are not on that device. typedef enum TFE_ContextDevicePlacementPolicy { - // The default: running operations with input tensors on the wrong device will - // fail. + // Running operations with input tensors on the wrong device will fail. TFE_DEVICE_PLACEMENT_EXPLICIT = 0, // Copy the tensor to the right device but log a warning. TFE_DEVICE_PLACEMENT_WARN = 1, // Silently copy the tensor, which has a performance cost since the // operation will be blocked till the copy completes. TFE_DEVICE_PLACEMENT_SILENT = 2, + // Default placement policy which silently copies int32 tensors but not other + // dtypes. + TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32 = 3, } TFE_ContextDevicePlacementPolicy; TF_CAPI_EXPORT extern void TFE_ContextOptionsSetDevicePlacementPolicy( @@ -93,6 +95,18 @@ TF_CAPI_EXPORT extern TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, // ops. TF_CAPI_EXPORT extern void TFE_ContextClearCaches(TFE_Context* ctx); +// Sets a thread-local device placement policy. After this call, other calls to +// TFE_Execute in the same thread will use the device policy specified here +// instead of the device policy used to construct the context. This has no +// effect on the device policy used by other program threads. +TF_CAPI_EXPORT extern void TFE_ContextSetThreadLocalDevicePlacementPolicy( + TFE_Context*, TFE_ContextDevicePlacementPolicy); + +// Returns the device placement policy to be used by this context in the current +// thread. +TF_CAPI_EXPORT extern TFE_ContextDevicePlacementPolicy +TFE_ContextGetDevicePlacementPolicy(TFE_Context*); + // A handle to a tensor on a device. // // Like a TF_Tensor, a TFE_TensorHandle refers to a tensor with a value, shape, diff --git a/tensorflow/c/eager/c_api_internal.h b/tensorflow/c/eager/c_api_internal.h index 55a04d48bad63a8c19ffdc39675b1e1b70ac80d7..a6f76c732f2a4c2402a27cd69c101d028dbb8fcc 100644 --- a/tensorflow/c/eager/c_api_internal.h +++ b/tensorflow/c/eager/c_api_internal.h @@ -21,6 +21,7 @@ limitations under the License. #include #include #include +#include #include #include "tensorflow/c/c_api.h" @@ -37,7 +38,8 @@ limitations under the License. struct TFE_ContextOptions { TF_SessionOptions session_options; - TFE_ContextDevicePlacementPolicy policy{TFE_DEVICE_PLACEMENT_EXPLICIT}; + TFE_ContextDevicePlacementPolicy policy{ + TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32}; }; struct TFE_Context { @@ -45,6 +47,12 @@ struct TFE_Context { TFE_ContextDevicePlacementPolicy policy; + // Note: we cannot use C++11 thread_local here as there is no concept of a + // thread-local-object-local variable in C++11. + tensorflow::mutex policy_map_mu; + std::unordered_map + thread_local_policies GUARDED_BY(policy_map_mu); + // TFE_Context is an extension of TF_Session. And TF_Session needs a TF_Graph. TF_Session* session; tensorflow::Rendezvous* rendezvous; diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc index 423a7e1ff71bfdc5f51e36ae63359869ea079ddc..18e7a64435e6c7e51998a744abd615edc7ad4318 100644 --- a/tensorflow/c/eager/c_api_test.cc +++ b/tensorflow/c/eager/c_api_test.cc @@ -321,6 +321,55 @@ TEST(CAPI, TensorHandleSilentCopy) { EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); } +TEST(CAPI, TensorHandleSilentCopyLocal) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetDevicePlacementPolicy(opts, + TFE_DEVICE_PLACEMENT_EXPLICIT); + TFE_Context* ctx = TFE_NewContext(opts, status.get()); + TFE_ContextSetThreadLocalDevicePlacementPolicy(ctx, + TFE_DEVICE_PLACEMENT_SILENT); + TFE_DeleteContextOptions(opts); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + + TFE_TensorHandle* hcpu = TestMatrixTensorHandle(); + TF_Tensor* t = TFE_TensorHandleResolve(hcpu, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + + TF_DeviceList* devices = TFE_ContextListDevices(ctx, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + const int num_devices = TF_DeviceListCount(devices); + + // Disable the test if no GPU is present. + if (num_devices > 1) { + const int device_to_use = 1; + const string name(TF_DeviceListName(devices, device_to_use, status.get())); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + + TFE_TensorHandle* hgpu = + TFE_TensorHandleCopyToDevice(hcpu, ctx, name.c_str(), status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + + TFE_Op* matmul = MatMulOp(ctx, hcpu, hgpu); + TFE_OpSetDevice(matmul, name.c_str(), status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + TFE_TensorHandle* retvals[1]; + int num_retvals = 1; + TFE_Execute(matmul, &retvals[0], &num_retvals, status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + TFE_DeleteOp(matmul); + TFE_DeleteTensorHandle(retvals[0]); + TFE_DeleteTensorHandle(hgpu); + } + + TF_DeleteDeviceList(devices); + TF_DeleteTensor(t); + TFE_DeleteTensorHandle(hcpu); + TFE_DeleteContext(ctx, status.get()); + EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); +} + TEST(CAPI, Execute) { TF_Status* status = TF_NewStatus(); TFE_ContextOptions* opts = TFE_NewContextOptions(); diff --git a/tensorflow/c/python_api.h b/tensorflow/c/python_api.h index b51ef2b53122802fef598a26bd6f1843976f11b0..aa9d9e06b28c54cb8869eb547d36ee3cb0d4e6b8 100644 --- a/tensorflow/c/python_api.h +++ b/tensorflow/c/python_api.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_C_PYTHON_API_H_ -#define THIRD_PARTY_TENSORFLOW_C_PYTHON_API_H_ +#ifndef TENSORFLOW_C_PYTHON_API_H_ +#define TENSORFLOW_C_PYTHON_API_H_ #include "tensorflow/c/c_api.h" @@ -39,4 +39,4 @@ void RemoveAllControlInputs(TF_Graph* graph, TF_Operation* op); } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_C_PYTHON_API_H_ +#endif // TENSORFLOW_C_PYTHON_API_H_ diff --git a/tensorflow/cc/BUILD b/tensorflow/cc/BUILD index ddcee3deee444382f4bdb206de6f06ee62265a51..c9ade5fb83ff5b80a62bc960d1af1dc55f458c4e 100644 --- a/tensorflow/cc/BUILD +++ b/tensorflow/cc/BUILD @@ -673,7 +673,6 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", - "//tensorflow/core:tensorflow", ], ) diff --git a/tensorflow/cc/framework/cc_op_gen.h b/tensorflow/cc/framework/cc_op_gen.h index 1b5f7dd923731e56ab3d7e5288d17fef9eb3beb0..c7256a7dc384e652fa1bddfe3aa9893491c2b14c 100644 --- a/tensorflow/cc/framework/cc_op_gen.h +++ b/tensorflow/cc/framework/cc_op_gen.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_CC_OP_GEN_H_ -#define THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_CC_OP_GEN_H_ +#ifndef TENSORFLOW_CC_FRAMEWORK_CC_OP_GEN_H_ +#define TENSORFLOW_CC_FRAMEWORK_CC_OP_GEN_H_ #include "tensorflow/core/framework/op_def.pb.h" #include "tensorflow/core/framework/op_gen_lib.h" @@ -28,4 +28,4 @@ void WriteCCOps(const OpList& ops, const ApiDefMap& api_def_map, } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_CC_OP_GEN_H_ +#endif // TENSORFLOW_CC_FRAMEWORK_CC_OP_GEN_H_ diff --git a/tensorflow/cc/framework/grad_op_registry.h b/tensorflow/cc/framework/grad_op_registry.h index 190b96f68506c6b5252d6c0184f1712310477a8a..0fc5abb20c884a66539682099497e2c8511a620f 100644 --- a/tensorflow/cc/framework/grad_op_registry.h +++ b/tensorflow/cc/framework/grad_op_registry.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_GRAD_OP_REGISTRY_H_ -#define THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_GRAD_OP_REGISTRY_H_ +#ifndef TENSORFLOW_CC_FRAMEWORK_GRAD_OP_REGISTRY_H_ +#define TENSORFLOW_CC_FRAMEWORK_GRAD_OP_REGISTRY_H_ #include @@ -72,4 +72,4 @@ class GradOpRegistry { } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_GRAD_OP_REGISTRY_H_ +#endif // TENSORFLOW_CC_FRAMEWORK_GRAD_OP_REGISTRY_H_ diff --git a/tensorflow/cc/framework/gradient_checker.h b/tensorflow/cc/framework/gradient_checker.h index d055c60d09c2f33fb1f61165f75b2d04618620b7..1aa215a9088335580667e0c23c7244e6e5047f1a 100644 --- a/tensorflow/cc/framework/gradient_checker.h +++ b/tensorflow/cc/framework/gradient_checker.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_GRADIENT_CHECKER_H_ -#define THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_GRADIENT_CHECKER_H_ +#ifndef TENSORFLOW_CC_FRAMEWORK_GRADIENT_CHECKER_H_ +#define TENSORFLOW_CC_FRAMEWORK_GRADIENT_CHECKER_H_ #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/framework/scope.h" @@ -60,4 +60,4 @@ Status ComputeGradientError(const Scope& scope, const Output& x, } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_GRADIENT_CHECKER_H_ +#endif // TENSORFLOW_CC_FRAMEWORK_GRADIENT_CHECKER_H_ diff --git a/tensorflow/cc/framework/gradients.h b/tensorflow/cc/framework/gradients.h index 717f6f0636d3dd1a546ef7477b100bbfc86ba13d..0a377ad56d139a6ec26ea97b4e1e43495d0b3165 100644 --- a/tensorflow/cc/framework/gradients.h +++ b/tensorflow/cc/framework/gradients.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_GRADIENTS_H_ -#define THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_GRADIENTS_H_ +#ifndef TENSORFLOW_CC_FRAMEWORK_GRADIENTS_H_ +#define TENSORFLOW_CC_FRAMEWORK_GRADIENTS_H_ #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/framework/scope.h" @@ -49,4 +49,4 @@ Output NoGradient(); } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_GRADIENTS_H_ +#endif // TENSORFLOW_CC_FRAMEWORK_GRADIENTS_H_ diff --git a/tensorflow/cc/framework/ops.h b/tensorflow/cc/framework/ops.h index 8d4154220c4b18f9286094b10c1b1e96eb4e31e7..a085e1d6e2de5ad63d11eb8979ae64c26b91366f 100644 --- a/tensorflow/cc/framework/ops.h +++ b/tensorflow/cc/framework/ops.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_OPS_H_ -#define THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_OPS_H_ +#ifndef TENSORFLOW_CC_FRAMEWORK_OPS_H_ +#define TENSORFLOW_CC_FRAMEWORK_OPS_H_ #include @@ -296,4 +296,4 @@ class InputList { } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_OPS_H_ +#endif // TENSORFLOW_CC_FRAMEWORK_OPS_H_ diff --git a/tensorflow/cc/framework/scope.h b/tensorflow/cc/framework/scope.h index 0225ac047291d6297af558fddad6e5315389ff40..30c32bd44b0f22d6b29dd3836d431807d0216818 100644 --- a/tensorflow/cc/framework/scope.h +++ b/tensorflow/cc/framework/scope.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_SCOPE_H_ -#define THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_SCOPE_H_ +#ifndef TENSORFLOW_CC_FRAMEWORK_SCOPE_H_ +#define TENSORFLOW_CC_FRAMEWORK_SCOPE_H_ #include #include @@ -242,4 +242,4 @@ struct CompositeOpScopes { } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_SCOPE_H_ +#endif // TENSORFLOW_CC_FRAMEWORK_SCOPE_H_ diff --git a/tensorflow/cc/framework/scope_internal.h b/tensorflow/cc/framework/scope_internal.h index 968c366550ef6f46557cd9b5662d9d0719b31531..8efcfed20d0b86d86d8c20a3d8630c7c6bc909c3 100644 --- a/tensorflow/cc/framework/scope_internal.h +++ b/tensorflow/cc/framework/scope_internal.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_SCOPE_INTERNAL_H_ -#define THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_SCOPE_INTERNAL_H_ +#ifndef TENSORFLOW_CC_FRAMEWORK_SCOPE_INTERNAL_H_ +#define TENSORFLOW_CC_FRAMEWORK_SCOPE_INTERNAL_H_ #include "tensorflow/cc/framework/scope.h" @@ -117,4 +117,4 @@ class Scope::Impl { } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_SCOPE_INTERNAL_H_ +#endif // TENSORFLOW_CC_FRAMEWORK_SCOPE_INTERNAL_H_ diff --git a/tensorflow/cc/framework/testutil.h b/tensorflow/cc/framework/testutil.h index a3e19870ec847bcd4f0e0bf0e71dda724024d5d2..7ad6fb4a676639f5d6d3da6a7c08de1894162f0c 100644 --- a/tensorflow/cc/framework/testutil.h +++ b/tensorflow/cc/framework/testutil.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_TESTUTIL_H_ -#define THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_TESTUTIL_H_ +#ifndef TENSORFLOW_CC_FRAMEWORK_TESTUTIL_H_ +#define TENSORFLOW_CC_FRAMEWORK_TESTUTIL_H_ #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/framework/scope.h" @@ -44,4 +44,4 @@ void GetTensor(const Scope& scope, const std::vector& assign_vars, } // namespace test } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_TESTUTIL_H_ +#endif // TENSORFLOW_CC_FRAMEWORK_TESTUTIL_H_ diff --git a/tensorflow/cc/framework/while_gradients.h b/tensorflow/cc/framework/while_gradients.h index 8f592accc93573cb8953a5ab25c04881ca0c2333..cb4e579c8548294ec45b0c3f42cb844e0b87c390 100644 --- a/tensorflow/cc/framework/while_gradients.h +++ b/tensorflow/cc/framework/while_gradients.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_WHILE_GRADIENTS_H_ -#define THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_WHILE_GRADIENTS_H_ +#ifndef TENSORFLOW_CC_FRAMEWORK_WHILE_GRADIENTS_H_ +#define TENSORFLOW_CC_FRAMEWORK_WHILE_GRADIENTS_H_ #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/framework/scope.h" @@ -37,4 +37,4 @@ Status AddWhileLoopGradient(WhileContext* while_ctx, const Scope& scope, } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_WHILE_GRADIENTS_H_ +#endif // TENSORFLOW_CC_FRAMEWORK_WHILE_GRADIENTS_H_ diff --git a/tensorflow/cc/gradients/grad_testutil.h b/tensorflow/cc/gradients/grad_testutil.h index d31f412754ff59cc7782b14e285071a8d4218d08..70c81f1a73a394322c602a5c51e3c2a40aca2397 100644 --- a/tensorflow/cc/gradients/grad_testutil.h +++ b/tensorflow/cc/gradients/grad_testutil.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CC_GRADIENTS_GRAD_TESTUTIL_H_ -#define THIRD_PARTY_TENSORFLOW_CC_GRADIENTS_GRAD_TESTUTIL_H_ +#ifndef TENSORFLOW_CC_GRADIENTS_GRAD_TESTUTIL_H_ +#define TENSORFLOW_CC_GRADIENTS_GRAD_TESTUTIL_H_ #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/framework/scope.h" @@ -32,4 +32,4 @@ Status CallGradFunction(const Scope& scope, const Operation& op, } // namespace test } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CC_GRADIENTS_GRAD_TESTUTIL_H_ +#endif // TENSORFLOW_CC_GRADIENTS_GRAD_TESTUTIL_H_ diff --git a/tensorflow/cc/ops/const_op.h b/tensorflow/cc/ops/const_op.h index d11fda475b3db58bf83cdb94079c8fde8d1170f7..424a683665f31b5e25eeceeb40477fc31640ce90 100644 --- a/tensorflow/cc/ops/const_op.h +++ b/tensorflow/cc/ops/const_op.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CC_OPS_CONST_OP_H_ -#define THIRD_PARTY_TENSORFLOW_CC_OPS_CONST_OP_H_ +#ifndef TENSORFLOW_CC_OPS_CONST_OP_H_ +#define TENSORFLOW_CC_OPS_CONST_OP_H_ #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/framework/scope.h" @@ -82,4 +82,4 @@ std::vector AsNodeOutList(const Scope& scope, } // namespace ops } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CC_OPS_CONST_OP_H_ +#endif // TENSORFLOW_CC_OPS_CONST_OP_H_ diff --git a/tensorflow/cc/ops/standard_ops.h b/tensorflow/cc/ops/standard_ops.h index 0c021f0b3ac02c596e0511e650a3caa0002c25d1..98f53010ecf78f769c7d89d6aafc48fdb772f42e 100644 --- a/tensorflow/cc/ops/standard_ops.h +++ b/tensorflow/cc/ops/standard_ops.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CC_OPS_STANDARD_OPS_H_ -#define THIRD_PARTY_TENSORFLOW_CC_OPS_STANDARD_OPS_H_ +#ifndef TENSORFLOW_CC_OPS_STANDARD_OPS_H_ +#define TENSORFLOW_CC_OPS_STANDARD_OPS_H_ #include "tensorflow/cc/ops/array_ops.h" #include "tensorflow/cc/ops/candidate_sampling_ops.h" @@ -37,4 +37,4 @@ limitations under the License. #include "tensorflow/cc/ops/training_ops.h" #include "tensorflow/cc/ops/user_ops.h" -#endif // THIRD_PARTY_TENSORFLOW_CC_OPS_STANDARD_OPS_H_ +#endif // TENSORFLOW_CC_OPS_STANDARD_OPS_H_ diff --git a/tensorflow/cc/ops/while_loop.h b/tensorflow/cc/ops/while_loop.h index a04476056a058ff0951a6347e8ffc05bc5ff5023..727237b5c7ad4d31dba1aaaf6d5600773d69223e 100644 --- a/tensorflow/cc/ops/while_loop.h +++ b/tensorflow/cc/ops/while_loop.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CC_OPS_WHILE_LOOP_H_ -#define THIRD_PARTY_TENSORFLOW_CC_OPS_WHILE_LOOP_H_ +#ifndef TENSORFLOW_CC_OPS_WHILE_LOOP_H_ +#define TENSORFLOW_CC_OPS_WHILE_LOOP_H_ #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/framework/scope.h" @@ -71,4 +71,4 @@ Status BuildWhileLoop(const Scope& scope, const std::vector& inputs, } // namespace ops } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CC_OPS_WHILE_LOOP_H_ +#endif // TENSORFLOW_CC_OPS_WHILE_LOOP_H_ diff --git a/tensorflow/cc/profiler/profiler.h b/tensorflow/cc/profiler/profiler.h index e1ce315d3c125ef9f0cb16209e199690211df440..6077c45c5854fd5812ccb7c91522f93ed4e54883 100644 --- a/tensorflow/cc/profiler/profiler.h +++ b/tensorflow/cc/profiler/profiler.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CC_PROFILER_PROFILER_H_ -#define THIRD_PARTY_TENSORFLOW_CC_PROFILER_PROFILER_H_ +#ifndef TENSORFLOW_CC_PROFILER_PROFILER_H_ +#define TENSORFLOW_CC_PROFILER_PROFILER_H_ #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/lib/core/status.h" @@ -94,4 +94,4 @@ class Profiler { } // namespace tfprof } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CC_PROFILER_PROFILER_H_ +#endif // TENSORFLOW_CC_PROFILER_PROFILER_H_ diff --git a/tensorflow/cc/saved_model/constants.h b/tensorflow/cc/saved_model/constants.h index c940df8a8761d97a859be3af30980ff79ca3577a..645a3f101d1ae7dda88ec4ca622c694dc5a7a919 100644 --- a/tensorflow/cc/saved_model/constants.h +++ b/tensorflow/cc/saved_model/constants.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CC_SAVED_MODEL_CONSTANTS_H_ -#define THIRD_PARTY_TENSORFLOW_CC_SAVED_MODEL_CONSTANTS_H_ +#ifndef TENSORFLOW_CC_SAVED_MODEL_CONSTANTS_H_ +#define TENSORFLOW_CC_SAVED_MODEL_CONSTANTS_H_ namespace tensorflow { @@ -47,4 +47,4 @@ constexpr char kSavedModelVariablesFilename[] = "variables"; } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CC_SAVED_MODEL_CONSTANTS_H_ +#endif // TENSORFLOW_CC_SAVED_MODEL_CONSTANTS_H_ diff --git a/tensorflow/cc/saved_model/loader.h b/tensorflow/cc/saved_model/loader.h index 3d634dd51543bed8d3c074bdc56c251f97d56976..a8e098fa5440e7a8f72fd0b52737dcb06435b908 100644 --- a/tensorflow/cc/saved_model/loader.h +++ b/tensorflow/cc/saved_model/loader.h @@ -15,8 +15,8 @@ limitations under the License. /// SavedModel loading functions and SavedModelBundle struct. -#ifndef THIRD_PARTY_TENSORFLOW_CC_SAVED_MODEL_LOADER_H_ -#define THIRD_PARTY_TENSORFLOW_CC_SAVED_MODEL_LOADER_H_ +#ifndef TENSORFLOW_CC_SAVED_MODEL_LOADER_H_ +#define TENSORFLOW_CC_SAVED_MODEL_LOADER_H_ #include #include @@ -61,4 +61,4 @@ bool MaybeSavedModelDirectory(const string& export_dir); } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CC_SAVED_MODEL_LOADER_H_ +#endif // TENSORFLOW_CC_SAVED_MODEL_LOADER_H_ diff --git a/tensorflow/cc/saved_model/signature_constants.h b/tensorflow/cc/saved_model/signature_constants.h index b2d39bd55beb48a05489236395a208e41deb9c8f..7d8c07f5cf0a310c20193469cb6d18664f738d96 100644 --- a/tensorflow/cc/saved_model/signature_constants.h +++ b/tensorflow/cc/saved_model/signature_constants.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CC_SAVED_MODEL_SIGNATURE_CONSTANTS_H_ -#define THIRD_PARTY_TENSORFLOW_CC_SAVED_MODEL_SIGNATURE_CONSTANTS_H_ +#ifndef TENSORFLOW_CC_SAVED_MODEL_SIGNATURE_CONSTANTS_H_ +#define TENSORFLOW_CC_SAVED_MODEL_SIGNATURE_CONSTANTS_H_ namespace tensorflow { @@ -66,4 +66,4 @@ static constexpr char kRegressOutputs[] = "outputs"; } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CC_SAVED_MODEL_SIGNATURE_CONSTANTS_H_ +#endif // TENSORFLOW_CC_SAVED_MODEL_SIGNATURE_CONSTANTS_H_ diff --git a/tensorflow/cc/saved_model/tag_constants.h b/tensorflow/cc/saved_model/tag_constants.h index b71cb263ca42dab7e830c1880ec4b311bc272f82..68a090e0c4cf79cfa87771a80447b8112fc37fb9 100644 --- a/tensorflow/cc/saved_model/tag_constants.h +++ b/tensorflow/cc/saved_model/tag_constants.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CC_SAVED_MODEL_TAG_CONSTANTS_H_ -#define THIRD_PARTY_TENSORFLOW_CC_SAVED_MODEL_TAG_CONSTANTS_H_ +#ifndef TENSORFLOW_CC_SAVED_MODEL_TAG_CONSTANTS_H_ +#define TENSORFLOW_CC_SAVED_MODEL_TAG_CONSTANTS_H_ namespace tensorflow { @@ -32,4 +32,4 @@ constexpr char kSavedModelTagTrain[] = "train"; } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CC_SAVED_MODEL_TAG_CONSTANTS_H_ +#endif // TENSORFLOW_CC_SAVED_MODEL_TAG_CONSTANTS_H_ diff --git a/tensorflow/cc/tools/freeze_saved_model.h b/tensorflow/cc/tools/freeze_saved_model.h index bd5e0516c8999dc235747ccec75a57542b0f9bf7..b10f29805a4515f9d49426cc41e0d375cd32b072 100644 --- a/tensorflow/cc/tools/freeze_saved_model.h +++ b/tensorflow/cc/tools/freeze_saved_model.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CC_TOOLS_FREEZE_SAVED_MODEL_H_ -#define THIRD_PARTY_TENSORFLOW_CC_TOOLS_FREEZE_SAVED_MODEL_H_ +#ifndef TENSORFLOW_CC_TOOLS_FREEZE_SAVED_MODEL_H_ +#define TENSORFLOW_CC_TOOLS_FREEZE_SAVED_MODEL_H_ #include @@ -40,4 +40,4 @@ Status FreezeSavedModel(const SavedModelBundle& saved_model_bundle, } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CC_TOOLS_FREEZE_SAVED_MODEL_H_ +#endif // TENSORFLOW_CC_TOOLS_FREEZE_SAVED_MODEL_H_ diff --git a/tensorflow/cc/tools/freeze_saved_model_test.cc b/tensorflow/cc/tools/freeze_saved_model_test.cc index 57244a4f0adeb9775e35445f77205f3d221ee05b..52a81a50284aec36bba4e56a0232c886cb0cb6cf 100644 --- a/tensorflow/cc/tools/freeze_saved_model_test.cc +++ b/tensorflow/cc/tools/freeze_saved_model_test.cc @@ -71,7 +71,7 @@ class FreezeTest : public ::testing::Test { return Status::OK(); } - // Adds `graph_def` to `saved_model_bundle` and intializes a session with + // Adds `graph_def` to `saved_model_bundle` and initializes a session with // `init_node`. Status AddGraphDefToSavedModelBundle(const GraphDef& graph_def, const string& init_node, diff --git a/tensorflow/cc/training/coordinator.h b/tensorflow/cc/training/coordinator.h index 0e01b19cd98bc797b7bb25da55c05d96f3eb93c7..7168b775251d38687d604b5294405389a8c1b04f 100644 --- a/tensorflow/cc/training/coordinator.h +++ b/tensorflow/cc/training/coordinator.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CC_TRAINING_COORDINATOR_H_ -#define THIRD_PARTY_TENSORFLOW_CC_TRAINING_COORDINATOR_H_ +#ifndef TENSORFLOW_CC_TRAINING_COORDINATOR_H_ +#define TENSORFLOW_CC_TRAINING_COORDINATOR_H_ #include #include @@ -128,4 +128,4 @@ class Coordinator { } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CC_TRAINING_COORDINATOR_H_ +#endif // TENSORFLOW_CC_TRAINING_COORDINATOR_H_ diff --git a/tensorflow/cc/training/queue_runner.h b/tensorflow/cc/training/queue_runner.h index 2d3450032388bfee96055f23cf621af0fa4731ae..21189b4b046b87b8609483109096fda6144681b8 100644 --- a/tensorflow/cc/training/queue_runner.h +++ b/tensorflow/cc/training/queue_runner.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CC_TRAINING_QUEUE_RUNNER_H_ -#define THIRD_PARTY_TENSORFLOW_CC_TRAINING_QUEUE_RUNNER_H_ +#ifndef TENSORFLOW_CC_TRAINING_QUEUE_RUNNER_H_ +#define TENSORFLOW_CC_TRAINING_QUEUE_RUNNER_H_ #include #include @@ -137,4 +137,4 @@ class QueueRunner : public RunnerInterface { } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CC_TRAINING_QUEUE_RUNNER_H_ +#endif // TENSORFLOW_CC_TRAINING_QUEUE_RUNNER_H_ diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index f7c6cd293a8a4788bd73cc42c5c61e60d4a2c110..314f5506b16e2c28736d9d39aa6c856d50885108 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -403,11 +403,6 @@ tf_xla_py_test( disabled_backends = [ "gpu", ], - tags = [ - "manual", - "no_oss", - "notap", - ], deps = [ ":xla_test", "//tensorflow/python:framework_for_generated_wrappers", diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py index 65706b35d616eb4dce94f0a7056a1604a97ff4c1..16856bd736ed408da29c3199c4593eb578775128 100644 --- a/tensorflow/compiler/tests/binary_ops_test.py +++ b/tensorflow/compiler/tests/binary_ops_test.py @@ -43,7 +43,7 @@ class BinaryOpsTest(XLATestCase): output = op(pa, pb) result = session.run(output, {pa: a, pb: b}) if equality_test is None: - equality_test = self.assertAllClose + equality_test = self.assertAllCloseAccordingToType equality_test(result, expected, rtol=1e-3) def _testSymmetricBinary(self, op, a, b, expected, equality_test=None): @@ -54,14 +54,20 @@ class BinaryOpsTest(XLATestCase): """Tests closeness of two lists of floats.""" self.assertEqual(len(result), len(expected)) for i in range(len(result)): - self.assertAllClose(result[i], expected[i], rtol) + self.assertAllCloseAccordingToType(result[i], expected[i], rtol) def testFloatOps(self): for dtype in self.float_types: + if dtype == dtypes.bfloat16.as_numpy_dtype: + a = -1.01 + b = 4.1 + else: + a = -1.001 + b = 4.01 self._testBinary( lambda x, y: math_ops.approximate_equal(x, y, tolerance=0.0001), - np.array([[[[-1, 2.00009999], [-3, 4.01]]]], dtype=dtype), - np.array([[[[-1.001, 2], [-3.00009, 4]]]], dtype=dtype), + np.array([[[[-1, 2.00009999], [-3, b]]]], dtype=dtype), + np.array([[[[a, 2], [-3.00009, 4]]]], dtype=dtype), expected=np.array([[[[False, True], [True, False]]]], dtype=dtype)) self._testBinary( diff --git a/tensorflow/compiler/tests/image_ops_test.py b/tensorflow/compiler/tests/image_ops_test.py index e84b790037c3b341a01c0a4d295e36890ea1f28e..538fa8e8e570b83ed681ecc0501285520cabdecb 100644 --- a/tensorflow/compiler/tests/image_ops_test.py +++ b/tensorflow/compiler/tests/image_ops_test.py @@ -65,7 +65,7 @@ class RGBToHSVTest(XLATestCase): # Verify that processing batch elements together is the same as separate self.assertAllClose(batch1, join1) self.assertAllClose(batch2, join2) - self.assertAllClose(batch2, inp) + self.assertAllCloseAccordingToType(batch2, inp, bfloat16_atol=0.03) def testRGBToHSVRoundTrip(self): data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1] @@ -77,21 +77,25 @@ class RGBToHSVTest(XLATestCase): hsv = image_ops.rgb_to_hsv(placeholder) rgb = image_ops.hsv_to_rgb(hsv) rgb_tf = rgb.eval(feed_dict={placeholder: rgb_np}) - self.assertAllClose(rgb_tf, rgb_np) + self.assertAllCloseAccordingToType(rgb_tf, rgb_np, bfloat16_atol=0.03) def testRGBToHSVNumpy(self): """Tests the RGB to HSV conversion matches a reference implementation.""" for nptype in self.float_types: rgb_flat = np.random.random(64 * 3).reshape((64, 3)).astype(nptype) rgb_np = rgb_flat.reshape(4, 4, 4, 3) - hsv_np = np.array([colorsys.rgb_to_hsv(r, g, b) for r, g, b in rgb_flat]) + hsv_np = np.array([ + colorsys.rgb_to_hsv( + r.astype(np.float64), g.astype(np.float64), b.astype(np.float64)) + for r, g, b in rgb_flat + ]) hsv_np = hsv_np.reshape(4, 4, 4, 3) with self.test_session(): placeholder = array_ops.placeholder(nptype) with self.test_scope(): hsv_op = image_ops.rgb_to_hsv(placeholder) hsv_tf = hsv_op.eval(feed_dict={placeholder: rgb_np}) - self.assertAllClose(hsv_tf, hsv_np) + self.assertAllCloseAccordingToType(hsv_tf, hsv_np) class AdjustContrastTest(XLATestCase): @@ -427,7 +431,8 @@ class ResizeBilinearTest(XLATestCase): np.zeros([1, input_shape[0], input_shape[1], 1], dtype=dtype), align_corners=True) out = sess.run(resized, {grads: grads_np[np.newaxis, :, :, np.newaxis]}) - self.assertAllClose(expected[np.newaxis, :, :, np.newaxis], out) + self.assertAllCloseAccordingToType(expected[np.newaxis, :, :, np.newaxis], + out) def testAlignCorners1x2To3x2(self): for dtype in self.float_types: diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py index 0a6fe04d3cdd29f1d40d33be1f4319090e7ba3d1..8e4b8a38336c5e8b2e10edc4c81502eeebb628d2 100644 --- a/tensorflow/compiler/tests/unary_ops_test.py +++ b/tensorflow/compiler/tests/unary_ops_test.py @@ -67,8 +67,10 @@ class UnaryOpsTest(XLATestCase): output = op(pinp) result = session.run(output, {pinp: inp}) if equality_test is None: - equality_test = self.assertAllCloseAccordingToType - equality_test(result, expected, rtol=rtol, atol=atol) + self.assertAllCloseAccordingToType( + result, expected, rtol=rtol, atol=atol, bfloat16_rtol=0.03) + else: + equality_test(result, expected, rtol=rtol, atol=atol) def ListsAreClose(self, result, expected, rtol, atol): """Tests closeness of two lists of floats.""" diff --git a/tensorflow/compiler/tf2xla/kernels/shape_util.h b/tensorflow/compiler/tf2xla/kernels/shape_util.h index 575086e118080f6799a54d3ae6409b2b641c4341..ca57be3d47b95d71b07746e50256070e0a4f4c09 100644 --- a/tensorflow/compiler/tf2xla/kernels/shape_util.h +++ b/tensorflow/compiler/tf2xla/kernels/shape_util.h @@ -31,4 +31,4 @@ Status TensorShapeToConstant(const TensorShape& input_shape, } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_COMPILER_TF2XLA_KERNELS_SHAPE_UTIL_H_ +#endif // TENSORFLOW_COMPILER_TF2XLA_KERNELS_SHAPE_UTIL_H_ diff --git a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc index 79da701fd244a461a60588153b601d5c1870fa89..672e19bd93449ccc31f4af5ded23257b197a3c39 100644 --- a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc +++ b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc @@ -29,7 +29,7 @@ XlaCompiledCpuFunction::XlaCompiledCpuFunction(const StaticData& static_data, arg_names_(static_data.arg_names), result_names_(static_data.result_names), program_shape_(static_data.program_shape), - hlo_profile_printer_(static_data.hlo_profile_printer) { + hlo_profile_printer_data_(static_data.hlo_profile_printer_data) { // Allocate arg and temp buffers. if (alloc_mode == AllocMode::ARGS_RESULTS_PROFILES_AND_TEMPS) { alloc_args_ = tensorflow::tfcompile::runtime::MallocContiguousBuffers( diff --git a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h index e0ae3ed9a811bcc49ce8862037a67d293e879e57..48a8c083cacf2f6ecf9dc1817b6174c01385d035 100644 --- a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h +++ b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h @@ -26,7 +26,7 @@ limitations under the License. // never use this functionality. namespace xla { class ProgramShape; -class HloProfilePrinter; +class HloProfilePrinterData; } namespace tensorflow { @@ -77,12 +77,14 @@ class XlaCompiledCpuFunction { // [Optional] Arg and result shapes. const xla::ProgramShape* program_shape = nullptr; - // [Optional] Profile printer. Null if profiling is disabled. - const xla::HloProfilePrinter* hlo_profile_printer = nullptr; + // [Optional] Profile printer data. Null if profiling is disabled. + const xla::HloProfilePrinterData* hlo_profile_printer_data = nullptr; // [Optional] The number of profile counters expected in the profile counter // buffer by the generated code and hlo_profile_printer. 0 if profiling is - // disabled. + // disabled. This information is already present in + // hlo_profile_printer_data but xla::HloProfilePrinterData is forward + // declared so we don't have access to that information here. int64 profile_counters_size = 0; }; @@ -205,10 +207,12 @@ class XlaCompiledCpuFunction { // program shape isn't available. const xla::ProgramShape* ProgramShape() const { return program_shape_; } - bool hlo_profiling_enabled() const { return hlo_profile_printer_ != nullptr; } - const xla::HloProfilePrinter& hlo_profile_printer() const { + bool hlo_profiling_enabled() const { + return hlo_profile_printer_data_ != nullptr; + } + const xla::HloProfilePrinterData& hlo_profile_printer_data() const { assert(hlo_profiling_enabled()); - return *hlo_profile_printer_; + return *hlo_profile_printer_data_; } private: @@ -234,7 +238,7 @@ class XlaCompiledCpuFunction { const char** arg_names_ = nullptr; const char** result_names_ = nullptr; const xla::ProgramShape* program_shape_ = nullptr; - const xla::HloProfilePrinter* hlo_profile_printer_ = nullptr; + const xla::HloProfilePrinterData* hlo_profile_printer_data_ = nullptr; }; } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc index 584417bc72c8f6645c05912e857b031cfb394e54..1fe6e69ff2dc838152032ac3d7b21de41684c6f6 100644 --- a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc +++ b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc @@ -182,10 +182,10 @@ XlaJitCompiledCpuFunction::Compile( jit->static_data_.program_shape = jit->program_shape_.get(); if (cpu_executable->hlo_profiling_enabled()) { - jit->static_data_.hlo_profile_printer = - &cpu_executable->hlo_profile_printer(); + jit->static_data_.hlo_profile_printer_data = + &cpu_executable->hlo_profile_printer_data(); jit->static_data_.profile_counters_size = - cpu_executable->hlo_profile_printer().profile_counters_size(); + cpu_executable->hlo_profile_printer_data().profile_counters_size(); } return std::move(jit_unique_ptr); diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc index c0c4251eabcd06d7c84ae76f349d657fa9f6d641..ee0aed672e1b264fee0a7f381c334400c55f3581 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc @@ -412,10 +412,20 @@ XlaCompiler* XlaOpKernelContext::compiler() const { return XlaContext::Get(context_).compiler(); } -void XlaOpKernelContext::CtxFailure(Status s) { context_->CtxFailure(s); } -void XlaOpKernelContext::CtxFailureWithWarning(Status s) { +void XlaOpKernelContext::CtxFailure(const Status& s) { + context_->CtxFailure(s); +} +void XlaOpKernelContext::CtxFailureWithWarning(const Status& s) { context_->CtxFailureWithWarning(s); } +void XlaOpKernelContext::CtxFailure(const char* file, int line, + const Status& s) { + context_->CtxFailure(file, line, s); +} +void XlaOpKernelContext::CtxFailureWithWarning(const char* file, int line, + const Status& s) { + context_->CtxFailureWithWarning(file, line, s); +} const xla::Computation* XlaOpKernelContext::GetOrCreateMax( const DataType type) { diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h index f1ae81a5aa9d507a3e0dd577568377385b1844e6..6d3b6db2289d6c0b8f266062f9f3baca1145154a 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.h +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h @@ -173,8 +173,10 @@ class XlaOpKernelContext { const xla::ComputationDataHandle& handle); // Helper routines for the OP_REQUIRES macros - void CtxFailure(Status s); - void CtxFailureWithWarning(Status s); + void CtxFailure(const Status& s); + void CtxFailureWithWarning(const Status& s); + void CtxFailure(const char* file, int line, const Status& s); + void CtxFailureWithWarning(const char* file, int line, const Status& s); // If this kernel invocation is within a function execution, // call_frame() returns the call frame for the function call. diff --git a/tensorflow/compiler/xla/client/computation_builder.h b/tensorflow/compiler/xla/client/computation_builder.h index d82ba63e8ad0b9ceac0eb5f0cd7720cac0cbe6d3..ea4cdb76673b1c99036224bcd754ce4fe1360945 100644 --- a/tensorflow/compiler/xla/client/computation_builder.h +++ b/tensorflow/compiler/xla/client/computation_builder.h @@ -67,7 +67,7 @@ class ComputationBuilder { // OpMetadata is often applied to a series of XLA HLO instructions. As a // result, OpMetadata is set on the Computation Builder. All subsequent // instructions generated via this Computation Builder will have the same - // OpMetadata attached until a call to ClearOpMetdata. + // OpMetadata attached until a call to ClearOpMetadata. void SetOpMetadata(const OpMetadata& metadata) { metadata_ = metadata; } // Clears the HloMetadata state. diff --git a/tensorflow/compiler/xla/execution_options_util.h b/tensorflow/compiler/xla/execution_options_util.h index 562da78e837ea6c4a01f0d1170797340fd421ad8..a8ca27ec8dfdc01267ccc9efa6c39093c43d4e2d 100644 --- a/tensorflow/compiler/xla/execution_options_util.h +++ b/tensorflow/compiler/xla/execution_options_util.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_EXECUTION_OPTIONS_UTIL_H_ -#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_EXECUTION_OPTIONS_UTIL_H_ +#ifndef TENSORFLOW_COMPILER_XLA_EXECUTION_OPTIONS_UTIL_H_ +#define TENSORFLOW_COMPILER_XLA_EXECUTION_OPTIONS_UTIL_H_ #include "tensorflow/compiler/xla/xla.pb.h" @@ -26,4 +26,4 @@ ExecutionOptions CreateDefaultExecutionOptions(); } // namespace xla -#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_EXECUTION_OPTIONS_UTIL_H_ +#endif // TENSORFLOW_COMPILER_XLA_EXECUTION_OPTIONS_UTIL_H_ diff --git a/tensorflow/compiler/xla/iterator_util.h b/tensorflow/compiler/xla/iterator_util.h index a39999705eddc5728dce028dab64b7358395757e..a8bb8c7a7e6784e555f4e9dad73ecc78c668ac42 100644 --- a/tensorflow/compiler/xla/iterator_util.h +++ b/tensorflow/compiler/xla/iterator_util.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_ITERATOR_UTIL_H_ -#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_ITERATOR_UTIL_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_ITERATOR_UTIL_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_ITERATOR_UTIL_H_ #include #include @@ -95,4 +95,4 @@ UnwrappingIterator MakeUnwrappingIterator(NestedIter iter) { } // namespace xla -#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_ITERATOR_UTIL_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_ITERATOR_UTIL_H_ diff --git a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc index e88bffd0ba2dacb837c568023f5da1338fea40f3..fe3a4d2f6df47d9f156529e55198a5f339bc8e3c 100644 --- a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc +++ b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc @@ -223,6 +223,11 @@ void AllocateFlags() { tensorflow::Flag( "xla_dump_hlo_proto_to", flag_values->mutable_xla_dump_hlo_proto_to(), "Dump compilation artifacts as proto binary into this directory."), + tensorflow::Flag( + "xla_dump_prepass_hlo_proto_to", + flag_values->mutable_xla_dump_prepass_hlo_proto_to(), + "Dump compilation artifacts, before hlo passes are executed, as " + "proto binary into this directory."), tensorflow::Flag( "xla_test_all_output_layouts", bool_setter_for(&DebugOptions::set_xla_test_all_output_layouts), diff --git a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.h b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.h index d0ef8e66ab0bcbf88035ae31fe32eb161e32e998..b53157f59c61cf4e0850e006ad3656f4be63a936 100644 --- a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.h +++ b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_DEBUG_OPTIONS_FLAGS_H_ -#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_DEBUG_OPTIONS_FLAGS_H_ +#ifndef TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_DEBUG_OPTIONS_FLAGS_H_ +#define TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_DEBUG_OPTIONS_FLAGS_H_ #include @@ -35,4 +35,4 @@ xla::DebugOptions GetDebugOptionsFromFlags(); } // namespace legacy_flags } // namespace xla -#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_DEBUG_OPTIONS_FLAGS_H_ +#endif // TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_DEBUG_OPTIONS_FLAGS_H_ diff --git a/tensorflow/compiler/xla/legacy_flags/debug_options_parsers.h b/tensorflow/compiler/xla/legacy_flags/debug_options_parsers.h index 0c238e6a5decffb0339f428e4ea676944479cf1b..e9cf435d83d8345e974d83f8e5340dafeba8e3b2 100644 --- a/tensorflow/compiler/xla/legacy_flags/debug_options_parsers.h +++ b/tensorflow/compiler/xla/legacy_flags/debug_options_parsers.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_DEBUG_OPTIONS_PARSERS_H_ -#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_DEBUG_OPTIONS_PARSERS_H_ +#ifndef TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_DEBUG_OPTIONS_PARSERS_H_ +#define TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_DEBUG_OPTIONS_PARSERS_H_ #include #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -148,4 +148,4 @@ inline bool parse_xla_reduce_precision_option( } // namespace legacy_flags } // namespace xla -#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_DEBUG_OPTIONS_PARSERS_H_ +#endif // TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_DEBUG_OPTIONS_PARSERS_H_ diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc index 7f0201e74ab51f8f9906dd045ae7dfb96158f8e9..89279b659c75ce4775581dfbfa8d830f54ae6fe8 100644 --- a/tensorflow/compiler/xla/literal_util.cc +++ b/tensorflow/compiler/xla/literal_util.cc @@ -830,6 +830,16 @@ std::unique_ptr Literal::Slice( result_literal->Set(indices, value); }); return result_literal; + case C64: + result_literal->EachCell( + [&](tensorflow::gtl::ArraySlice indices, complex64 /*value*/) { + for (int64 i = 0; i < ShapeUtil::Rank(result_shape); ++i) { + new_indices[i] = indices[i] + start_indices[i]; + } + complex64 value = Get(new_indices); + result_literal->Set(indices, value); + }); + return result_literal; case S32: result_literal->EachCell( [&](tensorflow::gtl::ArraySlice indices, int32 /*value*/) { diff --git a/tensorflow/compiler/xla/primitive_util.cc b/tensorflow/compiler/xla/primitive_util.cc index 2bce56b7bd2f91f20ea670d0e7ccaa432c2b5f9f..143c9a2366be5786b7ef2148580caeb97d67d2d8 100644 --- a/tensorflow/compiler/xla/primitive_util.cc +++ b/tensorflow/compiler/xla/primitive_util.cc @@ -20,79 +20,6 @@ limitations under the License. namespace xla { namespace primitive_util { -template <> -PrimitiveType NativeToPrimitiveType() { - return PRED; -} - -// Unsigned integer -template <> -PrimitiveType NativeToPrimitiveType() { - return U8; -} - -template <> -PrimitiveType NativeToPrimitiveType() { - return U16; -} - -template <> -PrimitiveType NativeToPrimitiveType() { - return U32; -} - -template <> -PrimitiveType NativeToPrimitiveType() { - return U64; -} - -// Signed integer -template <> -PrimitiveType NativeToPrimitiveType() { - return S8; -} - -template <> -PrimitiveType NativeToPrimitiveType() { - return S16; -} - -template <> -PrimitiveType NativeToPrimitiveType() { - return S32; -} - -template <> -PrimitiveType NativeToPrimitiveType() { - return S64; -} - -// Floating point -template <> -PrimitiveType NativeToPrimitiveType() { - return F32; -} - -template <> -PrimitiveType NativeToPrimitiveType() { - return F64; -} - -template <> -PrimitiveType NativeToPrimitiveType() { - return BF16; -} - -template <> -PrimitiveType NativeToPrimitiveType() { - return F16; -} - -template <> -PrimitiveType NativeToPrimitiveType() { - return C64; -} - bool IsFloatingPointType(PrimitiveType type) { return type == F16 || type == F32 || type == F64 || type == BF16; } diff --git a/tensorflow/compiler/xla/primitive_util.h b/tensorflow/compiler/xla/primitive_util.h index cb4583d198b454be1432134a9f6a77dbbbe5bdd8..b26a10ade63a5dad3bf8f9f3a2a33c3c5e67bdb2 100644 --- a/tensorflow/compiler/xla/primitive_util.h +++ b/tensorflow/compiler/xla/primitive_util.h @@ -47,49 +47,81 @@ PrimitiveType NativeToPrimitiveType() { } // Declarations of specializations for each native type which correspond to a -// XLA primitive type. +// XLA primitive type. As an optimization, these are declared inline in the +// header. template <> -PrimitiveType NativeToPrimitiveType(); +inline PrimitiveType NativeToPrimitiveType() { + return PRED; +} // Unsigned integer template <> -PrimitiveType NativeToPrimitiveType(); +inline PrimitiveType NativeToPrimitiveType() { + return U8; +} template <> -PrimitiveType NativeToPrimitiveType(); +inline PrimitiveType NativeToPrimitiveType() { + return U16; +} template <> -PrimitiveType NativeToPrimitiveType(); +inline PrimitiveType NativeToPrimitiveType() { + return U32; +} template <> -PrimitiveType NativeToPrimitiveType(); +inline PrimitiveType NativeToPrimitiveType() { + return U64; +} // Signed integer template <> -PrimitiveType NativeToPrimitiveType(); +inline PrimitiveType NativeToPrimitiveType() { + return S8; +} template <> -PrimitiveType NativeToPrimitiveType(); +inline PrimitiveType NativeToPrimitiveType() { + return S16; +} template <> -PrimitiveType NativeToPrimitiveType(); +inline PrimitiveType NativeToPrimitiveType() { + return S32; +} template <> -PrimitiveType NativeToPrimitiveType(); +inline PrimitiveType NativeToPrimitiveType() { + return S64; +} // Floating point template <> -PrimitiveType NativeToPrimitiveType(); +inline PrimitiveType NativeToPrimitiveType() { + return F32; +} + template <> -PrimitiveType NativeToPrimitiveType(); +inline PrimitiveType NativeToPrimitiveType() { + return F64; +} + template <> -PrimitiveType NativeToPrimitiveType(); +inline PrimitiveType NativeToPrimitiveType() { + return F16; +} + template <> -PrimitiveType NativeToPrimitiveType(); +inline PrimitiveType NativeToPrimitiveType() { + return BF16; +} // Complex template <> -PrimitiveType NativeToPrimitiveType(); +inline PrimitiveType NativeToPrimitiveType() { + return C64; +} bool IsFloatingPointType(PrimitiveType type); diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py index 5455adafcded90dbe38b4c444d2bc03fae445888..9cfe1249f50fd3c4b09d5af0c0e17a6f40b024a2 100644 --- a/tensorflow/compiler/xla/python/xla_client.py +++ b/tensorflow/compiler/xla/python/xla_client.py @@ -555,22 +555,12 @@ class ComputationBuilder(object): A ComputationDataHandle representing the added pad op. """ if not isinstance(padding_config, xla_data_pb2.PaddingConfig): - padding_config = self._GetPaddingConfigFromTriples(padding_config) + padding_config = GetPaddingConfigFromTriples(padding_config) return _wrap_data_handle( self._client.Pad(_unwrap_data_handle(operand), _unwrap_data_handle(padding_value), padding_config)) - def _GetPaddingConfigFromTriples(self, triples): - """Create PaddingConfig proto from list of triples of integers.""" - padding_config = xla_data_pb2.PaddingConfig() - for lo, hi, interior in triples: - dimension = padding_config.dimensions.add() - dimension.edge_padding_low = lo - dimension.edge_padding_high = hi - dimension.interior_padding = interior - return padding_config - def Reshape(self, operand, dimensions, new_sizes): """Reshape op.""" return _wrap_data_handle( @@ -997,3 +987,14 @@ def get_replica_count(): yet or not. """ return c_api.GetReplicaCount() + + +def GetPaddingConfigFromTriples(triples): + """Create PaddingConfig proto from list of triples of integers.""" + padding_config = xla_data_pb2.PaddingConfig() + for lo, hi, interior in triples: + dimension = padding_config.dimensions.add() + dimension.edge_padding_low = lo + dimension.edge_padding_high = hi + dimension.interior_padding = interior + return padding_config diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 71341c6f1e9a359a6d2a8aa9f2fb97b140ade23d..9a0acda94fb08ee0accfba6c5380f628c07ebaa2 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -29,6 +29,11 @@ xla_proto_library( deps = ["//tensorflow/compiler/xla:xla_data_proto"], ) +xla_proto_library( + name = "hlo_profile_printer_data", + srcs = ["hlo_profile_printer_data.proto"], +) + # Filegroup used to collect source files for dependency checking. filegroup( name = "c_srcs", @@ -452,6 +457,7 @@ cc_library( ":hlo_evaluator", ":hlo_execution_profile", ":hlo_module_config", + ":hlo_proto_util", ":platform_util", ":session_proto", ":transfer_manager", @@ -905,6 +911,7 @@ tf_cc_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/compiler/xla/tools/parser:hlo_parser", ], ) @@ -1084,6 +1091,7 @@ tf_cc_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep @@ -1940,6 +1948,16 @@ cc_library( ], ) +tf_cc_test( + name = "hlo_element_type_converter_test", + srcs = ["hlo_element_type_converter_test.cc"], + deps = [ + ":hlo_element_type_converter", + ":hlo_matchers", + "//tensorflow/compiler/xla/tests:hlo_test_base", + ], +) + cc_library( name = "device_memory_allocator", srcs = ["device_memory_allocator.cc"], @@ -2252,6 +2270,7 @@ cc_library( srcs = ["hlo_profile_printer.cc"], hdrs = ["hlo_profile_printer.h"], deps = [ + ":hlo_profile_printer_data", ":human_readable_profile_builder", "//tensorflow/compiler/xla:types", ], diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 90a3f0b6748fc00c9cd9226700805bf243a1acdd..ba82e822b216528c28536181059bc2417048de01 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -1741,6 +1741,63 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow( } } + // If the pad puts a single non-identity value in each window that we're + // reducing, then this is a broadcast. + HloInstruction* pad_operand = operand->mutable_operand(0); + auto is_effective_broadcast = [&] { + if (window_util::HasStride(window)) { + VLOG(10) << "Window has stride."; + return false; + } + if (!window_util::HasSymmetricPadding(pad_config)) { + VLOG(10) << "Window has uneven padding."; + return false; + } + for (int64 i = 0; i < pad_config.dimensions_size(); ++i) { + const auto& pad_dimension = pad_config.dimensions(i); + if ((pad_dimension.edge_padding_low() != 0 || + pad_dimension.edge_padding_high() != 0) && + pad_operand->shape().dimensions(i) != 1) { + VLOG(10) << "Found non-trivial dimension being padded: " << i; + return false; + } + } + VLOG(10) << "Found to be padding trivial dimensions only."; + + for (int64 i = 0; i < window.dimensions_size(); ++i) { + const auto& pad_dimension = pad_config.dimensions(i); + const WindowDimension& window_dimension = window.dimensions(i); + bool dimension_has_padding = (pad_dimension.edge_padding_low() != 0 || + pad_dimension.edge_padding_high() != 0); + if (dimension_has_padding && + window_dimension.size() < pad_dimension.edge_padding_low() + 1) { + VLOG(10) << "Found window did not cover single unpadded element in " + "dimension: " + << i; + return false; + } + if (pad_operand->shape().dimensions(i) != 1 && + window_dimension.size() != 1) { + VLOG(10) << "Found window covers more than one element in non-trivial " + "dimension: " + << i; + return false; + } + } + VLOG(10) << "Found window covers a single unpadded element."; + return true; + }; + if (is_effective_broadcast()) { + VLOG(10) << "Replacing pad/reduce-window with (implicit) broadcast."; + auto fadd = [this](std::unique_ptr x) { + return computation_->AddInstruction(std::move(x)); + }; + return ReplaceWithNewInstruction( + reduce_window, HloInstruction::CreateBroadcastSequence( + /*output_shape=*/reduce_window->shape(), + /*operand=*/pad_operand, fadd)); + } + // Carry out the folding of the pad into reduce_window. VLOG(10) << "Folding pad into reduce-window."; Window new_window = window; @@ -1758,7 +1815,7 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow( return ReplaceWithNewInstruction( reduce_window, HloInstruction::CreateReduceWindow( /*shape=*/reduce_window->shape(), - /*operand=*/operand->mutable_operand(0), + /*operand=*/pad_operand, /*init_value=*/reduce_window->mutable_operand(1), /*window=*/new_window, /*reduce_computation=*/function)); diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index e7c4dfb0a1690683bbdb7e61067392b48fdba8a5..e43ea50af45318adf2c95aa69b3e53a5225c5579 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -30,6 +30,7 @@ limitations under the License. #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/strings/str_util.h" @@ -2495,6 +2496,144 @@ TEST_F(AlgebraicSimplifierTest, TrivialDynamicUpdateSlice) { op::DynamicSlice(op::Parameter(), op::Parameter())); } +struct PadReduceWindowEffectiveBroadcastCase { + std::vector input_spatials; + std::vector symmetric_pad_spatials; + std::vector reduce_window_spatials; + // Whether to use `B F S0 S1` form vs `B S0 S1 F` form. + // + // This doesn't test any different functionality but is useful for making sure + // kBroadcast nodes are well formed. + bool prepend_a; + bool should_become_broadcast; + + string ToTestCaseName() const { + return tensorflow::strings::StrCat( + tensorflow::str_util::Join(input_spatials, ","), ";", + tensorflow::str_util::Join(symmetric_pad_spatials, ","), ";", + tensorflow::str_util::Join(reduce_window_spatials, ","), ";", prepend_a, + ";", should_become_broadcast); + } +}; + +void PrintTo(const PadReduceWindowEffectiveBroadcastCase& c, std::ostream* os) { + *os << c.ToTestCaseName(); +} + +class PadReduceWindowEffectiveBroadcastTest + : public AlgebraicSimplifierTest, + public ::testing::WithParamInterface< + PadReduceWindowEffectiveBroadcastCase> {}; + +TEST_P(PadReduceWindowEffectiveBroadcastTest, DoIt) { + const auto& param = GetParam(); + + // a and b are parallel bounds we can either turn into a B F S0 S1 or + // `B S0 S1 F` kind of pattern. + auto decorate_spatials = [¶m](tensorflow::gtl::ArraySlice spatials, + int64 a, int64 b) { + std::vector result; + if (param.prepend_a) { + result.push_back(a); + } + for (int64 s : spatials) { + result.push_back(s); + } + if (!param.prepend_a) { + result.push_back(a); + } + result.push_back(b); + return result; + }; + + HloComputation::Builder builder(TestName()); + const Shape input_shape = ShapeUtil::MakeShape( + F32, decorate_spatials(param.input_spatials, 128, 2048)); + HloInstruction* input = builder.AddInstruction( + HloInstruction::CreateParameter(0, input_shape, "input")); + + PaddingConfig padding = window_util::MakeSymmetricPadding( + decorate_spatials(param.symmetric_pad_spatials, 0, 0)); + HloInstruction* pad = builder.AddInstruction(HloInstruction::CreatePad( + ShapeUtil::MakeShape( + F32, decorate_spatials(param.reduce_window_spatials, 128, 2048)), + input, + builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(0.0f))), + padding)); + + std::unique_ptr module = CreateNewModule(); + HloComputation* add_computation = nullptr; + { + HloComputation::Builder builder(TestName() + ".add"); + const Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); + HloInstruction* p0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "p0")); + HloInstruction* p1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape, "p1")); + builder.AddInstruction( + HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1)); + add_computation = module->AddEmbeddedComputation(builder.Build()); + } + + TF_ASSERT_OK_AND_ASSIGN( + const Shape output_shape, + ShapeInference::InferPadShape(input_shape, ShapeUtil::MakeShape(F32, {}), + padding)); + Window window = window_util::MakeWindow( + decorate_spatials(param.reduce_window_spatials, 1, 1)); + auto zero = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); + builder.AddInstruction(HloInstruction::CreateReduceWindow( + output_shape, pad, zero, window, add_computation)); + + auto computation = module->AddEntryComputation(builder.Build()); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(module.get())); + ASSERT_TRUE(run_successful); + + EXPECT_TRUE( + ShapeUtil::Equal(computation->root_instruction()->shape(), output_shape)); + + if (param.should_become_broadcast) { + EXPECT_THAT(computation->root_instruction(), op::Broadcast(::testing::_)); + } else { + EXPECT_THAT(computation->root_instruction(), + op::ReduceWindow(::testing::_, zero)); + } +} + +const std::vector& +PadReduceWindowEffectiveBroadcastCases() { + static auto* cases = new std::vector{ + {/*input_spatials=*/{1, 1}, /*symmetric_pad_amount=*/{6, 6}, + /*reduce_window_spatials=*/{7, 7}, /*prepend_a=*/true, + /*should_become_broadcast=*/true}, // + {/*input_spatials=*/{1, 1}, /*symmetric_pad_amount=*/{6, 6}, + /*reduce_window_spatials=*/{7, 7}, /*prepend_a=*/false, + /*should_become_broadcast=*/true}, // + {/*input_spatials=*/{2, 2}, /*symmetric_pad_amount=*/{6, 6}, + /*reduce_window_spatials=*/{7, 7}, /*prepend_a=*/true, + /*should_become_broadcast=*/false}, // + {/*input_spatials=*/{1, 1}, /*symmetric_pad_amount=*/{2, 2}, + /*reduce_window_spatials=*/{5, 5}, /*prepend_a=*/true, + /*should_become_broadcast=*/true}, // + {/*input_spatials=*/{1, 1}, /*symmetric_pad_amount=*/{2, 2}, + /*reduce_window_spatials=*/{1, 1}, /*prepend_a=*/true, + /*should_become_broadcast=*/false}, // + {/*input_spatials=*/{5, 1}, /*symmetric_pad_amount=*/{0, 2}, + /*reduce_window_spatials=*/{2, 5}, /*prepend_a=*/true, + /*should_become_broadcast=*/false}, // + }; + return *cases; +} + +INSTANTIATE_TEST_CASE_P( + PadReduceWindowEffectiveBroadcastInstantiation, + PadReduceWindowEffectiveBroadcastTest, + ::testing::ValuesIn(PadReduceWindowEffectiveBroadcastCases())); + class DotStrengthReductionTest : public AlgebraicSimplifierTest, public ::testing::WithParamInterface< diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc index 33fe11b81db1a1db40285d5c77d8900722025d1c..323620c13186ed5f3c8613adb7e736f33674c270 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment.cc @@ -846,14 +846,13 @@ Status BufferAssigner::AssignBuffersForComputation( continue; } - if (is_thread_local || instruction->opcode() == HloOpcode::kCustomCall) { - // Custom call operations never have reusable buffers. Also we do not - // reuse thread-local buffers for now, because they are dynamically - // allocated and their lifetimes are hard to compute. + if (is_thread_local) { + // We do not reuse thread-local buffers for now, because they are + // dynamically allocated and their lifetimes are hard to compute. BufferAllocation* allocation = assignment->NewAllocation( *buffer, buffer_size, is_thread_local, /*is_reusable=*/false); VLOG(3) << "New allocation #" << allocation->index() - << " for thread-local/CustomCall: " << *buffer; + << " for thread-local: " << *buffer; continue; } diff --git a/tensorflow/compiler/xla/service/compile_only_service.cc b/tensorflow/compiler/xla/service/compile_only_service.cc index 9e96898d9b4215e67c8686d372e4b4e6edd1d88b..b9306a8bb09dc4541014716bb0c5e73e3c93ec85 100644 --- a/tensorflow/compiler/xla/service/compile_only_service.cc +++ b/tensorflow/compiler/xla/service/compile_only_service.cc @@ -107,6 +107,7 @@ CompileOnlyService::CompileAheadOfTime( computation_tracker_.BuildHloModule( versioned_handle, *module_config, /*include_unreachable_instructions=*/true)); + TF_RETURN_IF_ERROR(MaybeDumpHloModule(*hlo_module)); hlo_modules.push_back(std::move(hlo_module)); } diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index f0507982b3749b179dbd7d76c46d39a209640661..33af77e1a81411ff5e1543d594b6078ed8e7fd1e 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -485,7 +485,7 @@ StatusOr> CpuCompiler::RunBackend( std::unordered_map instruction_to_profile_idx; std::unordered_map computation_to_profile_idx; std::unique_ptr hlo_profile_index_map; - std::unique_ptr hlo_profile_printer; + std::unique_ptr hlo_profile_printer_data; if (module->config().hlo_profiling_enabled()) { hlo_profile_index_map = MakeUnique(*module); @@ -505,8 +505,8 @@ StatusOr> CpuCompiler::RunBackend( HloCostAnalysis cost_analysis(shape_size_bytes); TF_RETURN_IF_ERROR(entry_computation->Accept(&cost_analysis)); - hlo_profile_printer = - CreateHloProfilePrinter(*hlo_profile_index_map, cost_analysis); + hlo_profile_printer_data = + CreateHloProfilePrinterData(*hlo_profile_index_map, cost_analysis); computation_to_profile_idx = hlo_profile_index_map->computation_to_profile_idx(); } @@ -619,7 +619,7 @@ StatusOr> CpuCompiler::RunBackend( cpu_executable.reset(new ParallelCpuExecutable( std::move(jit), std::move(assignment), std::move(module), std::move(function_names), std::move(aligned_constants), - std::move(hlo_profile_printer), std::move(hlo_profile_index_map))); + std::move(hlo_profile_printer_data), std::move(hlo_profile_index_map))); if (embed_ir_in_executable) { static_cast(*cpu_executable) @@ -698,7 +698,7 @@ StatusOr> CpuCompiler::RunBackend( jit->AddModule(std::move(llvm_module)); cpu_executable.reset(new CpuExecutable( std::move(jit), std::move(assignment), std::move(module), function_name, - std::move(hlo_profile_printer), std::move(hlo_profile_index_map))); + std::move(hlo_profile_printer_data), std::move(hlo_profile_index_map))); if (embed_ir_in_executable) { static_cast(*cpu_executable) diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc index f335bd1bbc7376d1cccc0fa6aa1c0a6d6ad559ab..802d0a6fb46890b31d14b1fbf3b2e7d6520caccc 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc @@ -55,9 +55,9 @@ CpuExecutable::CpuExecutable( std::unique_ptr assignment, std::unique_ptr hlo_module, const string& entry_function_name, - std::unique_ptr hlo_profile_printer, + std::unique_ptr hlo_profile_printer_data, std::unique_ptr hlo_profile_index_map) - : Executable(std::move(hlo_module), std::move(hlo_profile_printer), + : Executable(std::move(hlo_module), std::move(hlo_profile_printer_data), std::move(hlo_profile_index_map)), jit_(std::move(jit)), assignment_(std::move(assignment)) { diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.h b/tensorflow/compiler/xla/service/cpu/cpu_executable.h index 50443a59954e222f65fc935e83effdaf6d6c8bf0..267b89a10b3c038dc2048f0ad5b5b343c88ef0f9 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_executable.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.h @@ -51,7 +51,7 @@ class CpuExecutable : public Executable { std::unique_ptr assignment, std::unique_ptr hlo_module, const string& entry_function_name, - std::unique_ptr hlo_profile_printer, + std::unique_ptr hlo_profile_printer_data, std::unique_ptr hlo_profile_index_map); ~CpuExecutable() override {} diff --git a/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h index 2271af7b247c2684d371010361308b4d7bcd6423..2924b6365943f0a3ec998d7a77767a76cbb576ae 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_HLO_SUPPORT_CHECKER_H_ -#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_HLO_SUPPORT_CHECKER_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_HLO_SUPPORT_CHECKER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_HLO_SUPPORT_CHECKER_H_ #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" @@ -39,4 +39,4 @@ class CpuHloSupportChecker : public HloPassInterface { } // namespace xla -#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_HLO_SUPPORT_CHECKER_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_HLO_SUPPORT_CHECKER_H_ diff --git a/tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h b/tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h index 2994642356d55df26c31553ef28dc653503d05be..664125ecc95ca5ac10be4201b9120ddbdb9b9821 100644 --- a/tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h +++ b/tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CUSTOM_CALL_TARGET_REGISTRY_H_ -#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CUSTOM_CALL_TARGET_REGISTRY_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CUSTOM_CALL_TARGET_REGISTRY_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CUSTOM_CALL_TARGET_REGISTRY_H_ // This file is depended on by kernels that have to build for mobile devices. // For this reason, we avoid relying on TensorFlow and instead only use the @@ -71,4 +71,4 @@ class RegisterCustomCallTarget { } // namespace cpu } // namespace xla -#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CUSTOM_CALL_TARGET_REGISTRY_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CUSTOM_CALL_TARGET_REGISTRY_H_ diff --git a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc index ebd96c4c42759b71b79408c73814605301af03c1..99c5e16db70c6a203b4751c1ed8a106c0dc260e6 100644 --- a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc @@ -33,8 +33,14 @@ StatusOr CpuElementalIrEmitter::EmitFloatUnaryOp( switch (op->opcode()) { case HloOpcode::kTanh: { PrimitiveType element_type = op->shape().element_type(); + bool cast_result_to_fp16 = false; string function_name; switch (element_type) { + case F16: + cast_result_to_fp16 = true; + operand_value = ir_builder_->CreateFPCast(operand_value, + ir_builder_->getFloatTy()); + TF_FALLTHROUGH_INTENDED; case F32: function_name = "tanhf"; break; @@ -44,7 +50,7 @@ StatusOr CpuElementalIrEmitter::EmitFloatUnaryOp( default: return Unimplemented("tanh"); } - // Create function declaration for 'tanhf'. + // Create a function declaration. llvm::Function* function = llvm::cast(module_->getOrInsertFunction( llvm_ir::AsStringRef(function_name), operand_value->getType(), @@ -52,8 +58,12 @@ StatusOr CpuElementalIrEmitter::EmitFloatUnaryOp( function->setCallingConv(llvm::CallingConv::C); function->setDoesNotThrow(); function->setDoesNotAccessMemory(); - // Create instruction to call 'tanhf'. - return ir_builder_->CreateCall(function, operand_value); + // Create an instruction to call the function. + llvm::Value* result = ir_builder_->CreateCall(function, operand_value); + if (cast_result_to_fp16) { + result = ir_builder_->CreateFPCast(result, ir_builder_->getHalfTy()); + } + return result; } default: return ElementalIrEmitter::EmitFloatUnaryOp(op, operand_value); @@ -63,7 +73,13 @@ StatusOr CpuElementalIrEmitter::EmitFloatUnaryOp( StatusOr CpuElementalIrEmitter::EmitAtan2( PrimitiveType prim_type, llvm::Value* lhs, llvm::Value* rhs) const { string function_name; + bool cast_result_to_fp16 = false; switch (prim_type) { + case F16: + cast_result_to_fp16 = true; + lhs = ir_builder_->CreateFPCast(lhs, ir_builder_->getFloatTy()); + rhs = ir_builder_->CreateFPCast(rhs, ir_builder_->getFloatTy()); + TF_FALLTHROUGH_INTENDED; case F32: function_name = "atan2f"; break; @@ -73,7 +89,7 @@ StatusOr CpuElementalIrEmitter::EmitAtan2( default: return Unimplemented("atan2"); } - // Create function declaration for 'atan2'. + // Create a function declaration. llvm::Function* function = llvm::cast(module_->getOrInsertFunction( llvm_ir::AsStringRef(function_name), lhs->getType(), lhs->getType(), @@ -81,8 +97,12 @@ StatusOr CpuElementalIrEmitter::EmitAtan2( function->setCallingConv(llvm::CallingConv::C); function->setDoesNotThrow(); function->setDoesNotAccessMemory(); - // Create instruction to call 'atan2'. - return ir_builder_->CreateCall(function, {lhs, rhs}); + // Create an instruction to call the function. + llvm::Value* result = ir_builder_->CreateCall(function, {lhs, rhs}); + if (cast_result_to_fp16) { + result = ir_builder_->CreateFPCast(result, ir_builder_->getHalfTy()); + } + return result; } llvm_ir::ElementGenerator CpuElementalIrEmitter::MakeElementGenerator( diff --git a/tensorflow/compiler/xla/service/cpu/external_constant_pool.h b/tensorflow/compiler/xla/service/cpu/external_constant_pool.h index 9c00d476b1fca6c3174af4ebb62dbbde324fd0ea..8008a56df4dbf16e7b57aee8a344058bb0d5883d 100644 --- a/tensorflow/compiler/xla/service/cpu/external_constant_pool.h +++ b/tensorflow/compiler/xla/service/cpu/external_constant_pool.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_EXTERNAL_CONSTANT_POOL_H_ -#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_EXTERNAL_CONSTANT_POOL_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_EXTERNAL_CONSTANT_POOL_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_EXTERNAL_CONSTANT_POOL_H_ #include @@ -62,4 +62,4 @@ class ExternalConstantPool { } // namespace cpu } // namespace xla -#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_EXTERNAL_CONSTANT_POOL_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_EXTERNAL_CONSTANT_POOL_H_ diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index cfdf9f4ebc5a5ae2b0188c86edcdc70e3a596971..71e81331897a8bb82438dd5160d2964cb88fd31f 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -62,6 +62,7 @@ limitations under the License. #include "tensorflow/core/lib/core/bits.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/lib/strings/stringprintf.h" @@ -72,6 +73,7 @@ namespace { using llvm_ir::AsStringRef; using llvm_ir::IrName; using llvm_ir::SetToFirstInsertPoint; +namespace gtl = tensorflow::gtl; } // namespace namespace cpu { @@ -491,7 +493,7 @@ Status IrEmitter::HandleTuple(HloInstruction* tuple) { } Status IrEmitter::HandleMap(HloInstruction* map) { - tensorflow::gtl::ArraySlice operands(map->operands()); + gtl::ArraySlice operands(map->operands()); HloComputation* function = map->to_apply(); // The called computation should have been emitted previously. llvm::Function* mapped_ir_function = FindOrDie(emitted_functions_, function); @@ -1225,205 +1227,6 @@ static llvm_ir::IrArray::Index FillReducedDimensionIndex( return index_with_free_var; } -Status IrEmitter::HandleBatchNormTraining(HloInstruction* batch_norm_training) { - // The output of BatchNormTraining is a tuple of three element: - // - An N-dimensional array containing normalized values. - // - A 1 dimensional array containing the mean value for each feature. - // - A 1 dimensional array containing the variance value for each feature. - HloInstruction* operand = batch_norm_training->operands()[0]; - HloInstruction* scale = batch_norm_training->operands()[1]; - HloInstruction* offset = batch_norm_training->operands()[2]; - float epsilon = batch_norm_training->epsilon(); - int64 feature_index = batch_norm_training->feature_index(); - TF_RET_CHECK(ShapeUtil::IsTuple(batch_norm_training->shape()) && - ShapeUtil::TupleElementCount(batch_norm_training->shape()) == 3); - - const Shape& output_shape = - ShapeUtil::GetTupleElementShape(batch_norm_training->shape(), 0); - const Shape& feature_shape = - ShapeUtil::GetTupleElementShape(batch_norm_training->shape(), 1); - - // Reduce vector of the non-feature dimensions. - std::vector dimensions_to_reduce; - - for (int64 i = 0; i < operand->shape().dimensions_size(); ++i) { - if (i != feature_index) { - dimensions_to_reduce.push_back(i); - } - } - - // Get the second and third allocations in the output tuple, which should be - // used to store the result of mean and variance value calculation. - TF_ASSIGN_OR_RETURN( - const BufferAllocation::Slice slice_mean, - assignment_.GetUniqueSlice(batch_norm_training, /*index=*/{1})); - TF_ASSIGN_OR_RETURN( - const BufferAllocation::Slice slice_var, - assignment_.GetUniqueSlice(batch_norm_training, /*index=*/{2})); - const int feature_count = output_shape.dimensions(feature_index); - const int size_in_elements = ShapeUtil::ElementsIn(output_shape); - TF_RET_CHECK(ShapeUtil::ElementsIn(operand->shape()) == size_in_elements); - const int elements_per_feature = size_in_elements / feature_count; - - llvm::Value* mean = EmitTempBufferPointer(slice_mean, feature_shape); - llvm_ir::IrArray mean_array(mean, feature_shape); - - llvm::Value* var = EmitTempBufferPointer(slice_var, feature_shape); - llvm_ir::IrArray var_array(var, feature_shape); - - // This loop calculates mean and variance for each feature. - // - // In theory this could be swapped by multi-output fusion. We will evaluate - // this when it's ready. - // - // For variance calculation, we use a simplified formula so we can fuse the - // computation into the same loop to calculate mean: Var=E(X^2) - E(X)^2. - TF_RETURN_IF_ERROR( - llvm_ir::LoopEmitter( - [&](const llvm_ir::IrArray::Index& index) { - PrimitiveType element_type = operand->shape().element_type(); - // Used to calculate E(X). - llvm::Value* sum_address = llvm_ir::EmitAllocaAtFunctionEntry( - llvm_ir::PrimitiveTypeToIrType(element_type, module_), - "sum_address", &ir_builder_, - MinimumAlignmentForPrimitiveType(element_type)); - - // Used to calculate E(X^2). - llvm::Value* sum_square_address = - llvm_ir::EmitAllocaAtFunctionEntry( - llvm_ir::PrimitiveTypeToIrType(element_type, module_), - "sum_square_address", &ir_builder_, - MinimumAlignmentForPrimitiveType(element_type)); - - ir_builder_.CreateStore( - llvm::ConstantFP::get(ir_builder_.getFloatTy(), 0.0), - sum_address); - - ir_builder_.CreateStore( - llvm::ConstantFP::get(ir_builder_.getFloatTy(), 0.0), - sum_square_address); - - llvm_ir::ForLoopNest loops(IrName(batch_norm_training, "inner"), - &ir_builder_); - - const llvm_ir::IrArray::Index reduced_dims_index = - loops.AddLoopsForShapeOnDimensions( - operand->shape(), dimensions_to_reduce, "reduction_dim"); - - SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), - &ir_builder_); - - llvm_ir::IrArray operand_array(GetIrArrayFor(operand)); - llvm_ir::IrArray::Index input_index = - FillReducedDimensionIndex(reduced_dims_index, index); - llvm::Value* new_value = - operand_array.EmitReadArrayElement(input_index, &ir_builder_); - - llvm::Value* new_value_square = - ir_builder_.CreateFMul(new_value, new_value); - - llvm::Value* current_sum = ir_builder_.CreateLoad(sum_address); - llvm::Value* current_sum_square = - ir_builder_.CreateLoad(sum_square_address); - // Update sum. - ir_builder_.CreateStore( - ir_builder_.CreateFAdd(current_sum, new_value), sum_address); - - // Update sum square. - ir_builder_.CreateStore( - ir_builder_.CreateFAdd(current_sum_square, new_value_square), - sum_square_address); - - SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), - &ir_builder_); - - llvm::Value* sum = ir_builder_.CreateLoad(sum_address); - llvm::Value* elements_per_feature_value = llvm::ConstantFP::get( - ir_builder_.getFloatTy(), elements_per_feature); - llvm::Value* mean = - ir_builder_.CreateFDiv(sum, elements_per_feature_value); - llvm::Value* mean_square = ir_builder_.CreateFMul(mean, mean); - llvm::Value* sum_square = - ir_builder_.CreateLoad(sum_square_address); - - // Var=E(X^2) - E(X)^2. - llvm::Value* var = ir_builder_.CreateFSub( - ir_builder_.CreateFDiv(sum_square, elements_per_feature_value), - mean_square); - - var_array.EmitWriteArrayElement(index, var, &ir_builder_); - return mean; - }, - mean_array, &ir_builder_) - .EmitLoop(IrName(batch_norm_training, "mean_var"))); - - TF_RETURN_IF_ERROR(EmitTargetAddressForOp(batch_norm_training)); - TF_ASSIGN_OR_RETURN( - const BufferAllocation::Slice slice, - assignment_.GetUniqueSlice(batch_norm_training, /*index=*/{0})); - - llvm::Value* normalized = EmitTempBufferPointer(slice, output_shape); - - llvm_ir::IrArray target_array(normalized, output_shape); - - AddAliasingInformationToIrArray(*batch_norm_training, &target_array); - - TF_RETURN_IF_ERROR( - llvm_ir::LoopEmitter( - [this, mean_array, var_array, epsilon, operand, dimensions_to_reduce, - feature_index, offset, scale](const llvm_ir::IrArray::Index& index) { - // The following logic normalizes the input value, scales and shifts - // it: - // - // normalized = (input - mean) / sqrt(variance + epsilon) - // result = normalized * scale + offset - - // Current index in the feature dimension. - llvm_ir::IrArray::Index feature_index_value(1, - index[feature_index]); - - llvm::Value* mean = mean_array.EmitReadArrayElement( - feature_index_value, &ir_builder_); - llvm::Value* var = var_array.EmitReadArrayElement( - feature_index_value, &ir_builder_); - - llvm_ir::IrArray operand_array(GetIrArrayFor(operand)); - llvm::Value* input = - operand_array.EmitReadArrayElement(index, &ir_builder_); - - llvm::Value* variance_with_epsilon = ir_builder_.CreateFAdd( - var, llvm::ConstantFP::get(ir_builder_.getFloatTy(), epsilon)); - llvm::Function* func_llvm_sqrt = llvm::Intrinsic::getDeclaration( - module_, llvm::Intrinsic::sqrt, {ir_builder_.getFloatTy()}); - llvm::Value* variance_sqrt = - ir_builder_.CreateCall(func_llvm_sqrt, {variance_with_epsilon}); - llvm::Value* normalized = ir_builder_.CreateFDiv( - ir_builder_.CreateFSub(input, mean), variance_sqrt); - llvm_ir::IrArray offset_array(GetIrArrayFor(offset)); - llvm::Value* offset = offset_array.EmitReadArrayElement( - feature_index_value, &ir_builder_); - llvm_ir::IrArray scale_array(GetIrArrayFor(scale)); - llvm::Value* scale = scale_array.EmitReadArrayElement( - feature_index_value, &ir_builder_); - llvm::Value* result = ir_builder_.CreateFAdd( - ir_builder_.CreateFMul(normalized, scale), offset); - - return result; - }, - target_array, &ir_builder_) - .EmitLoop(IrName(batch_norm_training, "normalize"))); - - llvm_ir::EmitTuple(GetIrArrayFor(batch_norm_training), - {normalized, mean, var}, &ir_builder_, module_); - return Status::OK(); -} - -Status IrEmitter::HandleBatchNormGrad(HloInstruction* batch_norm_grad) { - // TODO(b/62843645) Implement BatchNormGrad on CPU backend. - return Unimplemented( - "BatchNormGrad is not implemented on CPU. See b/62843645."); -} - Status IrEmitter::HandleParameter(HloInstruction* parameter) { VLOG(2) << "HandleParameter: " << parameter->ToString(); auto param_number = parameter->parameter_number(); @@ -1469,6 +1272,52 @@ Status IrEmitter::HandleParameter(HloInstruction* parameter) { return Status::OK(); } +// Returns true if the relative order of the unreduced dimensions stays the same +// through the reduce operation. +static bool ReductionPreservesLayout(const HloInstruction& reduce) { + DCHECK_EQ(reduce.opcode(), HloOpcode::kReduce); + + // Maps dimensions that were not reduced from their dimension numbers in the + // source shape to their dimensions numbers in the destination shape. + // + // So if we reduce f32[A,B,C,D] on dimensions 1 and 2, this map contains + // [0->0, 3->1]. + gtl::FlatMap unreduced_dim_map; + + gtl::FlatSet reduced_dims(reduce.dimensions().begin(), + reduce.dimensions().end()); + + const Shape& operand_shape = reduce.operand(0)->shape(); + const Shape& result_shape = reduce.shape(); + + int64 delta = 0; + for (int64 i = 0; i < operand_shape.dimensions_size(); i++) { + if (reduced_dims.count(i)) { + delta++; + } else { + InsertOrDie(&unreduced_dim_map, i, i - delta); + } + } + + // Iterate dimensions minor to major and check that the corresponding + // dimensions in the source and target shapes are equivalent. + int64 result_dim_idx = 0; + for (int64 operand_dim_idx = 0; + operand_dim_idx < operand_shape.dimensions_size(); operand_dim_idx++) { + int64 operand_dim = operand_shape.layout().minor_to_major(operand_dim_idx); + if (!reduced_dims.count(operand_dim)) { + if (FindOrDie(unreduced_dim_map, operand_dim) != + result_shape.layout().minor_to_major(result_dim_idx++)) { + return false; + } + } + } + + CHECK_EQ(result_dim_idx, result_shape.dimensions_size()); + + return true; +} + IrEmitter::ReductionGenerator IrEmitter::MatchReductionGenerator( HloComputation* function, string* failure_reason) const { CHECK_EQ(function->num_parameters(), 2); @@ -1632,7 +1481,7 @@ IrEmitter::EmitInnerLoopForVectorizedReduction( const ReductionGenerator& reduction_generator, const llvm_ir::IrArray::Index& output_index, const ShardedVectorType& accumulator_type, HloInstruction* init_value, - HloInstruction* arg, tensorflow::gtl::ArraySlice dimensions, + HloInstruction* arg, gtl::ArraySlice dimensions, unsigned element_alignment) { ShardedVector accumulator; accumulator.reserve(accumulator_type.size()); @@ -1736,8 +1585,12 @@ void IrEmitter::EmitShardedVectorStore( StatusOr IrEmitter::EmitVectorizedReduce( HloInstruction* reduce, HloInstruction* arg, HloInstruction* init_value, - tensorflow::gtl::ArraySlice dimensions, HloComputation* function, + gtl::ArraySlice dimensions, HloComputation* function, string* failure_reason) { + if (!ReductionPreservesLayout(*reduce)) { + return false; + } + ReductionGenerator reduction_generator = MatchReductionGenerator(function, failure_reason); if (!reduction_generator) { @@ -1881,7 +1734,7 @@ StatusOr IrEmitter::EmitVectorizedReduce( Status IrEmitter::HandleReduce(HloInstruction* reduce) { auto arg = reduce->mutable_operand(0); auto init_value = reduce->mutable_operand(1); - tensorflow::gtl::ArraySlice dimensions(reduce->dimensions()); + gtl::ArraySlice dimensions(reduce->dimensions()); HloComputation* function = reduce->to_apply(); if (!options::VectorizedReduceDisabled(hlo_module_config_)) { string vectorization_failure_reason; @@ -2001,7 +1854,7 @@ Status IrEmitter::HandleSlice(HloInstruction* slice) { // // * Implement the memcpy within the innermost loop. - tensorflow::gtl::FlatSet inner_dims; + gtl::FlatSet inner_dims; for (int64 dim : LayoutUtil::MinorToMajor(layout)) { if (operand->shape().dimensions(dim) != slice->shape().dimensions(dim)) { break; @@ -2329,8 +2182,7 @@ Status IrEmitter::HandleCall(HloInstruction* call) { } Status IrEmitter::HandleCustomCall(HloInstruction* custom_call) { - tensorflow::gtl::ArraySlice operands( - custom_call->operands()); + gtl::ArraySlice operands(custom_call->operands()); tensorflow::StringPiece custom_call_target(custom_call->custom_call_target()); llvm::Type* i8_ptr_type = ir_builder_.getInt8PtrTy(); llvm::AllocaInst* operands_alloca = @@ -2461,8 +2313,7 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) { } StatusOr IrEmitter::EmitFastConcatenate( - HloInstruction* concatenate, - tensorflow::gtl::ArraySlice operands, + HloInstruction* concatenate, gtl::ArraySlice operands, string* failure_reason) { if (ShouldEmitParallelLoopFor(*concatenate)) { *failure_reason = @@ -2601,8 +2452,7 @@ void IrEmitter::EmitTransferElements(llvm::Value* target, llvm::Value* source, } Status IrEmitter::HandleConcatenate(HloInstruction* concatenate) { - tensorflow::gtl::ArraySlice operands( - concatenate->operands()); + gtl::ArraySlice operands(concatenate->operands()); string failure_reason; TF_ASSIGN_OR_RETURN( bool successful, @@ -2915,7 +2765,7 @@ llvm::Value* IrEmitter::EmitTempBufferPointer( // for a single element_type value, and loads it after call. llvm::Value* IrEmitter::EmitElementFunctionCall( llvm::Function* function, const Shape& return_shape, - tensorflow::gtl::ArraySlice parameter_addresses, + gtl::ArraySlice parameter_addresses, tensorflow::StringPiece name) { llvm::Value* return_value_buffer = EmitArrayFunctionCall( function, return_shape, 1, parameter_addresses, name); @@ -2935,8 +2785,7 @@ llvm::Value* IrEmitter::EmitElementFunctionCall( // temps) // return return_value_buffer -- address of the return value. void IrEmitter::EmitArrayFunctionCallInto( - llvm::Function* function, - tensorflow::gtl::ArraySlice parameter_addresses, + llvm::Function* function, gtl::ArraySlice parameter_addresses, llvm::Value* return_value_buffer, tensorflow::StringPiece name) { ir_builder_.CreateCall( function, GetArrayFunctionCallArguments( @@ -2949,7 +2798,7 @@ void IrEmitter::EmitArrayFunctionCallInto( llvm::Value* IrEmitter::EmitArrayFunctionCall( llvm::Function* function, const Shape& return_shape, int64 element_count, - tensorflow::gtl::ArraySlice parameter_addresses, + gtl::ArraySlice parameter_addresses, tensorflow::StringPiece name) { llvm::Value* elements = llvm::ConstantInt::get(ir_builder_.getInt64Ty(), element_count); @@ -3059,8 +2908,8 @@ Status IrEmitter::EmitMemcpy(const HloInstruction& source, Status IrEmitter::ElementTypesSameAndSupported( const HloInstruction& instruction, - tensorflow::gtl::ArraySlice operands, - tensorflow::gtl::ArraySlice supported_types) { + gtl::ArraySlice operands, + gtl::ArraySlice supported_types) { for (auto operand : operands) { TF_RET_CHECK( ShapeUtil::SameElementType(operands[0]->shape(), operand->shape())); diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h index 66f2aeeab33dbaa34297c8dc6a37c3ad481820d8..509440251497cd7337284c39dae05c5f6c28e7c2 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h @@ -125,8 +125,6 @@ class IrEmitter : public DfsHloVisitorWithDefault { Status HandleDot(HloInstruction* dot) override; Status HandleConvolution(HloInstruction* convolution) override; Status HandleFft(HloInstruction* fft) override; - Status HandleBatchNormTraining(HloInstruction* batch_norm_training) override; - Status HandleBatchNormGrad(HloInstruction* batch_norm_grad) override; Status HandleCrossReplicaSum(HloInstruction* crs) override; Status HandleInfeed(HloInstruction* infeed) override; Status HandleOutfeed(HloInstruction* outfeed) override; diff --git a/tensorflow/compiler/xla/service/cpu/ir_function.h b/tensorflow/compiler/xla/service/cpu/ir_function.h index 1fd2da4dce23982ed030f3aa8ec604182d0ebab8..557aa4a6bfc2ef70cafca4b226f8d8f15ea01e2b 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_function.h +++ b/tensorflow/compiler/xla/service/cpu/ir_function.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_IR_FUNCTION_H_ -#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_IR_FUNCTION_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_IR_FUNCTION_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_IR_FUNCTION_H_ #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" @@ -131,4 +131,4 @@ Status EmitCallToParallelForkJoin( } // namespace cpu } // namespace xla -#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_IR_FUNCTION_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_IR_FUNCTION_H_ diff --git a/tensorflow/compiler/xla/service/cpu/orc_jit_memory_mapper.h b/tensorflow/compiler/xla/service/cpu/orc_jit_memory_mapper.h index 2d29550fd5bd659770cc6300e56b57bf1763e671..f8963841158b71a30aa926e3b2b153c42bf78eb1 100644 --- a/tensorflow/compiler/xla/service/cpu/orc_jit_memory_mapper.h +++ b/tensorflow/compiler/xla/service/cpu/orc_jit_memory_mapper.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_ORC_JIT_MEMORY_MAPPER_H_ -#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_ORC_JIT_MEMORY_MAPPER_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_ORC_JIT_MEMORY_MAPPER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_ORC_JIT_MEMORY_MAPPER_H_ #include @@ -53,4 +53,4 @@ class Registrar { } // namespace cpu } // namespace xla -#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_ORC_JIT_MEMORY_MAPPER_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_ORC_JIT_MEMORY_MAPPER_H_ diff --git a/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc index d1b88b27f068962fb86477fcad3e4390b1636c2b..cd997f07890cdc1d9a546ede58cc1d992b6416ae 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc @@ -61,9 +61,9 @@ ParallelCpuExecutable::ParallelCpuExecutable( std::unique_ptr> function_names, std::unordered_map> aligned_constants, - std::unique_ptr hlo_profile_printer, + std::unique_ptr hlo_profile_printer_data, std::unique_ptr hlo_profile_index_map) - : Executable(std::move(hlo_module), std::move(hlo_profile_printer), + : Executable(std::move(hlo_module), std::move(hlo_profile_printer_data), std::move(hlo_profile_index_map)), jit_(std::move(jit)), assignment_(std::move(assignment)), diff --git a/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h b/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h index 90ac94ef9288b2e860cb30c47ed44a7b96e4825d..c393e9b8ea39bfb4c605ebba8e2cd29726bc4af9 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h +++ b/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h @@ -55,7 +55,7 @@ class ParallelCpuExecutable : public Executable { std::unordered_map> aligned_constants, - std::unique_ptr hlo_profile_printer, + std::unique_ptr hlo_profile_printer_data, std::unique_ptr hlo_profile_index_map); ~ParallelCpuExecutable() override {} diff --git a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h index 9335d2818e99eb3588537d80dabddda08c1c020e..ce92e36a944de33b991d97460f0b2e859ad56081 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_PARALLEL_LOOP_EMITTER_H_ -#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_PARALLEL_LOOP_EMITTER_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_PARALLEL_LOOP_EMITTER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_PARALLEL_LOOP_EMITTER_H_ #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Value.h" @@ -70,4 +70,4 @@ class ParallelLoopEmitter : public llvm_ir::LoopEmitter { } // namespace cpu } // namespace xla -#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_PARALLEL_LOOP_EMITTER_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_PARALLEL_LOOP_EMITTER_H_ diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h index 5801ec8d270cdaed7f2f65c24987a9ea643edb02..7140dabe516cd7ea9260456e994e8b63b68c60d6 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h +++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_PARALLEL_TASK_ASSIGNMENT_H_ -#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_PARALLEL_TASK_ASSIGNMENT_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_PARALLEL_TASK_ASSIGNMENT_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_PARALLEL_TASK_ASSIGNMENT_H_ #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" #include "tensorflow/compiler/xla/service/hlo_module.h" @@ -99,4 +99,4 @@ class ParallelTaskAssigner : public HloPassInterface { } // namespace cpu } // namespace xla -#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_PARALLEL_TASK_ASSIGNMENT_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_PARALLEL_TASK_ASSIGNMENT_H_ diff --git a/tensorflow/compiler/xla/service/cpu/runtime_fork_join.h b/tensorflow/compiler/xla/service/cpu/runtime_fork_join.h index fcf1cc62078d3847435a2e75e3ca9d109cf8b200..1cf0ec6e3df400e35fa4e755a0b25b4ce7966e8f 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_fork_join.h +++ b/tensorflow/compiler/xla/service/cpu/runtime_fork_join.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_FORK_JOIN_H_ -#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_FORK_JOIN_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_FORK_JOIN_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_FORK_JOIN_H_ #include "tensorflow/core/platform/types.h" @@ -30,4 +30,4 @@ extern void __xla_cpu_runtime_ParallelForkJoin( } // extern "C" -#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_FORK_JOIN_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_FORK_JOIN_H_ diff --git a/tensorflow/compiler/xla/service/cpu/runtime_matvec.h b/tensorflow/compiler/xla/service/cpu/runtime_matvec.h index cb7e0a81f09e2702de565012e1fcac8b7cd841ab..1bd8dfb377acc1f7cfbe9a92773f87f0ef25de3a 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_matvec.h +++ b/tensorflow/compiler/xla/service/cpu/runtime_matvec.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_MATVEC_H_ -#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_MATVEC_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_MATVEC_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_MATVEC_H_ #include "tensorflow/core/platform/types.h" @@ -42,4 +42,4 @@ void EigenMatVecF64(double* out, double* lhs, double* rhs, tensorflow::int64 m, } // namespace xla -#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_MATVEC_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_MATVEC_H_ diff --git a/tensorflow/compiler/xla/service/cpu/shape_partition.h b/tensorflow/compiler/xla/service/cpu/shape_partition.h index 7a2d00421cfdc8e41ec48698a16665621de16bda..33d02b70e61e3311c9af934e80874939fbe3adae 100644 --- a/tensorflow/compiler/xla/service/cpu/shape_partition.h +++ b/tensorflow/compiler/xla/service/cpu/shape_partition.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_SHAPE_PARTITION_H_ -#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_SHAPE_PARTITION_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_SHAPE_PARTITION_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_SHAPE_PARTITION_H_ #include @@ -102,4 +102,4 @@ class ShapePartitionIterator { } // namespace cpu } // namespace xla -#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_SHAPE_PARTITION_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_SHAPE_PARTITION_H_ diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc index 5403bf48b748c587802c6ed7abb4699e8395ca67..de5e9b411905a37a7db7d05f51cca2802c1526ed 100644 --- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc +++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc @@ -47,7 +47,7 @@ namespace cpu { namespace { // A simple SymbolResolver that delegates to the host dynamic linker. -class SimpleResolver : public llvm::JITSymbolResolver { +class SimpleResolver : public llvm::LegacyJITSymbolResolver { public: explicit SimpleResolver(ExternalConstantPool* external_constant_pool) : external_constant_pool_(external_constant_pool) {} diff --git a/tensorflow/compiler/xla/service/dot_decomposer.h b/tensorflow/compiler/xla/service/dot_decomposer.h index 5ff0ab34eac0cd0fbc264b408c57653c944402a6..1959b687f16d6909a3283021c8635b3e65e6e412 100644 --- a/tensorflow/compiler/xla/service/dot_decomposer.h +++ b/tensorflow/compiler/xla/service/dot_decomposer.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_DOT_DECOMPOSER_H_ -#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_DOT_DECOMPOSER_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_DOT_DECOMPOSER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_DOT_DECOMPOSER_H_ #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" @@ -41,4 +41,4 @@ class DotDecomposer : public HloPassInterface { } // namespace xla -#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_DOT_DECOMPOSER_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_DOT_DECOMPOSER_H_ diff --git a/tensorflow/compiler/xla/service/executable.cc b/tensorflow/compiler/xla/service/executable.cc index 21e7fbea291721dfc446bae2a7002a8ec2520be4..90481c7a88f90edea5399ee44aee2d2c77fc115f 100644 --- a/tensorflow/compiler/xla/service/executable.cc +++ b/tensorflow/compiler/xla/service/executable.cc @@ -73,7 +73,7 @@ StatusOr> Executable::ExecuteOnStreamWrapper( std::unique_ptr profile_ptr = module_config().debug_options().xla_hlo_profile() && hlo_profiling_enabled() - ? MakeUnique(&hlo_profile_printer(), + ? MakeUnique(&hlo_profile_printer_data(), &hlo_profile_index_map()) : nullptr; diff --git a/tensorflow/compiler/xla/service/executable.h b/tensorflow/compiler/xla/service/executable.h index 5ecfdffe211c571b1bb2bc30ff2acd3021c735ae..0aee535ee780ef000bc5e9963ff48786b3a61eb2 100644 --- a/tensorflow/compiler/xla/service/executable.h +++ b/tensorflow/compiler/xla/service/executable.h @@ -44,13 +44,14 @@ namespace xla { // interface that is used for launching compiled programs across platforms. class Executable { public: - explicit Executable(std::unique_ptr hlo_module, - std::unique_ptr hlo_profile_printer, - std::unique_ptr hlo_profile_index_map) + explicit Executable( + std::unique_ptr hlo_module, + std::unique_ptr hlo_profile_printer_data, + std::unique_ptr hlo_profile_index_map) : hlo_module_(std::move(hlo_module)), - hlo_profile_printer_(std::move(hlo_profile_printer)), + hlo_profile_printer_data_(std::move(hlo_profile_printer_data)), hlo_profile_index_map_(std::move(hlo_profile_index_map)) { - CHECK_EQ(hlo_profile_printer_.get() == nullptr, + CHECK_EQ(hlo_profile_printer_data_.get() == nullptr, hlo_profile_index_map_.get() == nullptr); } virtual ~Executable() {} @@ -116,9 +117,9 @@ class Executable { "Equality test on this executable is not implemented."); } - const HloProfilePrinter& hlo_profile_printer() const { + const HloProfilePrinterData& hlo_profile_printer_data() const { CHECK(hlo_profiling_enabled()); - return *hlo_profile_printer_; + return *hlo_profile_printer_data_; } const HloProfileIndexMap& hlo_profile_index_map() const { @@ -129,7 +130,9 @@ class Executable { // Returns whether this executable was compiled with HLO profilings support // enabled. If not, the caller should not expect an hlo_execution_profile // passed to ExecuteOnStream above to be populated during execution. - bool hlo_profiling_enabled() const { return hlo_profile_printer_ != nullptr; } + bool hlo_profiling_enabled() const { + return hlo_profile_printer_data_ != nullptr; + } const HloModule& module() const { return *hlo_module_; } @@ -179,7 +182,7 @@ class Executable { // execution. int64 execution_count_ = 0; - std::unique_ptr hlo_profile_printer_; + std::unique_ptr hlo_profile_printer_data_; std::unique_ptr hlo_profile_index_map_; }; diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index d7ca0f6846834ae77569930325d3fc6b9fd5cca8..df5e2e35f802b476f4d9fef2cd4816089663686f 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -475,6 +475,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_constant_folding", "//tensorflow/compiler/xla/service:hlo_cse", "//tensorflow/compiler/xla/service:hlo_dce", + "//tensorflow/compiler/xla/service:hlo_element_type_converter", "//tensorflow/compiler/xla/service:hlo_pass", "//tensorflow/compiler/xla/service:hlo_pass_pipeline", "//tensorflow/compiler/xla/service:hlo_proto", diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc index 4b511cb4bb94addfae53d6b2e6d6f86d5b9afd84..5af7a77ea858563fbea05af8efd54f96a74aee93 100644 --- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc @@ -72,9 +72,27 @@ StatusOr GpuElementalIrEmitter::EmitLibdeviceMathCall( tensorflow::gtl::ArraySlice input_types, PrimitiveType output_type) const { // The libdevice math functions differentiate between "double" and "float" by - // appending an 'f' to the function's name. + // appending an 'f' to the function's name. libdevice doesn't have f16 math + // functions, so we convert the operands to f32 before calling the function + // and then convert the result back to f16. string munged_callee = callee_name; + bool cast_result_to_fp16 = false; + std::vector converted_operands(operands.begin(), + operands.end()); + std::vector converted_input_types(input_types.begin(), + input_types.end()); switch (output_type) { + case F16: + cast_result_to_fp16 = true; + for (int64 i = 0; i < operands.size(); ++i) { + if (input_types[i] == F16) { + converted_operands[i] = ir_builder_->CreateFPCast( + converted_operands[i], ir_builder_->getFloatTy()); + converted_input_types[i] = F32; + } + } + output_type = F32; + TF_FALLTHROUGH_INTENDED; case F32: StrAppend(&munged_callee, "f"); break; @@ -84,7 +102,13 @@ StatusOr GpuElementalIrEmitter::EmitLibdeviceMathCall( return Unimplemented("Bad type for libdevice math call: %s", PrimitiveType_Name(output_type).c_str()); } - return EmitMathCall(munged_callee, operands, input_types, output_type); + llvm::Value* result = EmitMathCall(munged_callee, converted_operands, + converted_input_types, output_type) + .ValueOrDie(); + if (cast_result_to_fp16) { + result = ir_builder_->CreateFPCast(result, ir_builder_->getHalfTy()); + } + return result; } StatusOr GpuElementalIrEmitter::EmitLlvmIntrinsicMathCall( @@ -92,10 +116,13 @@ StatusOr GpuElementalIrEmitter::EmitLlvmIntrinsicMathCall( tensorflow::gtl::ArraySlice operands, tensorflow::gtl::ArraySlice input_types, PrimitiveType output_type) const { - // llvm intrinsics differentiate between float/double functions via the ".f32" - // and ".f64" suffixes. + // llvm intrinsics differentiate between half/float/double functions via + // the suffixes ".f16", ".f32" and ".f64". string munged_callee = callee_name; switch (output_type) { + case F16: + StrAppend(&munged_callee, ".f16"); + break; case F32: StrAppend(&munged_callee, ".f32"); break; @@ -233,12 +260,6 @@ StatusOr GpuElementalIrEmitter::EmitFloatUnaryOp( PrimitiveType input_type = op->operand(0)->shape().element_type(); PrimitiveType output_type = op->shape().element_type(); switch (op->opcode()) { - case HloOpcode::kFloor: - return EmitLibdeviceMathCall("__nv_floor", {operand_value}, {input_type}, - output_type); - case HloOpcode::kCeil: - return EmitLibdeviceMathCall("__nv_ceil", {operand_value}, {input_type}, - output_type); case HloOpcode::kTanh: return EmitLibdeviceMathCall("__nv_tanh", {operand_value}, {input_type}, output_type); diff --git a/tensorflow/compiler/xla/service/gpu/for_thunk.h b/tensorflow/compiler/xla/service/gpu/for_thunk.h index 525a2af941e77a27c0e01543e00e8a4c3e4b9f62..832494d17e9c4e1d9e92e18ef331df1cf3689024 100644 --- a/tensorflow/compiler/xla/service/gpu/for_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/for_thunk.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_GPU_FOR_THUNK_H_ -#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_GPU_FOR_THUNK_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_FOR_THUNK_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_FOR_THUNK_H_ #include @@ -49,4 +49,4 @@ class ForThunk : public Thunk { } // namespace gpu } // namespace xla -#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_GPU_FOR_THUNK_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_FOR_THUNK_H_ diff --git a/tensorflow/compiler/xla/service/gpu/fusion_merger.h b/tensorflow/compiler/xla/service/gpu/fusion_merger.h index bd720f8584f6254c43a3e2a1a5399aa919eebbc0..4c523a66de977cd32423b25f0d165c4f4ba51c4a 100644 --- a/tensorflow/compiler/xla/service/gpu/fusion_merger.h +++ b/tensorflow/compiler/xla/service/gpu/fusion_merger.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_GPU_FUSION_MERGER_H_ -#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_GPU_FUSION_MERGER_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_FUSION_MERGER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_FUSION_MERGER_H_ #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" @@ -44,4 +44,4 @@ class FusionMerger : public HloPassInterface { } // namespace gpu } // namespace xla -#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_GPU_FUSION_MERGER_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_FUSION_MERGER_H_ diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc index 89acac2c3ff77a93b6cf3b871a130dcd7edecf30..0cca3ca0926ad1f9fe21803a771d66ac8b1affaf 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc @@ -58,6 +58,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_constant_folding.h" #include "tensorflow/compiler/xla/service/hlo_cse.h" #include "tensorflow/compiler/xla/service/hlo_dce.h" +#include "tensorflow/compiler/xla/service/hlo_element_type_converter.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_pass_fix.h" #include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h" @@ -137,6 +138,10 @@ tensorflow::Status OptimizeHloModule(HloModule* hlo_module) { // TODO(b/64094172): make Call work on GPU instead of inlining. pipeline.AddPass(); + // Convert BF16 operations to F32 operations so that the GPU backend can + // support BF16 operations without directly implementing a BF16 lowering for + // most ops. + pipeline.AddPass(BF16, F32); pipeline.AddPass(); { auto& pass = @@ -281,14 +286,16 @@ void WarnIfBadPtxasVersion(const string& ptxas_path) { return; } - // ptxas 9.0 before 9.0.276 miscompiles some address calculations with large - // offsets (e.g. "load ptr + large_constant"), b/70245379. - if (vmaj == 9 && vmin == 0 && vdot < 276) { + // ptxas 9.0 before 9.0.276 and ptxas 9.1 before 9.1.121 miscompile some + // address calculations with large offsets (e.g. "load ptr + large_constant"), + // b/70245379. + if ((vmaj == 9 && vmin == 0 && vdot < 276) || + (vmaj == 9 && vmin == 1 && vdot < 121)) { LOG(WARNING) << "*** WARNING *** You are using ptxas " << vmaj << "." << vmin << "." << vdot - << ", which is in range [9.0.0, 9.0.276). These versions are " - "known to miscompile XLA code, leading to incorrect " - "results or invalid-address errors."; + << ", which is in range [9.0.0, 9.0.276) + [9.1.0, 9.1.121). " + "These versions are known to miscompile XLA code, leading " + "to incorrect results or invalid-address errors."; } } @@ -309,16 +316,24 @@ void WarnIfBadDriverJITVersion() { } se::cuda::DriverVersion version = version_or_status.ValueOrDie(); - // The driver JIT in 384 before 384.108 miscompiles some address + // The following versions of the driver JIT miscompile some address // calculations with large offsets (e.g. "load ptr + large_constant"), - // b/70245379. - if (std::get<0>(version) == 384 && std::get<1>(version) < 108) { + // b/70245379: + // + // - 384.x before 384.108 + // - 387.x before 387.40 + // - 390.x before 390.10. + auto vmaj = std::get<0>(version); + auto vmin = std::get<1>(version); + if ((vmaj == 384 && vmin < 108) || // + (vmaj == 387 && vmin < 40) || // + (vmaj == 390 && vmin < 10)) { LOG(WARNING) << "*** WARNING *** Invoking the PTX->SASS JIT from driver version " << se::cuda::DriverVersionToString(version) - << ", which is in range [384.0.0, 384.108.0). These versions are " - "known to miscompile XLA code, leading to incorrect results or " - "invalid-address errors."; + << ", which is in range [384.0.0, 384.108.0) + [387.0.0, 387.40.0) + " + "[390.0.0, 390.10.0). These versions are known to miscompile XLA " + "code, leading to incorrect results or invalid-address errors."; } }); } @@ -578,14 +593,14 @@ StatusOr> GpuCompiler::RunBackend( XLA_VLOG_LINES(2, thunk_schedule->ToString()); std::unique_ptr profile_index_map; - std::unique_ptr profile_printer; + std::unique_ptr profile_printer; if (module->config().hlo_profiling_enabled()) { HloCostAnalysis cost_analysis(ShapeSizeBytesFunction()); TF_RETURN_IF_ERROR(module->entry_computation()->Accept(&cost_analysis)); profile_index_map = MakeUnique(*module); profile_printer = - CreateHloProfilePrinter(*profile_index_map, cost_analysis); + CreateHloProfilePrinterData(*profile_index_map, cost_analysis); } auto* gpu_executable = new GpuExecutable( diff --git a/tensorflow/compiler/xla/service/gpu/gpu_constants.h b/tensorflow/compiler/xla/service/gpu/gpu_constants.h index 572c85628278752f924b90dbb7134c5fc8fb9740..eb1ca4c6c95a23d2a08f5f9c3cbc85e7d47d4f89 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_constants.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_constants.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_CONSTANTS_H_ -#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_CONSTANTS_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_CONSTANTS_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_CONSTANTS_H_ #include "tensorflow/compiler/xla/types.h" @@ -28,4 +28,4 @@ extern const int64 kCudaMallocAlignBytes; } // namespace gpu } // namespace xla -#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_CONSTANTS_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_CONSTANTS_H_ diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc index 51d164cdf427f9513bc340e090832a9b064b999c..f5d67b9ea9498df3f023ea9a694a63b468c5be18 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc @@ -116,9 +116,9 @@ GpuExecutable::GpuExecutable( std::unique_ptr thunk_schedule, std::unique_ptr hlo_module, std::unique_ptr assignment, - std::unique_ptr hlo_profile_printer, + std::unique_ptr hlo_profile_printer_data, std::unique_ptr hlo_profile_index_map) - : Executable(std::move(hlo_module), std::move(hlo_profile_printer), + : Executable(std::move(hlo_module), std::move(hlo_profile_printer_data), std::move(hlo_profile_index_map)), ptx_(ptx), cubin_(cubin), diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.h b/tensorflow/compiler/xla/service/gpu/gpu_executable.h index 00da64dfade8ddb0694c0ee7ac158c9f2e15a508..b19cfd43debd0a5490495d176fa2f1fcd625da07 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.h @@ -54,7 +54,7 @@ class GpuExecutable : public Executable { std::unique_ptr thunk_schedule, std::unique_ptr hlo_module, std::unique_ptr assignment, - std::unique_ptr hlo_profile_printer, + std::unique_ptr hlo_profile_printer_data, std::unique_ptr hlo_profile_index_map); // This should be called after set_ir_module_string. diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h index d9550f81b591ead3f6e8d3de4f62896ee04d2f82..d63e213d2b1efab4bcff75541cc5ab33d7a07976 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_HLO_SUPPORT_CHECKER_H_ -#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_HLO_SUPPORT_CHECKER_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_HLO_SUPPORT_CHECKER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_HLO_SUPPORT_CHECKER_H_ #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" @@ -39,4 +39,4 @@ class GpuHloSupportChecker : public HloPassInterface { } // namespace xla -#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_HLO_SUPPORT_CHECKER_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_HLO_SUPPORT_CHECKER_H_ diff --git a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc index c29fee0879c02021fdc23ac0e02ab398cf40f99e..2923a79af0a559b08a2126162130a83801d024f8 100644 --- a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc +++ b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc @@ -28,7 +28,7 @@ namespace gpu { namespace { bool IsForwardConvolutionCanonical(const HloInstruction& conv) { CHECK_EQ(HloOpcode::kConvolution, conv.opcode()); - return window_util::HasEvenPadding(conv.window()) && + return window_util::HasSymmetricPadding(conv.window()) && !window_util::HasNegativePadding(conv.window()) && !window_util::HasDilation(conv.window()); } @@ -43,7 +43,7 @@ HloInstruction* MaybePaddedAndSlicedInput( const Window& conv_window, const ConvolutionDimensionNumbers& conv_dnums, HloInstruction* input) { HloComputation* computation = input->parent(); - if (!window_util::HasEvenPadding(conv_window) || + if (!window_util::HasSymmetricPadding(conv_window) || window_util::HasBaseDilation(conv_window)) { // If padding is uneven or has dilation, we insert a kPad instruction that // applies positive padding and dilation. @@ -190,7 +190,7 @@ void IncreasePaddingHighBy(int64 delta, WindowDimension* window_dim) { bool PadInsertion::CanonicalizeBackwardFilterConvolution( HloInstruction* backward_conv) { - if (window_util::HasEvenPadding(backward_conv->window())) { + if (window_util::HasSymmetricPadding(backward_conv->window())) { return false; } @@ -285,7 +285,7 @@ bool PadInsertion::CanonicalizeBackwardFilterConvolution( bool PadInsertion::CanonicalizeBackwardInputConvolution( HloInstruction* backward_conv) { - if (window_util::HasEvenPadding(backward_conv->window())) { + if (window_util::HasSymmetricPadding(backward_conv->window())) { return false; } diff --git a/tensorflow/compiler/xla/service/gpu/while_transformer.h b/tensorflow/compiler/xla/service/gpu/while_transformer.h index a4f527fce0e4e280e24efc1f33ea68a0b71555b9..fe3a954e1828ee4a323872eea81f64c7e780ad24 100644 --- a/tensorflow/compiler/xla/service/gpu/while_transformer.h +++ b/tensorflow/compiler/xla/service/gpu/while_transformer.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_GPU_WHILE_TRANSFORMER_H_ -#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_GPU_WHILE_TRANSFORMER_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_WHILE_TRANSFORMER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_WHILE_TRANSFORMER_H_ #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/statusor.h" @@ -40,4 +40,4 @@ StatusOr> CanTransformWhileToFor( } // namespace gpu } // namespace xla -#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_GPU_WHILE_TRANSFORMER_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_WHILE_TRANSFORMER_H_ diff --git a/tensorflow/compiler/xla/service/hlo_element_type_converter.cc b/tensorflow/compiler/xla/service/hlo_element_type_converter.cc index 1773bb401d380031f6c860d295e76d2f62c9e5ff..c782d1b0add17c70e0f54826917df251d5a613e2 100644 --- a/tensorflow/compiler/xla/service/hlo_element_type_converter.cc +++ b/tensorflow/compiler/xla/service/hlo_element_type_converter.cc @@ -54,45 +54,96 @@ bool HasOperandType(HloInstruction* hlo, PrimitiveType type) { return false; } +// Finds out the Tuple Shape of the new instruction after converting the element +// type of the operands of the original instruction from `from_type` to +// `to_type`. +// +// This routine assumes the resulting `shape` of the original instruction is a +// non-nested tuple. This assumption is currently safe as only kTuple, kInfeed, +// kOutfeed, kCall, kCustomCall and kBatchNorm* HLO instructions can produce +// results with tuple shapes, and this routine is only called to convert the +// result shapes of kBatchNorm* HLO instructions, which are non-nested tuples. +Shape GetConvertedTupleShape(const Shape& shape, PrimitiveType from_type, + PrimitiveType to_type) { + std::vector new_tuple_subshapes; + for (int64 i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { + Shape subshape = ShapeUtil::GetTupleElementShape(shape, i); + CHECK(!ShapeUtil::IsTuple(subshape)); + if (subshape.element_type() == from_type) { + subshape = ShapeUtil::ChangeElementType(subshape, to_type); + } + new_tuple_subshapes.push_back(subshape); + } + return ShapeUtil::MakeTupleShape(new_tuple_subshapes); +} + +// Converts the elements of the result of `hlo` to produce a new tuple with +// shape `to_shape`. +// +// This routine assumes `hlo` is an instruction that produces a non-nested Tuple +// as a result. +HloInstruction* ConvertTupleElements(HloInstruction* hlo, + const Shape& to_shape) { + const Shape& shape = hlo->shape(); + HloComputation* computation = hlo->parent(); + std::vector tuple_elements; + for (int64 i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { + const Shape& ele_shape = ShapeUtil::GetTupleElementShape(shape, i); + HloInstruction* element = computation->AddInstruction( + HloInstruction::CreateGetTupleElement(ele_shape, hlo, i)); + const Shape& to_ele_shape = ShapeUtil::GetTupleElementShape(to_shape, i); + CHECK(!ShapeUtil::IsTuple(ele_shape)); + if (ele_shape.element_type() != to_ele_shape.element_type()) { + element = computation->AddInstruction( + HloInstruction::CreateConvert(to_ele_shape, element)); + } + tuple_elements.push_back(element); + } + return computation->AddInstruction( + HloInstruction::CreateTuple(tuple_elements)); +} + } // namespace HloElementTypeConverter::HloElementTypeConverter( PrimitiveType eliminate_type, PrimitiveType replace_with_type) : eliminate_type_(eliminate_type), replace_with_type_(replace_with_type) {} +// This routine converts the arithmetic operations in the given module that use +// eliminate_type_ to operations that use replace_with_type_. StatusOr HloElementTypeConverter::Run(HloModule* module) { XLA_VLOG_LINES( 3, "HloElementTypeConverter::Run(), before:\n" + module->ToString()); + + if (eliminate_type_ == replace_with_type_) { + return false; + } + bool changed = false; for (auto* computation : module->computations()) { for (auto* hlo : computation->MakeInstructionPostOrder()) { + const auto opcode = hlo->opcode(); // These are ops where it does not make sense to convert them. - if (hlo->opcode() == HloOpcode::kParameter || - hlo->opcode() == HloOpcode::kConstant || - hlo->opcode() == HloOpcode::kTuple || - hlo->opcode() == HloOpcode::kConvert || - hlo->opcode() == HloOpcode::kGetTupleElement || - hlo->opcode() == HloOpcode::kInfeed || - hlo->opcode() == HloOpcode::kOutfeed) { + if (opcode == HloOpcode::kParameter || opcode == HloOpcode::kConstant || + opcode == HloOpcode::kTuple || opcode == HloOpcode::kConvert || + opcode == HloOpcode::kGetTupleElement || + opcode == HloOpcode::kInfeed || opcode == HloOpcode::kOutfeed) { continue; } // We cannot change a CustomCall since we have no way of adjusting the // called binary to expect the updated type. - if (hlo->opcode() == HloOpcode::kCustomCall) { + if (opcode == HloOpcode::kCustomCall) { continue; } // These are ops with embedded computations where it suffices to convert // the embedded computations instead of converting the ops themselves. - if (hlo->opcode() == HloOpcode::kWhile || - hlo->opcode() == HloOpcode::kCall || - hlo->opcode() == HloOpcode::kFusion || - hlo->opcode() == HloOpcode::kMap || - hlo->opcode() == HloOpcode::kReduce || - hlo->opcode() == HloOpcode::kReduceWindow || - hlo->opcode() == HloOpcode::kSelectAndScatter || - hlo->opcode() == HloOpcode::kConditional) { + if (opcode == HloOpcode::kWhile || opcode == HloOpcode::kCall || + opcode == HloOpcode::kFusion || opcode == HloOpcode::kMap || + opcode == HloOpcode::kReduce || opcode == HloOpcode::kReduceWindow || + opcode == HloOpcode::kSelectAndScatter || + opcode == HloOpcode::kConditional) { continue; } TF_RET_CHECK(hlo->called_computations().empty()) << hlo->ToString(); @@ -106,6 +157,11 @@ StatusOr HloElementTypeConverter::Run(HloModule* module) { continue; } + // Handle instructions that perform arithmetic operations and contain + // operands with eliminate_type_. + // + // First, convert the operands with eliminate_type_ to operands with + // replace_with_type_. std::vector new_operands; for (HloInstruction* operand : hlo->operands()) { if (operand->shape().element_type() == eliminate_type_) { @@ -114,6 +170,10 @@ StatusOr HloElementTypeConverter::Run(HloModule* module) { new_operands.push_back(operand); } + // Then find out the result type of the new instruction with the same + // opcode but using the converted operands, create the new instruction, + // and convert the result of the new instruction back to match the result + // type of the original instruction. HloInstruction* new_hlo; if (hlo->shape().element_type() == eliminate_type_) { Shape shape = @@ -121,10 +181,20 @@ StatusOr HloElementTypeConverter::Run(HloModule* module) { new_hlo = computation->AddInstruction( hlo->CloneWithNewOperands(shape, new_operands, hlo->GetModule())); new_hlo = ToElementType(new_hlo, eliminate_type_); + } else if (ShapeUtil::IsTuple(hlo->shape())) { + Shape old_shape = hlo->shape(); + Shape new_shape = GetConvertedTupleShape(hlo->shape(), eliminate_type_, + replace_with_type_); + new_hlo = computation->AddInstruction(hlo->CloneWithNewOperands( + new_shape, new_operands, hlo->GetModule())); + // Convert the elements of the result of `new_hlo` to produce a new + // tuple with shape `old_shape`. + new_hlo = ConvertTupleElements(new_hlo, old_shape); } else { new_hlo = computation->AddInstruction(hlo->CloneWithNewOperands( hlo->shape(), new_operands, hlo->GetModule())); } + TF_RETURN_IF_ERROR(computation->ReplaceInstruction(hlo, new_hlo)); changed = true; } diff --git a/tensorflow/compiler/xla/service/hlo_element_type_converter_test.cc b/tensorflow/compiler/xla/service/hlo_element_type_converter_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..cb94d9f19b825d1321263a4737b66a6bf198a772 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_element_type_converter_test.cc @@ -0,0 +1,121 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_element_type_converter.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" + +namespace xla { +namespace { + +namespace op = xla::testing::opcode_matchers; + +class HloElementTypeConverterTest : public HloTestBase { + public: + std::unique_ptr CreateModuleFromHloString( + const string& hlo_string) { + return HloRunner::CreateModuleFromString(hlo_string, + GetDebugOptionsForTest()) + .ValueOrDie(); + } +}; + +TEST_F(HloElementTypeConverterTest, CustomCallsNotConverted) { + const string& hlo_string = R"( + HloModule custom_call + ENTRY CustomCall { + constant = bf16[1]{0} constant({12345}) + ROOT custom-call = bf16[1,2,3]{0,2,1} custom-call(constant), + custom_call_target="foo" + } + )"; + auto module = CreateModuleFromHloString(hlo_string); + HloElementTypeConverter type_converter(BF16, F32); + TF_ASSERT_OK_AND_ASSIGN(bool converted, type_converter.Run(module.get())); + EXPECT_FALSE(converted); +} + +TEST_F(HloElementTypeConverterTest, InfeedsOutfeedsNotConverted) { + const string& hlo_string = R"( + HloModule InfeedOutfeed + ENTRY RoundTrip16MiBR1.v2 { + ROOT infeed = bf16[4]{0} infeed() + outfeed = () outfeed(infeed) + } + )"; + auto module = CreateModuleFromHloString(hlo_string); + HloElementTypeConverter type_converter(BF16, F32); + TF_ASSERT_OK_AND_ASSIGN(bool converted, type_converter.Run(module.get())); + EXPECT_FALSE(converted); +} + +TEST_F(HloElementTypeConverterTest, OperationsInNestedTuplesConverted) { + const string& hlo_string = R"( + HloModule NestedTuples + ENTRY NestedTuples.v5 { + constant.4 = bf16[] constant(42) + constant.2 = f32[2]{0} constant({1, 2}) + constant.3 = bf16[] constant(42) + add = bf16[] add(constant.2, constant.3) + tuple = (f32[2]{0}, bf16[]) tuple(constant.2, add) + constant.5 = bf16[2]{0} constant({22, 44}) + ROOT tuple.1 = ((f32[2]{0}, bf16[]), bf16[2]{0}) tuple(tuple, constant.5) + } + )"; + + auto module = CreateModuleFromHloString(hlo_string); + HloElementTypeConverter type_converter(BF16, F32); + TF_ASSERT_OK_AND_ASSIGN(bool converted, type_converter.Run(module.get())); + EXPECT_TRUE(converted); + const HloInstruction* bf16_op = + module->entry_computation()->root_instruction()->operand(0)->operand(1); + EXPECT_THAT(bf16_op, op::Convert(op::Add(op::Constant(), op::Convert()))); +} + +TEST_F(HloElementTypeConverterTest, BatchNormGradBF16Converted) { + const string& hlo_string = R"( + HloModule BatchNormGrad + ENTRY BatchNormGrad.v6 { + constant.4 = bf16[2,2,2,1]{3,2,1,0} constant(bf16[2,2,2,1] { { /*i0=0*/ + { /*i1=0*/ {0}, {0} }, { /*i1=1*/ {0}, {0} } }, { /*i0=1*/ { /*i1=0*/ {0}, + {0} }, { /*i1=1*/ {0}, {0} } } }) + constant.5 = bf16[2]{0} constant({1, 1}) + constant.6 = bf16[2]{0} constant({0, 0}) + constant.7 = bf16[2]{0} constant({1, 1}) + constant.8 = bf16[2,2,2,1]{3,2,1,0} constant(bf16[2,2,2,1] { { /*i0=0*/ + { /*i1=0*/ {1}, {2} }, { /*i1=1*/ {3}, {4} } }, { /*i0=1*/ { /*i1=0*/ + {5}, {6} }, { /*i1=1*/ {7}, {8} } } }) + ROOT batch-norm-grad = (bf16[2,2,2,1]{3,2,1,0}, bf16[2]{0}, bf16[2]{0}) + batch-norm-grad(constant.4, constant.5, constant.6, constant.7, + constant.8), epsilon=0, feature_index=2 + } + )"; + + auto module = CreateModuleFromHloString(hlo_string); + HloElementTypeConverter type_converter(BF16, F32); + TF_ASSERT_OK_AND_ASSIGN(bool converted, type_converter.Run(module.get())); + EXPECT_TRUE(converted); + const HloInstruction* tuple_instr = + module->entry_computation()->root_instruction(); + ::testing::Matcher batch_norm = + op::BatchNormGrad(); + EXPECT_THAT(tuple_instr, + op::Tuple(op::Convert(op::GetTupleElement(batch_norm, 0)), + op::Convert(op::GetTupleElement(batch_norm, 1)), + op::Convert(op::GetTupleElement(batch_norm, 2)))); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index 3a846a752988efd618a1d6b9ed3c9e7a27627eee..e3f5c17e35f5294e204993af9396dec326a779cd 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -43,6 +43,7 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/types.h" @@ -166,6 +167,34 @@ StatusOr> ElementWiseUnaryOpImpl( return std::move(result); } +// For one particular placement of a window in a base shape (the placement is +// represented as `window_count_index`), iterates inside the window. Translates +// the window index into base index. If the base index is within bound, call `f` +// with the base index. +void IterateThroughWindow( + const Shape& window_shape, const Window& window, const Shape& base_shape, + const tensorflow::gtl::ArraySlice& window_count_index, + const std::function&)>& f) { + const int64 rank = ShapeUtil::Rank(base_shape); + DimensionVector window_index(rank); + std::fill(window_index.begin(), window_index.end(), 0); + do { + std::vector base_index(rank); + bool out_of_bound = false; + for (int64 i = 0; i < rank; ++i) { + base_index[i] = window_count_index[i] * window.dimensions(i).stride() + + window_index[i] - window.dimensions(i).padding_low(); + if (base_index[i] < 0 || base_index[i] >= base_shape.dimensions(i)) { + out_of_bound = true; + break; + } + } + if (!out_of_bound) { + f(base_index); + } + } while (IndexUtil::BumpIndices(window_shape, &window_index)); +} + } // namespace template @@ -945,14 +974,21 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { out_index[output_spatial_dim] * window_dim.stride() - window_dim.padding_low() + rhs_spatial_index[ki] * window_dim.window_dilation(); - // Skip if the lhs (input) index is to be dilated. - if (undilated_index % window_dim.base_dilation() != 0) { + // Skip if the lhs (input) index is to be dilated. As an + // optimization, skip this mod if there's no dilation. + if (window_dim.base_dilation() > 1 && + undilated_index % window_dim.base_dilation() != 0) { goto cnt; } - // Calculate the actual lhs (input) index after dilation. - lhs_index[input_spatial_dim] = - undilated_index / window_dim.base_dilation(); + // Calculate the actual lhs (input) index after dilation. As an + // optimization, skip this integer divide if there's no dilation. + if (window_dim.base_dilation() > 1) { + lhs_index[input_spatial_dim] = + undilated_index / window_dim.base_dilation(); + } else { + lhs_index[input_spatial_dim] = undilated_index; + } // Skip if input index is not in bound. if (!(lhs_index[input_spatial_dim] >= 0 && @@ -1413,6 +1449,111 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { return Status::OK(); } + Status HandleSelectAndScatter(HloInstruction* select_and_scatter) override { + auto operand = select_and_scatter->operand(0); + auto source = select_and_scatter->operand(1); + const Window& window = select_and_scatter->window(); + + const Literal& init_literal = + parent_->GetEvaluatedLiteralFor(select_and_scatter->operand(2)); + TF_RET_CHECK(ShapeUtil::IsScalar(init_literal.shape())); + auto init_scalar = init_literal.Get({}); + + auto result = Literal::CreateFromShape(select_and_scatter->shape()); + + // Initialize result array with the init value. + TF_RETURN_IF_ERROR(result->Populate( + [&](tensorflow::gtl::ArraySlice output_index) { + return init_scalar; + })); + + std::vector window_dimension_sizes; + for (const auto& window_dimension : window.dimensions()) { + window_dimension_sizes.push_back(window_dimension.size()); + } + const Shape window_shape = ShapeUtil::MakeShape( + operand->shape().element_type(), window_dimension_sizes); + + HloComputation* select = select_and_scatter->select(); + HloComputation* scatter = select_and_scatter->scatter(); + + const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand); + const Literal& source_literal = parent_->GetEvaluatedLiteralFor(source); + + int64 rank = ShapeUtil::Rank(operand_literal.shape()); + + HloEvaluator embedded_evaluator; + DimensionVector source_index(rank); + + std::fill(source_index.begin(), source_index.end(), 0); + do { + // For each element in `source`, we place a window in `operand`. For each + // window placement, we iterate inside the window twice: + // + // 1. Find the selected index by applying `select` function to all + // elements. E.g., If the `select` function is GreaterEqual, the first + // iteration through the window finds the biggest value and returns its + // index. + // + // 2. Using the selected index, scatter value from `source` to result. We + // do this by iterating through the window, and compare each index with + // the selected index. + tensorflow::gtl::optional selected_val; + tensorflow::gtl::optional> selected_index; + + IterateThroughWindow( + window_shape, window, operand_literal.shape(), source_index, + [&](const std::vector& operand_index) { + auto curr_val = operand_literal.Get(operand_index); + if (!selected_val) { + selected_val = curr_val; + selected_index = operand_index; + } + const auto curr_val_literal = Literal::CreateR0(curr_val); + const auto selected_val_literal = + Literal::CreateR0(*selected_val); + + const std::vector args = { + curr_val_literal.get(), selected_val_literal.get()}; + std::unique_ptr computed_result = + embedded_evaluator.Evaluate(*select, args) + .ConsumeValueOrDie(); + bool selected = computed_result->Get({}); + if (selected) { + selected_val = curr_val; + selected_index = operand_index; + } + embedded_evaluator.ResetVisitStates(); + }); + + IterateThroughWindow( + window_shape, window, operand_literal.shape(), source_index, + [&](const std::vector& operand_index) { + if (std::equal(operand_index.begin(), operand_index.end(), + selected_index->begin())) { + auto source = source_literal.Get(source_index); + auto scattered = result->Get(operand_index); + const auto source_literal = Literal::CreateR0(source); + const auto scattered_literal = + Literal::CreateR0(scattered); + + const std::vector args = { + source_literal.get(), scattered_literal.get()}; + std::unique_ptr computed_result = + embedded_evaluator.Evaluate(*scatter, args) + .ConsumeValueOrDie(); + result->Set(operand_index, computed_result->Get({})); + // Clear visit states so that the we can use the evaluator again + // on the same computation. + embedded_evaluator.ResetVisitStates(); + } + }); + } while (IndexUtil::BumpIndices(source->shape(), &source_index)); + + parent_->evaluated_[select_and_scatter] = std::move(result); + return Status::OK(); + } + Status HandleReduceWindow(HloInstruction* reduce_window) override { auto operand = reduce_window->operand(0); const Window& window = reduce_window->window(); @@ -1461,39 +1602,28 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { std::fill(window_index.begin(), window_index.end(), 0); std::fill(operand_index.begin(), operand_index.end(), 0); - do { - bool out_of_bound = false; - for (int i = 0; i < operand_index.size(); ++i) { - operand_index[i] = - output_index[i] * window.dimensions(i).stride() + - window_index[i] - window.dimensions(i).padding_low(); - if (operand_index[i] < 0 || - operand_index[i] >= operand_literal.shape().dimensions(i)) { - out_of_bound = true; - break; - } - } - if (!out_of_bound) { - auto curr_val = operand_literal.Get(operand_index); - - // Evaluate computation with specified literal operands. - const auto curr_val_literal = - Literal::CreateR0(curr_val); - const auto result_val_literal = - Literal::CreateR0(result_val); - const std::vector args = { - curr_val_literal.get(), result_val_literal.get()}; - std::unique_ptr computed_result = - embedded_evaluator.Evaluate(*function, args) - .ConsumeValueOrDie(); - - // Clear visit states so that the we can use the evaluate again on - // the same computation. - embedded_evaluator.ResetVisitStates(); - - result_val = computed_result->Get({}); - } - } while (IndexUtil::BumpIndices(window_shape, &window_index)); + IterateThroughWindow( + window_shape, window, operand_literal.shape(), output_index, + [&](const std::vector& operand_index) { + auto curr_val = operand_literal.Get(operand_index); + + // Evaluate computation with specified literal operands. + const auto curr_val_literal = + Literal::CreateR0(curr_val); + const auto result_val_literal = + Literal::CreateR0(result_val); + const std::vector args = { + curr_val_literal.get(), result_val_literal.get()}; + std::unique_ptr computed_result = + embedded_evaluator.Evaluate(*function, args) + .ConsumeValueOrDie(); + + // Clear visit states so that the we can use the evaluate again + // on the same computation. + embedded_evaluator.ResetVisitStates(); + + result_val = computed_result->Get({}); + }); return result_val; })); diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.h b/tensorflow/compiler/xla/service/hlo_evaluator.h index 02bb8b0a47065c359603a113f49626bf3ad344d8..3b2b697e492a78a06a4e5ae6bf056ff8676f2ff5 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_H_ -#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_H_ #include @@ -195,4 +195,4 @@ class HloEvaluator : public DfsHloVisitorWithDefault { } // namespace xla -#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_H_ diff --git a/tensorflow/compiler/xla/service/hlo_execution_profile.cc b/tensorflow/compiler/xla/service/hlo_execution_profile.cc index 849aac0b12b096e5f7c4a5c441fc019c48a27060..f0df93b61d29c1535d8a89fbd65e669de5b43729 100644 --- a/tensorflow/compiler/xla/service/hlo_execution_profile.cc +++ b/tensorflow/compiler/xla/service/hlo_execution_profile.cc @@ -40,83 +40,75 @@ HloProfileIndexMap::HloProfileIndexMap(const HloModule& module) { } } -std::unique_ptr CreateHloProfilePrinter( +std::unique_ptr CreateHloProfilePrinterData( const HloProfileIndexMap& hlo_profile_index_map, const HloCostAnalysis& cost_analysis) { - using HloComputationInfo = HloProfilePrinter::HloComputationInfo; - using HloInstructionInfo = HloProfilePrinter::HloInstructionInfo; - - HloComputationInfo* computation_infos = - new HloComputationInfo[hlo_profile_index_map.computation_count()]; - - // There are two "indices" in play here. The first one is the index of the - // HloComputationInfo or HloInstructionInfo in the array that contains said - // HloComputationInfo or HloInstructionInfo. The second index is the index of - // the HloComputationInfo or HloInstructionInfo in the profile counters array, - // as decided by hlo_profile_index_map. The latter index is always referred - // to as "profile_index". - - size_t computation_index_in_static_data = 0; - size_t max_profile_index = hlo_profile_index_map.total_count(); - for (const auto& pair : hlo_profile_index_map.computation_to_profile_idx()) { - CHECK_LT(pair.second, max_profile_index); + using HloComputationInfo = HloProfilePrinterData::HloComputationInfo; + using HloInstructionInfo = HloProfilePrinterData::HloInstructionInfo; + + size_t profile_counters_size = hlo_profile_index_map.total_count(); + + std::unique_ptr profile_printer_data = + MakeUnique(); + profile_printer_data->set_profile_counters_size(profile_counters_size); + profile_printer_data->mutable_computation_infos()->Reserve( + hlo_profile_index_map.computation_count()); + + const auto& computation_to_profile_idx_map = + hlo_profile_index_map.computation_to_profile_idx(); + + // computation_to_profile_idx_map's order is not deterministic so create a + // deterministic computation_and_profile_idx_list so that we end up with a + // deterministic HloProfilePrinterData protobuf. + + std::vector> + computation_and_profile_idx_list(computation_to_profile_idx_map.begin(), + computation_to_profile_idx_map.end()); + + // The profile indices were computed deterministically in + // HloProfileIndexMap::HloProfileIndexMap. + c_sort(computation_and_profile_idx_list, + [](const std::pair& left, + const std::pair& right) { + return left.second < right.second; + }); + + for (const auto& pair : computation_and_profile_idx_list) { + CHECK_LT(pair.second, profile_counters_size); const HloComputation* computation = pair.first; - size_t current_computation_index = computation_index_in_static_data++; HloComputationInfo* computation_info = - &computation_infos[current_computation_index]; + profile_printer_data->add_computation_infos(); - computation_info->name = strdup(computation->name().c_str()); - computation_info->profile_index = pair.second; - computation_info->instructions = - new HloInstructionInfo[computation->instruction_count()]; - computation_info->instructions_size = computation->instruction_count(); + computation_info->set_name(computation->name()); + computation_info->set_profile_index(pair.second); + computation_info->mutable_instruction_infos()->Reserve( + computation->instruction_count()); - size_t instruction_index_in_static_data = 0; for (const HloInstruction* hlo : computation->instructions()) { - HloProfilePrinter::HloInstructionInfo* instruction_info = - &computation_info->instructions[instruction_index_in_static_data++]; - instruction_info->long_name = strdup(hlo->ToString().c_str()); - instruction_info->short_name = strdup( - hlo->ToString(HloPrintOptions().set_compact_operands(true)).c_str()); - instruction_info->category = strdup(hlo->ToCategory().c_str()); - instruction_info->flop_count = cost_analysis.flop_count(*hlo); - instruction_info->transcendental_count = - cost_analysis.transcendental_count(*hlo); - instruction_info->bytes_accessed = cost_analysis.bytes_accessed(*hlo); - instruction_info->optimal_seconds = cost_analysis.optimal_seconds(*hlo); - instruction_info->profile_index = - hlo_profile_index_map.GetProfileIndexFor(*hlo); - CHECK_LT(instruction_info->profile_index, max_profile_index); + HloInstructionInfo* instruction_info = + computation_info->add_instruction_infos(); + instruction_info->set_long_name(hlo->ToString()); + instruction_info->set_short_name( + hlo->ToString(HloPrintOptions().set_compact_operands(true))); + instruction_info->set_category(hlo->ToCategory()); + instruction_info->set_flop_count(cost_analysis.flop_count(*hlo)); + instruction_info->set_transcendental_count( + cost_analysis.transcendental_count(*hlo)); + instruction_info->set_bytes_accessed(cost_analysis.bytes_accessed(*hlo)); + instruction_info->set_optimal_seconds( + cost_analysis.optimal_seconds(*hlo)); + instruction_info->set_profile_index( + hlo_profile_index_map.GetProfileIndexFor(*hlo)); } } - auto deleter = [](HloProfilePrinter::HloComputationInfo* computation_infos, - int64 computation_infos_size) { - for (int64 i = 0; i < computation_infos_size; i++) { - HloInstructionInfo* instruction_infos = computation_infos[i].instructions; - for (int64 j = 0; j < computation_infos[i].instructions_size; j++) { - // We can't make instruction_infos[j].long_name etc. non-const pointers - // since they may point into static storage, so we have a const_cast - // here. - free(const_cast(instruction_infos[j].long_name)); - free(const_cast(instruction_infos[j].short_name)); - free(const_cast(instruction_infos[j].category)); - } - delete[] instruction_infos; - free(const_cast(computation_infos[i].name)); - } - delete[] computation_infos; - }; - - return MakeUnique( - computation_infos, hlo_profile_index_map.computation_count(), - /*profile_counters_size=*/max_profile_index, deleter); + return profile_printer_data; } HloExecutionProfile::HloExecutionProfile( - const HloProfilePrinter* hlo_profile_printer, + const HloProfilePrinterData* hlo_profile_printer_data, const HloProfileIndexMap* hlo_profile_index_map) - : hlo_profile_printer_(*hlo_profile_printer), + : hlo_profile_printer_data_(*hlo_profile_printer_data), hlo_profile_index_map_(*hlo_profile_index_map), profile_counters_( /*count*/ hlo_profile_index_map_.total_count(), diff --git a/tensorflow/compiler/xla/service/hlo_execution_profile.h b/tensorflow/compiler/xla/service/hlo_execution_profile.h index 1a6b069609cb58bcc9659b4457453758a277bc0e..6fb91b9bef9d1df82b8806ce79cc147823edeb3d 100644 --- a/tensorflow/compiler/xla/service/hlo_execution_profile.h +++ b/tensorflow/compiler/xla/service/hlo_execution_profile.h @@ -77,8 +77,8 @@ class HloProfileIndexMap { std::unordered_map computation_to_profile_idx_; }; -// Create an instance of `HloProfilePrinter` that owns its memory. -std::unique_ptr CreateHloProfilePrinter( +// Create an instance of `HloProfilePrinterData`. +std::unique_ptr CreateHloProfilePrinterData( const HloProfileIndexMap& hlo_profile_index_map, const HloCostAnalysis& cost_analysis); @@ -90,7 +90,7 @@ class HloExecutionProfile { public: using DeviceDescription = perftools::gputools::DeviceDescription; - HloExecutionProfile(const HloProfilePrinter* hlo_profile_printer, + HloExecutionProfile(const HloProfilePrinterData* hlo_profile_printer_data, const HloProfileIndexMap* hlo_profile_index_map); // Record how many cycles this HLO took to execute. @@ -117,11 +117,10 @@ class HloExecutionProfile { // debugging; e.g. emits cycle counts, execution time at the nominal device // frequency, and the effective throughput given the provided cost_analysis // for the operations in a given computation. Returns an empty string if it - // wasn't possible to generate a printable version. cost_analysis should be a - // clean analysis that can be used to visit the computation. + // wasn't possible to generate a printable version. string ToString(const DeviceDescription& device_description) const { - return hlo_profile_printer_.ToString(profile_counters_.data(), - device_description.clock_rate_ghz()); + return PrintHloProfile(hlo_profile_printer_data_, profile_counters_.data(), + device_description.clock_rate_ghz()); } std::vector* mutable_profile_counters() { return &profile_counters_; } @@ -130,7 +129,7 @@ class HloExecutionProfile { } private: - const HloProfilePrinter& hlo_profile_printer_; + const HloProfilePrinterData& hlo_profile_printer_data_; const HloProfileIndexMap& hlo_profile_index_map_; // Stores per-Hlo profile counters. This is the only thing that changes when diff --git a/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc b/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc index b1e6729e2bccad4bdbe075a635d8a9b1ede6fecb..a0cb28246d3be541e798e85552436f64a3521f22 100644 --- a/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc +++ b/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc @@ -73,8 +73,8 @@ TEST_F(HloExecutionProfileTest, Basic) { HloCostAnalysis cost_analysis(shape_size_function); HloProfileIndexMap profile_index_map(*hlo_module); - std::unique_ptr profile_printer = - CreateHloProfilePrinter(profile_index_map, cost_analysis); + std::unique_ptr profile_printer = + CreateHloProfilePrinterData(profile_index_map, cost_analysis); HloExecutionProfile execution_profile(profile_printer.get(), &profile_index_map); diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 90121f7ffe11b379bea9e83a483c7e752c97998c..a889c35aeb297bd118c40ced2dd9539957dce67a 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -404,6 +404,9 @@ HloInstruction::CreateCrossReplicaSum( tensorflow::StringPiece outfeed_config) { std::unique_ptr instruction = WrapUnique(new HloInstruction(HloOpcode::kOutfeed, ShapeUtil::MakeNil())); + CHECK(ShapeUtil::Compatible(operand->shape(), shape)) + << "Outfeed shape " << shape << " must be compatible with operand shape " + << operand->shape(); instruction->AppendOperand(operand); instruction->outfeed_config_ = outfeed_config.ToString(); instruction->outfeed_shape_ = shape; @@ -669,6 +672,58 @@ HloInstruction::CreateSelectAndScatter( return instruction; } +/* static */ std::unique_ptr +HloInstruction::CreateBroadcastSequence( + const Shape& output_shape, HloInstruction* operand, + const std::function)>& + adder) { + CHECK(ShapeUtil::IsScalar(operand->shape()) || + ShapeUtil::Rank(operand->shape()) == ShapeUtil::Rank(output_shape)); + Shape broadcast_shape = ShapeUtil::ChangeElementType( + output_shape, operand->shape().element_type()); + // Do explicit broadcast for scalar. + if (ShapeUtil::IsScalar(operand->shape())) { + auto broadcast = + HloInstruction::CreateBroadcast(broadcast_shape, operand, {}); + broadcast->set_metadata(operand->metadata()); + if (operand->has_sharding()) { + broadcast->set_sharding(operand->sharding()); + } + return broadcast; + } + // Do explicit broadcast for degenerate broadcast. + std::vector broadcast_dimensions; + std::vector reshaped_dimensions; + for (int i = 0; i < ShapeUtil::Rank(operand->shape()); i++) { + if (operand->shape().dimensions(i) == output_shape.dimensions(i)) { + broadcast_dimensions.push_back(i); + reshaped_dimensions.push_back(operand->shape().dimensions(i)); + } else { + CHECK_EQ(operand->shape().dimensions(i), 1) + << "An explicit broadcast sequence requires the broadcasted " + "dimensions to be trivial; operand: " + << operand->ToString() << "; output_shape: " << output_shape; + } + } + // Eliminate the size one dimensions. + HloInstruction* reshaped_operand = adder(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(operand->shape().element_type(), + reshaped_dimensions), + operand)); + reshaped_operand->set_metadata(operand->metadata()); + if (operand->has_sharding()) { + reshaped_operand->set_sharding(operand->sharding()); + } + // Broadcast 'reshape' up to the larger size. + auto broadcast = HloInstruction::CreateBroadcast( + broadcast_shape, reshaped_operand, broadcast_dimensions); + broadcast->set_metadata(operand->metadata()); + if (operand->has_sharding()) { + broadcast->set_sharding(operand->sharding()); + } + return broadcast; +} + /* static */ std::unique_ptr HloInstruction::CreatePad( const Shape& shape, HloInstruction* operand, HloInstruction* padding_value, const PaddingConfig& padding_config) { diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index e700ec1d2903ac0bb77e36097c3e1e582206e4d5..5e89dc79bea81e650331e320f7836fdde90b2a53 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -409,6 +409,20 @@ class HloInstruction { const Shape& shape, HloInstruction* operand, tensorflow::gtl::ArraySlice broadcast_dimensions); + // Creates a sequence of instructions that performs an explicit broadcast of + // the operand to the target shape. + // + // Interior HLOs are passed to "adder", but the "root" HLO of the sequence is + // returned as a unique_ptr for API consistency with other factory methods in + // this interface. + // + // TODO(b/72173833) Ideally HloComputations would always be present, and so + // the adder being passed by the caller would not be necessary. + static std::unique_ptr CreateBroadcastSequence( + const Shape& output_shape, HloInstruction* operand, + const std::function)>& + adder); + // Creates a pad instruction, where the operand is padded on the edges and // between the elements with the given padding value. static std::unique_ptr CreatePad( diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc index 3af3b29cedd06996dd4a175fdb1584c705ceea87..1038ab555567aa654342d59e02efaf844f2b95ba 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc @@ -712,8 +712,8 @@ TEST_F(HloInstructionTest, PreserveOutfeedShapeThroughClone) { {1, 2}, {3, 4}, }))); - auto shape10 = ShapeUtil::MakeShapeWithLayout(F32, {2, 3}, {1, 0}); - auto shape01 = ShapeUtil::MakeShapeWithLayout(F32, {2, 3}, {0, 1}); + auto shape10 = ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0}); + auto shape01 = ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1}); auto outfeed10 = builder.AddInstruction( HloInstruction::CreateOutfeed(shape10, constant, "")); auto outfeed01 = builder.AddInstruction( diff --git a/tensorflow/compiler/xla/service/hlo_matchers.h b/tensorflow/compiler/xla/service/hlo_matchers.h index 992f55788b4900949f4994ba5b7be015bcd0d3de..9206cdac05fbc1d6051617ab4b0f3016f19e3c90 100644 --- a/tensorflow/compiler/xla/service/hlo_matchers.h +++ b/tensorflow/compiler/xla/service/hlo_matchers.h @@ -83,6 +83,7 @@ HLO_MATCHER(Abs); HLO_MATCHER(Add); HLO_MATCHER(Bitcast); HLO_MATCHER(Broadcast); +HLO_MATCHER(BatchNormGrad); HLO_MATCHER(Call); HLO_MATCHER(Ceil); HLO_MATCHER(Clamp); diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc index 58bb94221149c9a8b550add900dff52a53565985..99d8dd04e5279e0e8a977370beedc4448dc6dc4b 100644 --- a/tensorflow/compiler/xla/service/hlo_module.cc +++ b/tensorflow/compiler/xla/service/hlo_module.cc @@ -523,7 +523,15 @@ std::unique_ptr HloModule::Clone(const string& suffix) const { std::unordered_map clone_map; for (auto& computation : computations_) { - auto cloned_computation = computation->Clone(suffix); + if (computation->IsFusionComputation()) { + // Cloning of a fused computation is handled by its fusion instruction. + continue; + } + + // When cloning a computation, pass in the new module, so that for any + // fusion instruction in this computation, the fused computation will be + // deep cloned to the new module. + auto cloned_computation = computation->Clone(suffix, module.get()); InsertOrDie(&clone_map, computation.get(), cloned_computation.get()); if (entry_computation_ == computation.get()) { @@ -537,8 +545,15 @@ std::unique_ptr HloModule::Clone(const string& suffix) const { for (auto* instruction : cloned_computation->instructions()) { // Rewrite instruction's called_computation to point to the cloned // computations. - instruction->ReplaceCalledComputations( - [&](HloComputation* hlo) { return FindOrDie(clone_map, hlo); }); + instruction->ReplaceCalledComputations([&](HloComputation* hlo) { + if (hlo->IsFusionComputation()) { + // Cloning of a fused computation has already been handled when its + // fusion instruction is cloned. So this hlo computation is already + // the cloned one. + return hlo; + } + return FindOrDie(clone_map, hlo); + }); } } return module; diff --git a/tensorflow/compiler/xla/service/hlo_module_test.cc b/tensorflow/compiler/xla/service/hlo_module_test.cc index 0f5d3dccb74e6e3c88e51685392171f940c03596..cd51fa4e8549daba3e953eece50cb3538f627b89 100644 --- a/tensorflow/compiler/xla/service/hlo_module_test.cc +++ b/tensorflow/compiler/xla/service/hlo_module_test.cc @@ -105,6 +105,48 @@ TEST_F(HloModuleTest, CloneTest) { } } +TEST_F(HloModuleTest, CloneHasFusion) { + auto module = CreateNewModule(); + + // Create the fused computation. + HloComputation* fused_computation; + { + auto b = HloComputation::Builder("Fused"); + auto x = b.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "x")); + b.AddInstruction( + HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, x, x)); + fused_computation = module->AddEmbeddedComputation(b.Build()); + } + + // Create the entry computation. + { + auto b = HloComputation::Builder("Entry"); + auto input = b.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + b.AddInstruction( + HloInstruction::CreateFusion(r0f32_, HloInstruction::FusionKind::kInput, + /*operands=*/{input}, fused_computation)); + module->AddEntryComputation(b.Build()); + } + + auto post_order = module->MakeComputationPostOrder(); + auto cloned_module = module->Clone("copy"); + auto post_order_copied = cloned_module->MakeComputationPostOrder(); + + EXPECT_EQ(post_order.size(), post_order_copied.size()); + for (auto origin = post_order.begin(), copied = post_order_copied.begin(); + origin != post_order.end() && copied != post_order_copied.end(); + ++origin, ++copied) { + if ((*origin)->name() == "Fused") { + // Clone of the fused computation is handled when its fusion instruction + // is cloned, which always use suffix ".clone". + EXPECT_EQ((*origin)->name() + ".clone", (*copied)->name()); + } else { + EXPECT_EQ((*origin)->name() + ".copy", (*copied)->name()); + } + } +} + TEST_F(HloModuleTest, DiamondComputationsPostOrder) { // Create a module with a diamond call graph of computations. auto module = CreateNewModule(); diff --git a/tensorflow/compiler/xla/service/hlo_ordering.cc b/tensorflow/compiler/xla/service/hlo_ordering.cc index 6f6e679a21870e46da85963c3b2998465ac43420..68e3c9618c1fe9daacb0aee3ee98862c8b9e4bc4 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering.cc @@ -249,7 +249,7 @@ bool PredecessorHloOrdering::ExecutesBeforeInSameComputation( string PredecessorHloOrdering::ToStringHelper(const string& name) const { std::vector pieces; pieces.push_back(name); - for (auto* computation : module_->computations()) { + for (auto* computation : module_->MakeNonfusionComputations()) { pieces.push_back(tensorflow::strings::Printf("computation %s:", computation->name().c_str())); const auto all = computation->MakeInstructionPostOrder(); diff --git a/tensorflow/compiler/xla/service/hlo_ordering_test.cc b/tensorflow/compiler/xla/service/hlo_ordering_test.cc index 33bafd05c15c47abaa313f92eb53a791de43d7d9..aba66114de649ce7667ae77174e9c4073b010b90 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering_test.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering_test.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_scheduling.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -310,5 +311,56 @@ TEST_F(HloOrderingTest, ValuesInWhileComputations) { *dataflow)); } +// Regression test for HloOrdering::ToString() crashing when fed a computation +// containing a fusion node. +TEST_F(HloOrderingTest, ToStringDoesNotCrash) { + const char* module_str = R"( +HloModule test_module + +body.v8 { + prev.1 = (s32[], f32[3]{0}, f32[3]{0}, f32[3]{0}) parameter(0) + get-tuple-element.4 = s32[] get-tuple-element(prev.1), index=0 + constant.1 = s32[] constant(1) + add = s32[] add(get-tuple-element.4, constant.1) + get-tuple-element.5 = f32[3]{0} get-tuple-element(prev.1), index=3 + get-tuple-element.6 = f32[3]{0} get-tuple-element(prev.1), index=1 + get-tuple-element.7 = f32[3]{0} get-tuple-element(prev.1), index=2 + ROOT tuple = (s32[], f32[3]{0}, f32[3]{0}, f32[3]{0}) tuple(add, get-tuple-element.5, get-tuple-element.6, get-tuple-element.7) +} + +condition.v4 { + constant.2 = s32[] constant(2) + prev.2 = (s32[], f32[3]{0}, f32[3]{0}, f32[3]{0}) parameter(0) + get-tuple-element.8 = s32[] get-tuple-element(prev.2), index=0 + ROOT greater-than = pred[] greater-than(constant.2, get-tuple-element.8) +} + +fused_computation { + get-tuple-element.5.param_1 = f32[3]{0} parameter(1) + get-tuple-element.6.param_2 = f32[3]{0} parameter(2) + add.4 = f32[3]{0} add(get-tuple-element.5.param_1, get-tuple-element.6.param_2) + get-tuple-element.7.param_1.1 = f32[3]{0} parameter(0) + ROOT add.5 = f32[3]{0} add(add.4, get-tuple-element.7.param_1.1) +} + +ENTRY while.v11 { + constant.5 = s32[] constant(0) + constant.6 = f32[3]{0} constant({1, 1, 1}) + constant.7 = f32[3]{0} constant({2, 2, 2}) + constant.8 = f32[3]{0} constant({3, 3, 3}) + tuple.1 = (s32[], f32[3]{0}, f32[3]{0}, f32[3]{0}) tuple(constant.5, constant.6, constant.7, constant.8) + while = (s32[], f32[3]{0}, f32[3]{0}, f32[3]{0}) while(tuple.1), condition=condition.v4, body=body.v8 + get-tuple-element.9 = f32[3]{0} get-tuple-element(while), index=3 + get-tuple-element.10 = f32[3]{0} get-tuple-element(while), index=1 + get-tuple-element.11 = f32[3]{0} get-tuple-element(while), index=2 + ROOT fusion = f32[3]{0} fusion(get-tuple-element.9, get-tuple-element.10, get-tuple-element.11), kind=kLoop, calls=fused_computation +})"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + tools::Parse(module_str)); + DependencyHloOrdering ordering(module.get()); + ordering.ToString(); // Shouldn't crash. +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_profile_printer.cc b/tensorflow/compiler/xla/service/hlo_profile_printer.cc index e944ad15139af0d2f98e8e68d3d48303f47ecf1c..dcc22793015147aaf3229875078b2989e4ef7559 100644 --- a/tensorflow/compiler/xla/service/hlo_profile_printer.cc +++ b/tensorflow/compiler/xla/service/hlo_profile_printer.cc @@ -18,20 +18,20 @@ limitations under the License. #include "tensorflow/compiler/xla/service/human_readable_profile_builder.h" namespace xla { -string HloProfilePrinter::ToString(const int64* counters, - double clock_rate_ghz) const { +string PrintHloProfile(const HloProfilePrinterData& hlo_profile_printer_data, + const int64* counters, double clock_rate_ghz) { + using HloComputationInfo = HloProfilePrinterData::HloComputationInfo; + using HloInstructionInfo = HloProfilePrinterData::HloInstructionInfo; + string result; - for (int computation_idx = 0; computation_idx < computation_infos_size_; - computation_idx++) { - const HloComputationInfo& computation = computation_infos_[computation_idx]; - const HloInstructionInfo* instructions_begin = computation.instructions; - const HloInstructionInfo* instructions_end = - computation.instructions + computation.instructions_size; + for (const HloComputationInfo& computation_info : + hlo_profile_printer_data.computation_infos()) { + const auto& instruction_infos = computation_info.instruction_infos(); bool any_instruction_profiled = - std::any_of(instructions_begin, instructions_end, + std::any_of(instruction_infos.begin(), instruction_infos.end(), [&](const HloInstructionInfo& instruction_info) { - return counters[instruction_info.profile_index] != 0; + return counters[instruction_info.profile_index()] != 0; }); if (!any_instruction_profiled) { @@ -41,16 +41,19 @@ string HloProfilePrinter::ToString(const int64* counters, // Once we start using this in AOT for real, we will probably need a more // minimal version of HumanReadableProfileBuilder. HumanReadableProfileBuilder builder( - computation.name, counters[computation.profile_index], clock_rate_ghz); + computation_info.name(), counters[computation_info.profile_index()], + clock_rate_ghz); - for (const auto* instruction = instructions_begin; - instruction != instructions_end; instruction++) { + for (const auto& instruction_info : instruction_infos) { builder.AddOp( - /*op_name=*/instruction->long_name, - /*short_name=*/instruction->short_name, instruction->category, - counters[instruction->profile_index], instruction->flop_count, - instruction->transcendental_count, instruction->bytes_accessed, - instruction->optimal_seconds); + /*op_name=*/instruction_info.long_name(), + /*short_name=*/instruction_info.short_name(), + instruction_info.category(), + counters[instruction_info.profile_index()], + instruction_info.flop_count(), + instruction_info.transcendental_count(), + instruction_info.bytes_accessed(), + instruction_info.optimal_seconds()); } result += builder.ToString(); @@ -58,10 +61,4 @@ string HloProfilePrinter::ToString(const int64* counters, return result; } - -HloProfilePrinter::~HloProfilePrinter() { - if (deleter_) { - deleter_(computation_infos_, computation_infos_size_); - } -} } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_profile_printer.h b/tensorflow/compiler/xla/service/hlo_profile_printer.h index 2f056490ae027872570f7a0821ee63114f49fab8..b72325c7554acad258c2da55a18e5e18ec1b06a6 100644 --- a/tensorflow/compiler/xla/service/hlo_profile_printer.h +++ b/tensorflow/compiler/xla/service/hlo_profile_printer.h @@ -13,91 +13,20 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PROFILE_PRINTER_H_ -#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PROFILE_PRINTER_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PROFILE_PRINTER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PROFILE_PRINTER_H_ #include #include #include +#include "tensorflow/compiler/xla/service/hlo_profile_printer_data.pb.h" #include "tensorflow/compiler/xla/types.h" namespace xla { -// Instances of this class can pretty-print profile counters gathered from -// running an XLA computation without having access to the backing module. -class HloProfilePrinter { - public: - // Holds meta information about an HloInstruction. - // - // The pointer-typed fields can be owning or non-owning -- this decision is - // manifested as the deleter_ function in the containing HloProfilePrinter. - struct HloInstructionInfo { - // Textual information for pretty printing. - const char* long_name; - const char* short_name; - const char* category; - - // Metrics computed by HloCostAnalysis. - float flop_count; - float transcendental_count; - float bytes_accessed; - float optimal_seconds; - - // The index into the profile counters array for the HloInstruction - // corresponding to this HloInstructionInfo. - int64 profile_index; - }; - - // Holds meta information about an HloComputation. - // - // The pointer-typed fields can be owning or non-owning -- this decision is - // manifested as the deleter_ function in the containing HloProfilePrinter. - struct HloComputationInfo { - const char* name; - - // The index into the profile counters array for the HloInstruction - // corresponding to this HloComputationInfo. - int64 profile_index; - - HloInstructionInfo* instructions; - int64 instructions_size; - }; - - HloProfilePrinter( - HloComputationInfo* computation_infos, int64 computation_infos_size, - int64 profile_counters_size, - std::function deleter = nullptr) - : computation_infos_(computation_infos), - computation_infos_size_(computation_infos_size), - profile_counters_size_(profile_counters_size), - deleter_(std::move(deleter)) {} - - HloProfilePrinter(HloProfilePrinter&& other) { - std::swap(other.computation_infos_, computation_infos_); - std::swap(other.computation_infos_size_, computation_infos_size_); - std::swap(other.deleter_, deleter_); - } - - HloProfilePrinter(const HloProfilePrinter&) = delete; - HloProfilePrinter& operator=(const HloProfilePrinter&) = delete; - - // Converts the profile counter sequence `counters` to a human readable string - // representation. - string ToString(const int64* counters, double clock_rate_ghz) const; - - // Returns the size of the profile buffer expected by this printer. - int64 profile_counters_size() const { return profile_counters_size_; } - - ~HloProfilePrinter(); - - private: - // The `computation_infos_` field can be owning or non-owning -- this decision - // is manifested as the deleter_ function. - HloComputationInfo* computation_infos_ = nullptr; - int64 computation_infos_size_ = 0; - int64 profile_counters_size_ = 0; - std::function deleter_; -}; +// Pretty-print an array of profile counters using hlo_profile_printer_data. +string PrintHloProfile(const HloProfilePrinterData& hlo_profile_printer_data, + const int64* counters, double clock_rate_ghz); } // namespace xla -#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PROFILE_PRINTER_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PROFILE_PRINTER_H_ diff --git a/tensorflow/compiler/xla/service/hlo_profile_printer_data.proto b/tensorflow/compiler/xla/service/hlo_profile_printer_data.proto new file mode 100644 index 0000000000000000000000000000000000000000..9f22b733fe1d676b177039a9d7a3064b8638d7bc --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_profile_printer_data.proto @@ -0,0 +1,60 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto3"; + +package xla; + +option cc_enable_arenas = true; + +// Describes how to pretty-print a profile counter array gathered for a specific +// HloModule. +message HloProfilePrinterData { + // Pretty-printer information about an HloInstruction. + message HloInstructionInfo { + string long_name = 1; + string short_name = 2; + string category = 3; + + // Metrics computed by HloCostAnalysis. + float flop_count = 4; + float transcendental_count = 5; + float bytes_accessed = 6; + float optimal_seconds = 7; + + // The index into the profile counters array for the HloInstruction + // corresponding to this HloInstructionInfo. + int64 profile_index = 8; + } + + // Pretty-printer information about an HloComputation. + message HloComputationInfo { + string name = 1; + + // The index into the profile counters array for the HloComputation + // corresponding to this HloComputationInfo. + int64 profile_index = 2; + + // HloInstructionInfos for every HloInstruction in the HloComputation for + // corresponding to this HloComputattionInfo. + repeated HloInstructionInfo instruction_infos = 3; + } + + // HloComputationInfos for every HloComputation in the HloModule. + repeated HloComputationInfo computation_infos = 1; + + // The size of the profile counters array we will pretty-print. + int64 profile_counters_size = 2; +} diff --git a/tensorflow/compiler/xla/service/hlo_tfgraph_builder.h b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.h index 9aa3e501d5f85e3b61b20555e3d13c5687f33f2f..c4876b852e32d34693202f4023aa20ad2b301ffd 100644 --- a/tensorflow/compiler/xla/service/hlo_tfgraph_builder.h +++ b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_HLO_TFGRAPH_BUILDER_H_ -#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_HLO_TFGRAPH_BUILDER_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_TFGRAPH_BUILDER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_TFGRAPH_BUILDER_H_ #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/xla.pb.h" @@ -56,4 +56,4 @@ class HloTfGraphBuilder { } // namespace hlo_graph_dumper } // namespace xla -#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_HLO_TFGRAPH_BUILDER_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_TFGRAPH_BUILDER_H_ diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index 9d9cf0c0f67f50a13f6d966079b3f9748b0a52e9..6e46f945e0a2d776ab557c10fedf9b5eb393f3c2 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -107,8 +107,20 @@ Status ShapeVerifier::HandleInfeed(HloInstruction*) { return tensorflow::Status::OK(); } -Status ShapeVerifier::HandleOutfeed(HloInstruction*) { - return tensorflow::Status::OK(); +Status ShapeVerifier::HandleOutfeed(HloInstruction* outfeed) { + // Outfeed has a separate shape field for the value which is outfed to the + // host. The shape of the instruction itself is always nil because the outfeed + // produces no HLO value in the graph. + if (!ShapeUtil::Compatible(outfeed->outfeed_shape(), + outfeed->operand(0)->shape())) { + return InvalidArgument( + "Expected outfeed to have shape compatible with operand's shape %s, " + "actual shape is %s:\n%s", + ShapeUtil::HumanString(outfeed->operand(0)->shape()).c_str(), + ShapeUtil::HumanString(outfeed->outfeed_shape()).c_str(), + outfeed->ToString().c_str()); + } + return CheckShape(outfeed, ShapeUtil::MakeNil()); } Status ShapeVerifier::HandleRng(HloInstruction*) { @@ -159,7 +171,8 @@ Status ShapeVerifier::HandleBroadcast(HloInstruction* broadcast) { ++operand_dimension) { int64 output_dimension = broadcast->dimensions()[operand_dimension]; TF_RET_CHECK(broadcast->shape().dimensions(output_dimension) == - operand_shape.dimensions(operand_dimension)); + operand_shape.dimensions(operand_dimension)) + << broadcast->ToString() << " operand shape " << operand_shape; } return tensorflow::Status::OK(); } diff --git a/tensorflow/compiler/xla/service/hlo_verifier.h b/tensorflow/compiler/xla/service/hlo_verifier.h index 6368611f323ad7c1ebade4941260e12ed2c6e45f..5a1d864e03d436bb29f7c98b9a373a19abc28a7e 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.h +++ b/tensorflow/compiler/xla/service/hlo_verifier.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_HLO_VERIFIER_H_ -#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_HLO_VERIFIER_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_VERIFIER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_VERIFIER_H_ #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" @@ -127,4 +127,4 @@ class HloVerifier : public HloPassInterface { } // namespace xla -#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_HLO_VERIFIER_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_VERIFIER_H_ diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc index f80dace8775c5ed31addb4a3d134f53005c6df71..bbea6bee5659c73cc71f45ed5e6bbd51df26c050 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/layout_assignment.cc @@ -81,7 +81,10 @@ OperandLayoutConstraint::OperandLayoutConstraint( operand_no_(operand_no) { CHECK(shape_layout_.LayoutIsSet()); CHECK(ShapeUtil::Compatible(shape_layout.shape(), - instruction->operand(operand_no)->shape())); + instruction->operand(operand_no)->shape())) + << shape_layout.shape() << " is not compatible with " + << instruction->operand(operand_no)->shape() << " (for operand " + << operand_no << " of instruction " << instruction->ToString() << ")"; } string OperandLayoutConstraint::ToString() const { diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h index 827e092a3fa9116c461716b27c309033f7988745..1c00b2aabd182da72e78d2c9c01cbe70cfd8e33c 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h +++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_KERNEL_SUPPORT_LIBRARY_H_ -#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_KERNEL_SUPPORT_LIBRARY_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_KERNEL_SUPPORT_LIBRARY_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_KERNEL_SUPPORT_LIBRARY_H_ #include @@ -179,4 +179,4 @@ class KernelSupportLibrary { }; } // namespace xla -#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_KERNEL_SUPPORT_LIBRARY_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_KERNEL_SUPPORT_LIBRARY_H_ diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc index d2bcb38d09218c72183c7cece95bef6371006555..8d1e6338e189a055ac20f09961a783b52600866d 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc @@ -150,6 +150,8 @@ llvm::Type* PrimitiveTypeToIrType(PrimitiveType element_type, // addition to an addition on this type (int16) - this is just the type // used for storage. return llvm::Type::getInt16Ty(module->getContext()); + case F16: + return llvm::Type::getHalfTy(module->getContext()); case S32: case U32: return llvm::Type::getInt32Ty(module->getContext()); @@ -292,6 +294,11 @@ llvm::Constant* LiteralToConstant(const Literal& literal, int64 dimension_index, ir_element_type, tensorflow::bit_cast(literal.Get(*multi_index))); break; + case F16: + value = llvm::ConstantFP::get( + ir_element_type, + static_cast(literal.Get(*multi_index))); + break; case F64: value = llvm::ConstantFP::get(ir_element_type, literal.Get(*multi_index)); diff --git a/tensorflow/compiler/xla/service/llvm_ir/ops.h b/tensorflow/compiler/xla/service/llvm_ir/ops.h index f72f482e3128c61e53cc454e7da8b5795ba6f695..175b081e84d31779b15560cb0998011fe046ca01 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ops.h +++ b/tensorflow/compiler/xla/service/llvm_ir/ops.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_OPS_H_ -#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_OPS_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_OPS_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_OPS_H_ #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/elemental_ir_emitter.h" @@ -90,4 +90,4 @@ Status EmitParallelFusedDynamicUpdateSliceInPlace( } // namespace llvm_ir } // namespace xla -#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_OPS_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_OPS_H_ diff --git a/tensorflow/compiler/xla/service/logical_buffer_analysis.h b/tensorflow/compiler/xla/service/logical_buffer_analysis.h index 598d08b7203b25b194dfc3b3125ec58c96b2cd4c..f4c63dd86b4d8a6f598d46047012e4e5bc7b3d7e 100644 --- a/tensorflow/compiler/xla/service/logical_buffer_analysis.h +++ b/tensorflow/compiler/xla/service/logical_buffer_analysis.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_LOGICAL_BUFFER_ANALYSIS_H_ -#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_LOGICAL_BUFFER_ANALYSIS_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LOGICAL_BUFFER_ANALYSIS_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_LOGICAL_BUFFER_ANALYSIS_H_ #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" @@ -90,4 +90,4 @@ class LogicalBufferAnalysis : public DfsHloVisitorWithDefault { } // namespace xla -#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_LOGICAL_BUFFER_ANALYSIS_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_LOGICAL_BUFFER_ANALYSIS_H_ diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index fc848bdb036125e5dadb471be431d3d2523c6770..926ebbe3140d631a3fb03f41c687ae72c58706f5 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -34,6 +34,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_module_config.h" +#include "tensorflow/compiler/xla/service/hlo_proto_util.h" #include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/compiler/xla/service/session.pb.h" #include "tensorflow/compiler/xla/service/transfer_manager.h" @@ -419,6 +420,8 @@ StatusOr> Service::BuildExecutable( /*include_unreachable_instructions=*/ true)); + TF_RETURN_IF_ERROR(MaybeDumpHloModule(*module)); + TF_ASSIGN_OR_RETURN( module, backend->compiler()->RunHloPasses(std::move(module), executor)); @@ -566,7 +569,7 @@ Service::ExecuteParallelAndRegisterResult( se::Stream* stream = index_to_profiled_stream.second; Executable* executable = executables[device]; const HloModule& module = executable->module(); - HloExecutionProfile hlo_profile(&executable->hlo_profile_printer(), + HloExecutionProfile hlo_profile(&executable->hlo_profile_printer_data(), &executable->hlo_profile_index_map()); TF_RETURN_IF_ERROR( executable->PopulateExecutionProfile(&hlo_profile, stream->parent())); @@ -1597,4 +1600,15 @@ StatusOr> Service::Replicas( return replicas; } +Status Service::MaybeDumpHloModule(const HloModule& module) const { + const string xla_dump_prepass_hlo_proto_to = + module.config().debug_options().xla_dump_prepass_hlo_proto_to(); + if (xla_dump_prepass_hlo_proto_to.empty()) { + return Status::OK(); + } + HloProto proto = MakeHloProto(module); + return protobuf_util::DumpProtoToDirectory( + proto, xla_dump_prepass_hlo_proto_to, module.name()); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/service.h b/tensorflow/compiler/xla/service/service.h index f962d0cdc7d41e1aeab55da5abcb1b40215b4144..0a7d0b3a7d25a1b046852c87d8463d0169080a5e 100644 --- a/tensorflow/compiler/xla/service/service.h +++ b/tensorflow/compiler/xla/service/service.h @@ -340,6 +340,8 @@ class Service : public ServiceInterface { StatusOr> Replicas( const Backend& backend, const DeviceHandle& device_handle) const; + Status MaybeDumpHloModule(const HloModule& module) const; + // Returns the device handle that represents the replicated device for a // single computation that is not model-parallelized. DeviceHandle SingleComputationDeviceHandle() const; diff --git a/tensorflow/compiler/xla/service/user_computation.cc b/tensorflow/compiler/xla/service/user_computation.cc index 7882b70ab7765ad528b68f97c115e3ae5f19e48a..2ea6507900e712200ce43e9b63577a4967381fdf 100644 --- a/tensorflow/compiler/xla/service/user_computation.cc +++ b/tensorflow/compiler/xla/service/user_computation.cc @@ -2767,48 +2767,11 @@ HloComputation* ComputationLowerer::ResolveComputation( HloInstruction* ComputationLowerer::ImplicitBroadcastToExplicitBroadcast( HloInstruction* operand, const Shape& output_shape) { - CHECK(ShapeUtil::IsScalar(operand->shape()) || - ShapeUtil::Rank(operand->shape()) == ShapeUtil::Rank(output_shape)); - Shape broadcast_shape = ShapeUtil::MakeShape( - operand->shape().element_type(), AsInt64Slice(output_shape.dimensions())); - // Do explicit broadcast for scalar. - if (ShapeUtil::IsScalar(operand->shape())) { - HloInstruction* broadcast = hlo_builder_.AddInstruction( - HloInstruction::CreateBroadcast(broadcast_shape, operand, {})); - broadcast->set_metadata(operand->metadata()); - if (operand->has_sharding()) { - broadcast->set_sharding(operand->sharding()); - } - return broadcast; - } - // Do explicit broadcast for degenerate broadcast. - std::vector broadcast_dimensions; - std::vector reshaped_dimensions; - for (int i = 0; i < ShapeUtil::Rank(operand->shape()); i++) { - if (operand->shape().dimensions(i) == output_shape.dimensions(i)) { - broadcast_dimensions.push_back(i); - reshaped_dimensions.push_back(operand->shape().dimensions(i)); - } - } - // Eliminate the size one dimensions. - HloInstruction* reshaped_operand = - hlo_builder_.AddInstruction(HloInstruction::CreateReshape( - ShapeUtil::MakeShape(operand->shape().element_type(), - reshaped_dimensions), - operand)); - reshaped_operand->set_metadata(operand->metadata()); - if (operand->has_sharding()) { - reshaped_operand->set_sharding(operand->sharding()); - } - // Broadcast 'reshape' up to the larger size. - HloInstruction* broadcast = - hlo_builder_.AddInstruction(HloInstruction::CreateBroadcast( - broadcast_shape, reshaped_operand, broadcast_dimensions)); - broadcast->set_metadata(operand->metadata()); - if (operand->has_sharding()) { - broadcast->set_sharding(operand->sharding()); - } - return broadcast; + auto fadd = [this](std::unique_ptr x) { + return hlo_builder_.AddInstruction(std::move(x)); + }; + return fadd( + HloInstruction::CreateBroadcastSequence(output_shape, operand, fadd)); } void ComputationLowerer::Visit( diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier.h b/tensorflow/compiler/xla/service/while_loop_simplifier.h index 50dac32a4ab0a5de756c1ddf5e62c3560e54a079..d3d55634c97bbdf3f81321d8089bb808c411340b 100644 --- a/tensorflow/compiler/xla/service/while_loop_simplifier.h +++ b/tensorflow/compiler/xla/service/while_loop_simplifier.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_SIMPLIFIER_H_ -#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_SIMPLIFIER_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_SIMPLIFIER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_SIMPLIFIER_H_ #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" @@ -41,4 +41,4 @@ class WhileLoopSimplifier : public HloPassInterface { } // namespace xla -#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_SIMPLIFIER_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_SIMPLIFIER_H_ diff --git a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h index 63afab4206eb072e84745ced3307295c0516da7b..063e312df66ce9cba0fa9f49c2fc6026ba6b74aa 100644 --- a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h +++ b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_ZERO_SIZED_HLO_ELIMINATION_H_ -#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_ZERO_SIZED_HLO_ELIMINATION_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_ZERO_SIZED_HLO_ELIMINATION_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_ZERO_SIZED_HLO_ELIMINATION_H_ #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" @@ -29,4 +29,4 @@ class ZeroSizedHloElimination : public HloPassInterface { } }; } // namespace xla -#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_ZERO_SIZED_HLO_ELIMINATION_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_ZERO_SIZED_HLO_ELIMINATION_H_ diff --git a/tensorflow/compiler/xla/sparse_index_array.h b/tensorflow/compiler/xla/sparse_index_array.h index 903fee525520205dbd516897fe451b0fd59d3872..f2ce22d6721ff8da46f741ccedc2a63dea5994c8 100644 --- a/tensorflow/compiler/xla/sparse_index_array.h +++ b/tensorflow/compiler/xla/sparse_index_array.h @@ -15,8 +15,8 @@ limitations under the License. // Utility class for managing sparse array indices. -#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SPARSE_INDEX_ARRAY_H_ -#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SPARSE_INDEX_ARRAY_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SPARSE_INDEX_ARRAY_H_ +#define TENSORFLOW_COMPILER_XLA_SPARSE_INDEX_ARRAY_H_ #include @@ -173,4 +173,4 @@ void SparseIndexArray::SortWithValues( } // namespace xla -#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SPARSE_INDEX_ARRAY_H_ +#endif // TENSORFLOW_COMPILER_XLA_SPARSE_INDEX_ARRAY_H_ diff --git a/tensorflow/compiler/xla/statusor_internals.h b/tensorflow/compiler/xla/statusor_internals.h index a2fda5bb3c6f11c20fc45c57885b1ce7523db81d..14636bd144bc0a155fc96c5a350c658fd2dadfe6 100644 --- a/tensorflow/compiler/xla/statusor_internals.h +++ b/tensorflow/compiler/xla/statusor_internals.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_STATUSOR_INTERNALS_H_ -#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_STATUSOR_INTERNALS_H_ +#ifndef TENSORFLOW_COMPILER_XLA_STATUSOR_INTERNALS_H_ +#define TENSORFLOW_COMPILER_XLA_STATUSOR_INTERNALS_H_ #include "tensorflow/compiler/xla/status.h" #include "tensorflow/core/platform/macros.h" @@ -242,4 +242,4 @@ struct TraitsBase { } // namespace internal_statusor } // namespace xla -#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_STATUSOR_INTERNALS_H_ +#endif // TENSORFLOW_COMPILER_XLA_STATUSOR_INTERNALS_H_ diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index 3922c779a0979c493df84431bf97c1da57717443..3afd52b6b2573aaecb125ad6e5bd05b41a1fbc68 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -815,9 +815,6 @@ xla_test( xla_test( name = "bfloat16_test", srcs = ["bfloat16_test.cc"], - blacklisted_backends = [ - "gpu", - ], shard_count = 40, deps = [ ":test_utils", @@ -847,6 +844,30 @@ xla_test( ], ) +xla_test( + name = "half_test", + srcs = ["half_test.cc"], + backends = [ + "cpu", + "gpu", + ], + deps = [ + ":test_utils", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla/client:computation", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + xla_test( name = "slice_test", srcs = ["slice_test.cc"], @@ -1013,6 +1034,7 @@ xla_test( name = "select_and_scatter_test", timeout = "long", srcs = ["select_and_scatter_test.cc"], + tags = ["enable_for_xla_interpreter"], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:literal_util", @@ -1051,6 +1073,19 @@ xla_test( ], ) +xla_test( + name = "reduce_hlo_test", + srcs = ["reduce_hlo_test.cc"], + deps = [ + ":client_library_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/compiler/xla/tools/parser:hlo_parser", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + xla_test( name = "call_test", srcs = ["call_test.cc"], diff --git a/tensorflow/compiler/xla/tests/bfloat16_test.cc b/tensorflow/compiler/xla/tests/bfloat16_test.cc index e47fcad475bb176a7b4598daf2c98897eb34182b..b853dfaa15d7ff2e21048a5a6a486d22c5a05416 100644 --- a/tensorflow/compiler/xla/tests/bfloat16_test.cc +++ b/tensorflow/compiler/xla/tests/bfloat16_test.cc @@ -99,8 +99,9 @@ XLA_TEST_F(Bfloat16Test, BatchNormTraining) { auto expected = Literal::MakeTuple( {Literal::CreateR4( - {{{{static_cast(-1.7f)}, {static_cast(-2.04f)}}, - {{static_cast(0.105f)}, {static_cast(0.65f)}}}, + {{{{static_cast(-1.6875f)}, + {static_cast(-2.04f)}}, + {{static_cast(0.105f)}, {static_cast(0.66f)}}}, {{{static_cast(1.89f)}, {static_cast(3.35f)}}, {{static_cast(3.7f)}, {static_cast(6.04f)}}}}) .get(), diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.cc b/tensorflow/compiler/xla/tests/client_library_test_base.cc index 7c9494f133f3db3733fc2ffa4dacfb9a71dd01d8..a677986cd926cc0054d8f36abc98ccac33dc043d 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.cc +++ b/tensorflow/compiler/xla/tests/client_library_test_base.cc @@ -387,7 +387,7 @@ void ClientLibraryTestBase::ComputeAndCompareTuple( return; } auto actual = actual_status.ConsumeValueOrDie(); - LiteralTestUtil::ExpectEqualTuple(expected, *actual); + LiteralTestUtil::ExpectEqual(expected, *actual); } void ClientLibraryTestBase::ComputeAndCompareTuple( @@ -399,7 +399,7 @@ void ClientLibraryTestBase::ComputeAndCompareTuple( return; } auto actual = actual_status.ConsumeValueOrDie(); - LiteralTestUtil::ExpectNearTuple(expected, *actual, error); + LiteralTestUtil::ExpectNear(expected, *actual, error); } void ClientLibraryTestBase::ComputeAndCompare( diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h index a559a653df89f3b99bd87665a7f2ccf99afa54e0..ba0319990bc04196386e6812b0a03671676698ec 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.h +++ b/tensorflow/compiler/xla/tests/client_library_test_base.h @@ -431,6 +431,7 @@ void ClientLibraryTestBase::ComputeAndCompareR0( static_assert(std::is_same::value || std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value, "Float or complex type required when specifying an ErrorSpec"); std::unique_ptr expected_literal = @@ -456,6 +457,7 @@ void ClientLibraryTestBase::ComputeAndCompareR1( static_assert(std::is_same::value || std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value, "Float or complex type required when specifying an ErrorSpec"); std::unique_ptr expected_literal = @@ -481,6 +483,7 @@ void ClientLibraryTestBase::ComputeAndCompareR2( static_assert(std::is_same::value || std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value, "Float or complex type required when specifying an ErrorSpec"); std::unique_ptr expected_literal = @@ -506,6 +509,7 @@ void ClientLibraryTestBase::ComputeAndCompareR3( static_assert(std::is_same::value || std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value, "Float or complex type required when specifying an ErrorSpec"); std::unique_ptr expected_literal = @@ -531,6 +535,7 @@ void ClientLibraryTestBase::ComputeAndCompareR4( static_assert(std::is_same::value || std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value, "Float or complex type required when specifying an ErrorSpec"); std::unique_ptr expected_literal = diff --git a/tensorflow/compiler/xla/tests/convolution_test.cc b/tensorflow/compiler/xla/tests/convolution_test.cc index a10e17dbf34b3a6fe503f156fab496708b833c07..0ceb9aff378ae8aa8098be9360310b1d78d31ab2 100644 --- a/tensorflow/compiler/xla/tests/convolution_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_test.cc @@ -608,5 +608,28 @@ INSTANTIATE_TEST_CASE_P( ); +TEST_F(ConvolutionTest, Convolve_bf16_1x1x1x2_1x1x1x2_Valid) { + ComputationBuilder builder(client_, TestName()); + Shape input_shape = ShapeUtil::MakeShape(BF16, {1, 1, 1, 2}); + Shape filter_shape = ShapeUtil::MakeShape(BF16, {1, 1, 1, 2}); + auto input = builder.Parameter(0, input_shape, "input"); + auto filter = builder.Parameter(1, filter_shape, "filter"); + auto conv = builder.Conv(input, filter, {1, 1}, Padding::kValid); + + Array4D input_data(1, 1, 1, 2); + input_data.FillWithYX(Array2D({ + {bfloat16(1), bfloat16(2)}, + })); + Array4D filter_data(1, 1, 1, 2); + filter_data.FillWithYX(Array2D({ + {bfloat16(5), bfloat16(6)}, + })); + + ComputeAndCompare(&builder, conv, + {std::move(*Literal::CreateFromArray(input_data)), + std::move(*Literal::CreateFromArray(filter_data))}, + error_spec_); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc index ae3f887240d0ccffcc9c51a2c409de457a94f967..877dc7db0eec229a7119b3627f177a33ed0d971b 100644 --- a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc +++ b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc @@ -595,6 +595,11 @@ XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousSingleElement) { // Single element, no wrap. std::vector operand_shape({4, 5, 2}); RunR3Contiguous(operand_shape, /*index=*/1, /*size=*/1); +} + +XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousSingleElementBF16) { + // Single element, no wrap. + std::vector operand_shape({4, 5, 2}); RunR3Contiguous(operand_shape, /*index=*/1, /*size=*/1); } @@ -602,6 +607,11 @@ XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousMultipleElements) { // Multiple element, no wrap. std::vector operand_shape({4, 5, 2}); RunR3Contiguous(operand_shape, /*index=*/1, /*size=*/2); +} + +XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousMultipleElementsBF16) { + // Multiple element, no wrap. + std::vector operand_shape({4, 5, 2}); RunR3Contiguous(operand_shape, /*index=*/1, /*size=*/2); } @@ -609,6 +619,11 @@ XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousMultipleWrapping) { // Multiple element, wrapping. std::vector operand_shape({4, 5, 2}); RunR3Contiguous(operand_shape, /*index=*/3, /*size=*/2); +} + +XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousMultipleWrappingBF16) { + // Multiple element, wrapping. + std::vector operand_shape({4, 5, 2}); RunR3Contiguous(operand_shape, /*index=*/3, /*size=*/2); } @@ -616,12 +631,21 @@ XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousTooLarge) { // Multiple element, update size larger than operand. std::vector operand_shape({4, 5, 2}); RunR3Contiguous(operand_shape, /*index=*/5, /*size=*/2); +} + +XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousTooLargeBF16) { + // Multiple element, update size larger than operand. + std::vector operand_shape({4, 5, 2}); RunR3Contiguous(operand_shape, /*index=*/5, /*size=*/2); } XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousUnaligned) { std::vector operand_shape({3, 123, 247}); RunR3Contiguous(operand_shape, /*index=*/1, /*size=*/1); +} + +XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousUnalignedBF16) { + std::vector operand_shape({3, 123, 247}); RunR3Contiguous(operand_shape, /*index=*/1, /*size=*/1); } @@ -629,6 +653,10 @@ XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousUnaligned) { XLA_TEST_F(DynamicUpdateSliceTest, DISABLED_ON_GPU(R3ContiguousLarger)) { std::vector operand_shape({32, 128, 1024}); RunR3Contiguous(operand_shape, /*index=*/7, /*size=*/1); +} + +XLA_TEST_F(DynamicUpdateSliceTest, DISABLED_ON_GPU(R3ContiguousLargerBF16)) { + std::vector operand_shape({32, 128, 1024}); RunR3Contiguous(operand_shape, /*index=*/7, /*size=*/1); } diff --git a/tensorflow/compiler/xla/tests/filecheck.h b/tensorflow/compiler/xla/tests/filecheck.h index 493ff7414bde31b18a39a5098925d9c991529b00..3830d5a44d2ca483fbe839231b0136d13033b48b 100644 --- a/tensorflow/compiler/xla/tests/filecheck.h +++ b/tensorflow/compiler/xla/tests/filecheck.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_TESTS_FILECHECK_H_ -#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_TESTS_FILECHECK_H_ +#ifndef TENSORFLOW_COMPILER_XLA_TESTS_FILECHECK_H_ +#define TENSORFLOW_COMPILER_XLA_TESTS_FILECHECK_H_ #include @@ -30,4 +30,4 @@ StatusOr RunFileCheck(const string& input, const string& pattern); } // namespace xla -#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_TESTS_FILECHECK_H_ +#endif // TENSORFLOW_COMPILER_XLA_TESTS_FILECHECK_H_ diff --git a/tensorflow/compiler/xla/tests/half_test.cc b/tensorflow/compiler/xla/tests/half_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..ec2f49d43bd8cee84c6b0abe1892e8b2278eefeb --- /dev/null +++ b/tensorflow/compiler/xla/tests/half_test.cc @@ -0,0 +1,257 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "tensorflow/compiler/xla/client/computation.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/tests/test_utils.h" + +// Tests the handling of the basic mathematics operations with F16 operands. + +namespace xla { +namespace { + +class HalfTestBase : public ClientLibraryTestBase { + protected: + const ErrorSpec error_spec_{0.001, 0.001}; + // Number of elements in the input buffers. + static const int kNumElements = 4; +}; + +using UnaryBuildFuncTy = + std::function; + +struct UnaryOpTestParam { + std::function compute_func; + UnaryBuildFuncTy build_func; +}; + +class UnaryOpTest : public HalfTestBase, + public ::testing::WithParamInterface {}; + +XLA_TEST_P(UnaryOpTest, Ops) { + std::vector x({half(1.4), half(-2.3), half(3.2), half(-4.1)}); + ComputationBuilder builder(client_, TestName()); + ComputationDataHandle x_opnd; + auto x_data = CreateR1Parameter(x, /*parameter_number=*/0, "x", + &builder, &x_opnd); + + std::function compute_func = GetParam().compute_func; + std::vector expected; + for (int64 i = 0; i < x.size(); ++i) { + expected.push_back(compute_func(x[i])); + } + + UnaryBuildFuncTy build_func = GetParam().build_func; + build_func(&builder, x_opnd); + + ComputeAndCompareR1(&builder, expected, {x_data.get()}, error_spec_); +} + +half sign_imp(half value) { + const float x(std::move(value)); + return half((x < .0) ? -1 : (x > .0)); +} + +half round_imp(half value) { + return half(round(static_cast(std::move(value)))); +} + +INSTANTIATE_TEST_CASE_P( + half, UnaryOpTest, + ::testing::Values(UnaryOpTestParam{[](half x) { return abs(x); }, + &ComputationBuilder::Abs}, + UnaryOpTestParam{[](half x) { return round_imp(x); }, + &ComputationBuilder::Round}, + UnaryOpTestParam{[](half x) { return ceil(x); }, + &ComputationBuilder::Ceil}, + UnaryOpTestParam{[](half x) { return cos(x); }, + &ComputationBuilder::Cos}, + UnaryOpTestParam{[](half x) { return exp(x); }, + &ComputationBuilder::Exp}, + UnaryOpTestParam{[](half x) { return floor(x); }, + &ComputationBuilder::Floor}, + UnaryOpTestParam{[](half x) { return log(x); }, + &ComputationBuilder::Log}, + UnaryOpTestParam{[](half x) { return -x; }, + &ComputationBuilder::Neg}, + UnaryOpTestParam{[](half x) { return sign_imp(x); }, + &ComputationBuilder::Sign}, + UnaryOpTestParam{[](half x) { return sin(x); }, + &ComputationBuilder::Sin}, + UnaryOpTestParam{[](half x) { return tanh(x); }, + &ComputationBuilder::Tanh} + + )); + +struct UnaryPredTestParam { + std::function compute_func; + UnaryBuildFuncTy build_func; +}; + +class UnaryPredTest : public HalfTestBase, + public ::testing::WithParamInterface { +}; + +XLA_TEST_P(UnaryPredTest, Ops) { + std::vector x({half(1.4), half(-2.3), half(3.2), half(-4.1)}); + ComputationBuilder builder(client_, TestName()); + ComputationDataHandle x_opnd; + auto x_data = CreateR1Parameter(x, /*parameter_number=*/0, "x", + &builder, &x_opnd); + + std::function compute_func = GetParam().compute_func; + CHECK_EQ(kNumElements, x.size()); + bool expected[kNumElements]; + for (int64 i = 0; i < x.size(); ++i) { + expected[i] = compute_func(x[i]); + } + + UnaryBuildFuncTy build_func = GetParam().build_func; + build_func(&builder, x_opnd); + + ComputeAndCompareR1(&builder, expected, {x_data.get()}); +} + +INSTANTIATE_TEST_CASE_P(half, UnaryPredTest, + ::testing::Values(UnaryPredTestParam{ + [](half x) { return isfinite(x); }, + &ComputationBuilder::IsFinite})); + +using BinaryBuildFuncTy = std::function)>; + +struct BinaryOpTestParam { + std::function compute_func; + BinaryBuildFuncTy build_func; +}; + +class BinaryOpTest : public HalfTestBase, + public ::testing::WithParamInterface {}; + +XLA_TEST_P(BinaryOpTest, Ops) { + std::vector x({half(1.0), half(2.0), half(3.0), half(-4.0)}); + std::vector y({half(0.4), half(-0.3), half(0.2), half(0.1)}); + ComputationBuilder builder(client_, TestName()); + ComputationDataHandle x_opnd; + auto x_data = CreateR1Parameter(x, /*parameter_number=*/0, "x", + &builder, &x_opnd); + + ComputationDataHandle y_opnd; + auto y_data = CreateR1Parameter(y, /*parameter_number=*/1, "y", + &builder, &y_opnd); + + std::function compute_func = GetParam().compute_func; + std::vector expected; + for (int64 i = 0; i < x.size(); ++i) { + expected.push_back(compute_func(x[i], y[i])); + } + + BinaryBuildFuncTy build_func = GetParam().build_func; + build_func(&builder, x_opnd, y_opnd, {}); + + ComputeAndCompareR1(&builder, expected, {x_data.get(), y_data.get()}, + error_spec_); +} + +half atan2_imp(half x, half y) { + return half(atan2(static_cast(std::move(x)), + static_cast(std::move(y)))); +} + +INSTANTIATE_TEST_CASE_P( + half, BinaryOpTest, + ::testing::Values( + BinaryOpTestParam{[](half x, half y) { return x + y; }, + &ComputationBuilder::Add}, + BinaryOpTestParam{[](half x, half y) { return atan2_imp(x, y); }, + &ComputationBuilder::Atan2}, + BinaryOpTestParam{[](half x, half y) { return x / y; }, + &ComputationBuilder::Div}, + BinaryOpTestParam{[](half x, half y) { return max(x, y); }, + &ComputationBuilder::Max}, + BinaryOpTestParam{[](half x, half y) { return min(x, y); }, + &ComputationBuilder::Min}, + BinaryOpTestParam{[](half x, half y) { return x * y; }, + &ComputationBuilder::Mul}, + BinaryOpTestParam{[](half x, half y) { return pow(x, y); }, + &ComputationBuilder::Pow}, + BinaryOpTestParam{[](half x, half y) { return x - y; }, + &ComputationBuilder::Sub} + + )); + +struct BinaryPredTestParam { + std::function compute_func; + BinaryBuildFuncTy build_func; +}; + +class BinaryPredTest + : public HalfTestBase, + public ::testing::WithParamInterface {}; + +XLA_TEST_P(BinaryPredTest, Ops) { + std::vector x({half(1.0), half(2.0), half(0.2), half(-4.0)}); + std::vector y({half(0.4), half(-0.3), half(0.2), half(0.1)}); + ComputationBuilder builder(client_, TestName()); + ComputationDataHandle x_opnd; + auto x_data = CreateR1Parameter(x, /*parameter_number=*/0, "x", + &builder, &x_opnd); + + ComputationDataHandle y_opnd; + auto y_data = CreateR1Parameter(y, /*parameter_number=*/1, "y", + &builder, &y_opnd); + + std::function compute_func = GetParam().compute_func; + CHECK_EQ(kNumElements, x.size()); + bool expected[kNumElements]; + for (int64 i = 0; i < x.size(); ++i) { + expected[i] = compute_func(x[i], y[i]); + } + + BinaryBuildFuncTy build_func = GetParam().build_func; + build_func(&builder, x_opnd, y_opnd, {}); + + ComputeAndCompareR1(&builder, expected, {x_data.get(), y_data.get()}); +} + +INSTANTIATE_TEST_CASE_P( + half, BinaryPredTest, + ::testing::Values(BinaryPredTestParam{[](half x, half y) { return x == y; }, + &ComputationBuilder::Eq}, + BinaryPredTestParam{[](half x, half y) { return x != y; }, + &ComputationBuilder::Ne}, + BinaryPredTestParam{[](half x, half y) { return x >= y; }, + &ComputationBuilder::Ge}, + BinaryPredTestParam{[](half x, half y) { return x > y; }, + &ComputationBuilder::Gt}, + BinaryPredTestParam{[](half x, half y) { return x <= y; }, + &ComputationBuilder::Le}, + BinaryPredTestParam{[](half x, half y) { return x < y; }, + &ComputationBuilder::Lt} + + )); + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/tests/literal_test_util.cc b/tensorflow/compiler/xla/tests/literal_test_util.cc index e5b96c51ce303819e33d67f5f383c119d313bae1..f8205de702fb3534dcd7dbdce6ee0cbfb11d6ee4 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util.cc +++ b/tensorflow/compiler/xla/tests/literal_test_util.cc @@ -301,6 +301,9 @@ bool ExpectLiteralsEqual(const Literal& expected, const Literal& actual, case BF16: match = ExpectLiteralsEqual(expected, actual, &multi_index, 0); break; + case F16: + match = ExpectLiteralsEqual(expected, actual, &multi_index, 0); + break; case F32: match = ExpectLiteralsEqual(expected, actual, &multi_index, 0); break; @@ -313,6 +316,10 @@ bool ExpectLiteralsEqual(const Literal& expected, const Literal& actual, case TUPLE: { bool tuple_match = true; for (int i = 0; i < ShapeUtil::TupleElementCount(expected.shape()); ++i) { + SCOPED_TRACE(tensorflow::strings::StrCat( + "Tuple index ", i, " in ", + ShapeUtil::HumanString(expected.shape()))); + // Create LiteralViews of the expected and actual elements. auto result = Equal(LiteralView::Create(expected, {i}), LiteralView::Create(actual, {i})); @@ -336,47 +343,6 @@ bool ExpectLiteralsEqual(const Literal& expected, const Literal& actual, return result; } -/* static */ ::testing::AssertionResult LiteralTestUtil::EqualTuple( - const Literal& expected, const Literal& actual) { - VLOG(1) << "expected: " << expected.ToString(); - VLOG(1) << "actual: " << actual.ToString(); - - if (!ShapeUtil::IsTuple(expected.shape()) || - !ShapeUtil::IsTuple(actual.shape())) { - return ::testing::AssertionFailure() - << "tuples expected shape = " << expected.shape().ShortDebugString() - << " actual shape = " << actual.shape().ShortDebugString(); - } - AssertEqualShapes(expected.shape(), actual.shape()); - - ::testing::AssertionResult err = ::testing::AssertionSuccess(); - for (int64 i = 0; i < ShapeUtil::TupleElementCount(expected.shape()); ++i) { - SCOPED_TRACE(tensorflow::strings::StrCat( - "Tuple index ", i, " in ", ShapeUtil::HumanString(expected.shape()))); - const auto expected_element = LiteralView::Create(expected, {i}); - const auto actual_element = LiteralView::Create(actual, {i}); - - ::testing::AssertionResult res = [&] { - if (ShapeUtil::IsTuple(expected_element.shape())) { - return EqualTuple(expected_element, actual_element); - } else { - return Equal(expected_element, actual_element); - } - }(); - - if (!res && err) { - err = res; - } - } - - return err; -} - -/* static */ void LiteralTestUtil::ExpectEqualTuple(const Literal& expected, - const Literal& actual) { - EXPECT_TRUE(EqualTuple(expected, actual)); -} - namespace { // Helper class for comparing floating-point literals within an error bound. @@ -417,6 +383,9 @@ class NearComparator { case BF16: ExpectLiteralsNear(expected, actual, 0); break; + case F16: + ExpectLiteralsNear(expected, actual, 0); + break; case F32: ExpectLiteralsNear(expected, actual, 0); break; @@ -609,14 +578,47 @@ bool NearComparator::ExpectValuesNear(bfloat16 expected, static_cast(actual)); } +template <> +bool NearComparator::ExpectValuesNear(half expected, half actual) { + return ExpectValuesNear(static_cast(std::move(expected)), + static_cast(std::move(actual))); +} + } // namespace /* static */ ::testing::AssertionResult LiteralTestUtil::Near( const Literal& expected, const Literal& actual, const ErrorSpec& error) { - NearComparator comparator(error); - return comparator.ExpectNear(expected, actual) - ? ::testing::AssertionSuccess() - : ::testing::AssertionFailure() << "values were not near"; + ::testing::AssertionResult err = + EqualShapes(expected.shape(), actual.shape()); + if (!err) { + return err; + } + + if (ShapeUtil::IsTuple(expected.shape())) { + for (int64 i = 0; i < ShapeUtil::TupleElementCount(expected.shape()); ++i) { + SCOPED_TRACE(tensorflow::strings::StrCat( + "Tuple index ", i, " in ", ShapeUtil::HumanString(expected.shape()))); + const auto expected_element = LiteralView::Create(expected, {i}); + const auto actual_element = LiteralView::Create(actual, {i}); + + ::testing::AssertionResult res = + Near(expected_element, actual_element, error); + if (err && !res) { + err = res; + } + } + return err; + } + + if (ShapeUtil::ElementIsFloating(expected.shape()) || + ShapeUtil::ElementIsComplex(expected.shape())) { + NearComparator comparator(error); + return comparator.ExpectNear(expected, actual) + ? ::testing::AssertionSuccess() + : ::testing::AssertionFailure() << "values were not near"; + } + + return Equal(expected, actual); } /* static */ void LiteralTestUtil::ExpectNear(const Literal& expected, @@ -629,65 +631,13 @@ bool NearComparator::ExpectValuesNear(bfloat16 expected, : tensorflow::strings::StrCat("\nmessage: ", message)); } -/* static */ ::testing::AssertionResult LiteralTestUtil::NearTuple( - const Literal& expected, const Literal& actual, const ErrorSpec& error) { - VLOG(1) << "expected: " << expected.ToString(); - VLOG(1) << "actual: " << actual.ToString(); - - if (!ShapeUtil::IsTuple(expected.shape()) || - !ShapeUtil::IsTuple(actual.shape())) { - return ::testing::AssertionFailure() - << "tuples expected shape = " << expected.shape().ShortDebugString() - << " actual shape = " << actual.shape().ShortDebugString(); - } - AssertEqualShapes(expected.shape(), actual.shape()); - - ::testing::AssertionResult err = ::testing::AssertionSuccess(); - for (int64 i = 0; i < ShapeUtil::TupleElementCount(expected.shape()); ++i) { - SCOPED_TRACE(tensorflow::strings::StrCat( - "Tuple index ", i, " in ", ShapeUtil::HumanString(expected.shape()))); - const auto expected_element = LiteralView::Create(expected, {i}); - const auto actual_element = LiteralView::Create(actual, {i}); - - ::testing::AssertionResult res = [&] { - if (ShapeUtil::IsTuple(expected_element.shape())) { - return NearTuple(expected_element, actual_element, error); - } else if (ShapeUtil::ElementIsFloating(expected_element.shape())) { - return Near(expected_element, actual_element, error); - } else { - return Equal(expected_element, actual_element); - } - }(); - - if (err && !res) { - err = res; - } - } - return err; -} - -/* static */ void LiteralTestUtil::ExpectNearTuple(const Literal& expected, - const Literal& actual, - const ErrorSpec& error) { - EXPECT_TRUE(NearTuple(expected, actual, error)); -} - /*static*/ ::testing::AssertionResult LiteralTestUtil::NearOrEqual( const Literal& expected, const Literal& actual, const tensorflow::gtl::optional& error) { - bool is_tuple = ShapeUtil::IsTuple(expected.shape()); if (error.has_value()) { - if (is_tuple) { - VLOG(1) << "Expects near tuple"; - return NearTuple(expected, actual, *error); - } VLOG(1) << "Expects near"; return Near(expected, actual, *error); } - if (is_tuple) { - VLOG(1) << "Expects equal tuple"; - return EqualTuple(expected, actual); - } VLOG(1) << "Expects equal"; return Equal(expected, actual); } @@ -712,6 +662,7 @@ bool NearComparator::ExpectValuesNear(bfloat16 expected, new_num_elements *= new_dimensions[i]; } CHECK_EQ(ShapeUtil::ElementsIn(literal.shape()), new_num_elements); + CHECK_EQ(new_dimensions.size(), minor_to_major.size()); auto new_literal = MakeUnique( ShapeUtil::MakeShape(literal.shape().element_type(), new_dimensions)); @@ -761,6 +712,10 @@ bool NearComparator::ExpectValuesNear(bfloat16 expected, new_literal->Set(to_multi_index, literal.Get(from_multi_index)); break; + case C64: + new_literal->Set(to_multi_index, + literal.Get(from_multi_index)); + break; default: LOG(FATAL) << "Unhandled primitive element type: " << PrimitiveType_Name(literal.shape().element_type()); diff --git a/tensorflow/compiler/xla/tests/literal_test_util.h b/tensorflow/compiler/xla/tests/literal_test_util.h index f53553c70170bdcda717e72ffd791016effd0774..9b0724262d51ec7964a918bb8eb8716308662b96 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util.h +++ b/tensorflow/compiler/xla/tests/literal_test_util.h @@ -111,17 +111,18 @@ class LiteralTestUtil { static void ExpectR4EqualArray4D(const Array4D& expected, const Literal& actual); - // Returns whether the two tuples are equal. - static ::testing::AssertionResult EqualTuple( - const Literal& expected, const Literal& actual) TF_MUST_USE_RESULT; - - // Expects that the values of the elements in the expected and actual tuples - // are equal. Tuples are matched recursively. - static void ExpectEqualTuple(const Literal& expected, const Literal& actual); - // Asserts that the expected and actual literals are within the given error // bound for all elements. Also, asserts that the rank, dimensions sizes, and - // bounds are equivalent. Only supported for floating point values. + // bounds are equivalent. + // + // Tuples are matched recursively. When comparing tensors of + // non-floating-point type, checks for exact equality, ignoring the ErroSpec. + // + // If the shape of the literals is neither a complex/floating-point tensor nor + // a tuple which contains a complex/floating-point tensor, Near() is + // equivalent to Equal(). We don't raise an error in this case, because we + // want to allow callers to call Near() even if they have no preconceptions + // about the shapes being compared. static ::testing::AssertionResult Near( const Literal& expected, const Literal& actual, const ErrorSpec& error) TF_MUST_USE_RESULT; @@ -170,18 +171,6 @@ class LiteralTestUtil { const Literal& actual, const ErrorSpec& error); - // Returns whether the values of the elements in the expected and actual - // tuples are within the given error bound. Tuples are matched recursively. - // If the elements of the tuple are not floating-point types, the error spec - // is ignored and exact equality is checked. - static ::testing::AssertionResult NearTuple( - const Literal& expected, const Literal& actual, - const ErrorSpec& error) TF_MUST_USE_RESULT; - - // Expects that the expected and actual values are near. - static void ExpectNearTuple(const Literal& expected, const Literal& actual, - const ErrorSpec& error); - // If the error spec is given, returns whether the expected and the actual are // within the error bound; otherwise, returns whether they are equal. Tuples // will be compared recursively. diff --git a/tensorflow/compiler/xla/tests/local_client_aot_test.cc b/tensorflow/compiler/xla/tests/local_client_aot_test.cc index 569d5944cab0ae8f6a7b58a651285d20d4f9d019..47cab796041e9669affaebd7866d0d80100730f1 100644 --- a/tensorflow/compiler/xla/tests/local_client_aot_test.cc +++ b/tensorflow/compiler/xla/tests/local_client_aot_test.cc @@ -44,8 +44,7 @@ TEST_F(LocalClientAotTest, Constant) { OpaqueData opaque_data{100, 20, 3}; void* parameters[] = {&opaque_data}; float out = 0; - char tmp[4] = {0}; - void* temporary_buffers[] = {nullptr, &out, &tmp}; + void* temporary_buffers[] = {nullptr, &out}; SumAndDouble(&out, &run_options, parameters, temporary_buffers); EXPECT_EQ(out, 246.0f); diff --git a/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc b/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc index 4d3b513b092e0b447a1452a3809fb7099e54dbb9..3704ddd8010bf727b75ff81b63605e8b7ffe2ca8 100644 --- a/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc +++ b/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc @@ -87,10 +87,9 @@ int main(int argc, char** argv) { // It's lame to hard-code the buffer assignments, but we need // local_client_aot_test.cc to be able to easily invoke the function. CHECK_EQ(result->result_buffer_index(), 1); - CHECK_EQ(result->buffer_sizes().size(), 3); + CHECK_EQ(result->buffer_sizes().size(), 2); CHECK_EQ(result->buffer_sizes()[0], -1); // param buffer CHECK_EQ(result->buffer_sizes()[1], sizeof(float)); // result buffer - CHECK_EQ(result->buffer_sizes()[2], sizeof(float)); // temp buffer if (triple.isOSBinFormatELF()) { // Check the ELF magic. CHECK_EQ(result->object_file_data()[0], 0x7F); diff --git a/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc b/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc index 0fb87c3c2ccbad387d46016cfad4e7d3cc537dcc..6c86dd5b9ef673c9facffafa37e00a859ce82010 100644 --- a/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc +++ b/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc @@ -221,5 +221,77 @@ INSTANTIATE_TEST_CASE_P(MatOpsDotAddTestInstances, MatOpsDotAddTest, ::testing::Combine(::testing::Bool(), ::testing::Bool(), ::testing::Bool())); +class MatOpsDotAddTest_bf16 + : public ClientLibraryTestBase, + public ::testing::WithParamInterface> {}; + +TEST_P(MatOpsDotAddTest_bf16, Dot_Add_2x2_2x2) { + bool row_major = std::get<0>(GetParam()); + bool add_lhs = std::get<1>(GetParam()); + bool transpose = std::get<2>(GetParam()); + Array2D lhs( + {{bfloat16(1.0f), bfloat16(2.0f)}, {bfloat16(3.0), bfloat16(4.0)}}); + Array2D rhs( + {{bfloat16(10.0f), bfloat16(11.0f)}, {bfloat16(12.0f), bfloat16(13.0f)}}); + + auto minor_to_major = [](bool row_major) -> std::vector { + return {row_major ? 1 : 0, row_major ? 0 : 1}; + }; + + auto prim_type = primitive_util::NativeToPrimitiveType(); + Shape lhs_shape = + ShapeUtil::MakeShape(prim_type, {lhs.height(), lhs.width()}); + Shape rhs_shape = + ShapeUtil::MakeShape(prim_type, {rhs.height(), rhs.width()}); + + TF_ASSERT_OK_AND_ASSIGN( + auto lhs_handle, + client_->TransferToServer( + *Literal::CreateR2FromArray2DWithLayout( + lhs, LayoutUtil::MakeLayout(minor_to_major(row_major))))); + TF_ASSERT_OK_AND_ASSIGN( + auto rhs_handle, + client_->TransferToServer( + *Literal::CreateR2FromArray2DWithLayout( + rhs, LayoutUtil::MakeLayout(minor_to_major(row_major))))); + + ComputationBuilder builder(client_, TestName()); + auto lhs_arg = builder.Parameter(0, lhs_shape, "lhs"); + auto lhs_mat_arg = lhs_arg; + if (transpose) { + lhs_mat_arg = builder.Transpose(lhs_mat_arg, {1, 0}); + } + auto rhs_arg = builder.Parameter(1, rhs_shape, "rhs"); + auto result = builder.Dot(lhs_mat_arg, rhs_arg); + Array2D expected; + if (add_lhs) { + result = builder.Add(result, lhs_arg); + if (transpose) { + expected = Array2D( + {{bfloat16(47), bfloat16(52)}, {bfloat16(71), bfloat16(78)}}); + } else { + expected = Array2D( + {{bfloat16(35), bfloat16(39)}, {bfloat16(81), bfloat16(89)}}); + } + } else { + result = builder.Add(result, rhs_arg); + if (transpose) { + expected = Array2D( + {{bfloat16(56), bfloat16(61)}, {bfloat16(80), bfloat16(87)}}); + } else { + expected = Array2D( + {{bfloat16(44), bfloat16(48)}, {bfloat16(90), bfloat16(98)}}); + } + } + + ComputeAndCompareR2(&builder, expected, + {lhs_handle.get(), rhs_handle.get()}, + ErrorSpec(1e-6)); +} + +INSTANTIATE_TEST_CASE_P(MatOpsDotAddTestInstances, MatOpsDotAddTest_bf16, + ::testing::Combine(::testing::Bool(), ::testing::Bool(), + ::testing::Bool())); + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/prng_test.cc b/tensorflow/compiler/xla/tests/prng_test.cc index 6489eee9f34c6c4426d52e166f7b401d5948742f..6aafb9fa6cb2175c478f0e9a5e16f5808cbea590 100644 --- a/tensorflow/compiler/xla/tests/prng_test.cc +++ b/tensorflow/compiler/xla/tests/prng_test.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include "tensorflow/compiler/xla/client/computation_builder.h" @@ -36,36 +37,42 @@ namespace { class PrngTest : public ClientLibraryTestBase { protected: template - void UniformTest(T a, T b, tensorflow::gtl::ArraySlice dims); - - template - void BernoulliTest(float p, tensorflow::gtl::ArraySlice dims); + std::unique_ptr UniformTest(T a, T b, + tensorflow::gtl::ArraySlice dims, + int64 seed = 42); // Computes the χ² statistic of a sample of the discrete uniform distribution // of the given range size. `expected_count` is the number of times each // possible value is expected to be generated. Thus, the sample size is // `range_size * expected_count`. - double UniformChiSquared(int32 range_size, int32 expected_count); + double UniformChiSquared(int32 range_size, int32 expected_count, + int64 seed = 42); }; template -void PrngTest::UniformTest(T a, T b, tensorflow::gtl::ArraySlice dims) { +std::unique_ptr PrngTest::UniformTest( + T a, T b, tensorflow::gtl::ArraySlice dims, int64 seed) { ComputationBuilder builder(client_, TestName()); builder.RngUniform( builder.ConstantR0(a), builder.ConstantR0(b), ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType(), dims)); - SetSeed(42); + SetSeed(seed); auto actual = ExecuteAndTransferOrDie(&builder, /*arguments=*/{}); EXPECT_THAT(dims, ::testing::ElementsAreArray(actual->shape().dimensions())); actual->EachCell([=](tensorflow::gtl::ArraySlice, T value) { EXPECT_LE(a, value); EXPECT_LT(value, b); }); + return actual; } // Uniform random number generation tests XLA_TEST_F(PrngTest, ScalarU01) { UniformTest(0, 1, {}); } +XLA_TEST_F(PrngTest, ScalarU01limits) { + UniformTest(std::numeric_limits::min(), + std::numeric_limits::max(), {}); +} XLA_TEST_F(PrngTest, ZeroValuesU01) { UniformTest(0, 1, {0}); } XLA_TEST_F(PrngTest, TenValuesU01) { UniformTest(0, 1, {10}); } XLA_TEST_F(PrngTest, TenValuesU37) { UniformTest(3, 7, {10}); } @@ -73,6 +80,56 @@ XLA_TEST_F(PrngTest, ZeroValuesR2) { UniformTest(0, 1, {0, 20}); } XLA_TEST_F(PrngTest, LargeU01) { UniformTest(0, 1, {0x100, 0x100}); } XLA_TEST_F(PrngTest, TwelveValuesU524) { UniformTest(5, 24, {12}); } +// TODO(b/71543667): Fix Rng ops on LLVM backends. +XLA_TEST_F(PrngTest, DISABLED_ON_GPU(DISABLED_ON_CPU_PARALLEL( + DISABLED_ON_CPU(ScalarBF16Tests)))) { + for (int64 seed = 0; seed < 100; ++seed) { + // The largest negative number smaller than zero in bf16 that's not + // denormalized. + int32 low_raw = 0x80800000; + const float low = reinterpret_cast(low_raw); + float high = 0.0f; + UniformTest(static_cast(low), + static_cast(high), {}, /*seed=*/seed); + + // Test odd and even values. + UniformTest(static_cast(32.75), + static_cast(33), {}, /*seed=*/seed); + UniformTest(static_cast(32.50), + static_cast(32.75), {}, /*seed=*/seed); + UniformTest(static_cast(-33.00), + static_cast(-32.75), {}, /*seed=*/seed); + UniformTest(static_cast(-32.75), + static_cast(-32.50), {}, /*seed=*/seed); + } +} + +// TODO(b/71543667): Fix Rng ops on LLVM backends. +XLA_TEST_F(PrngTest, DISABLED_ON_GPU(DISABLED_ON_CPU( + DISABLED_ON_CPU_PARALLEL(ScalarBF16CountTests)))) { + // There are 3 BF16 values in the range of [32.25, 33): 32.25, 32.5, 32.75, + // they should get similar counts. + bfloat16 low = static_cast(32.25); + bfloat16 high = static_cast(33); + bfloat16 interval = static_cast(0.25); + std::vector counts(static_cast((high - low) / interval), 0); + + constexpr int64 count = 100; + for (int64 seed = 0; seed < count; ++seed) { + auto result = UniformTest(low, high, {}, /*seed=*/seed); + result->Literal::EachCell( + [&](tensorflow::gtl::ArraySlice, bfloat16 value) { + int64 index = static_cast((value - low) / interval); + counts[index]++; + }); + } + // Each bucket should have similar amount of counts. That is, not more than + // 10% of total counts. This mostly tests that we don't fall into a 1:2:2 + // distribution, which yields 20% expected difference. + EXPECT_LT(std::abs(counts[0] - counts[1]), count * 0.1); + EXPECT_LT(std::abs(counts[1] - counts[2]), count * 0.1); +} + namespace { template T Square(T x) { @@ -80,7 +137,8 @@ T Square(T x) { } } // namespace -double PrngTest::UniformChiSquared(int32 range_size, int32 expected_count) { +double PrngTest::UniformChiSquared(int32 range_size, int32 expected_count, + int64 seed) { int32 sample_size = range_size * expected_count; ComputationBuilder builder(client_, TestName()); @@ -88,7 +146,7 @@ double PrngTest::UniformChiSquared(int32 range_size, int32 expected_count) { builder.ConstantR0(range_size), ShapeUtil::MakeShape(S32, {sample_size})); - SetSeed(42); + SetSeed(seed); auto actual = ExecuteAndTransferOrDie(&builder, /*arguments=*/{}); std::vector counts(range_size, 0); actual->EachCell([&counts](tensorflow::gtl::ArraySlice, diff --git a/tensorflow/compiler/xla/tests/reduce_hlo_test.cc b/tensorflow/compiler/xla/tests/reduce_hlo_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..c0a2c0ca4cb8414e0771a541b9f963f9aedc8376 --- /dev/null +++ b/tensorflow/compiler/xla/tests/reduce_hlo_test.cc @@ -0,0 +1,132 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" + +// Tests the Reduce HLO in ways that can't be done using the ComputationBuilder +// API. + +namespace xla { +namespace { + +namespace str_util = tensorflow::str_util; +namespace strings = tensorflow::strings; + +struct ReduceLayout { + std::array input_minor_to_major; + std::array output_minor_to_major; + + string ToString() const { + return strings::StrCat(str_util::Join(input_minor_to_major, "x"), "_", + str_util::Join(output_minor_to_major, "x")); + } +}; + +string PrintReduceLayout( + ::testing::TestParamInfo reduce_layout_param) { + return reduce_layout_param.param.ToString(); +} + +void PrintTo(const ReduceLayout& reduce_layout, ::std::ostream* os) { + *os << reduce_layout.ToString(); +} + +class ReduceWithLayoutTest + : public HloTestBase, + public ::testing::WithParamInterface {}; + +StatusOr> GetParsedModule() { + const char* const hlo_string = R"( +HloModule BadReduce + +Sum { + x.1 = f32[] parameter(0) + y.1 = f32[] parameter(1) + ROOT add.1 = f32[] add(x.1, y.1) +} + +ENTRY reduce.1 { + parameter = f32[2,2,2,3]{3,2,1,0} parameter(0) + init_value = f32[] constant(0) + reduce = f32[2,2,3]{2,1,0} reduce(parameter, init_value), dimensions={1}, to_apply=Sum + ROOT copy = f32[2,2,3]{2,1,0} copy(reduce) +} +)"; + + return tools::Parse(hlo_string); +} + +// TODO(b/72454718): XLA:GPU does not support executing code compiled without +// optimizations. +XLA_TEST_P(ReduceWithLayoutTest, DISABLED_ON_GPU(Reduce)) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, GetParsedModule()); + HloInstruction* reduce_instruction = + module->entry_computation()->root_instruction()->mutable_operand(0); + ASSERT_EQ(reduce_instruction->opcode(), HloOpcode::kReduce); + + const ReduceLayout& reduce_layout = GetParam(); + + Shape* reduce_output_shape = reduce_instruction->mutable_shape(); + *reduce_output_shape->mutable_layout() = + LayoutUtil::MakeLayout(reduce_layout.output_minor_to_major); + + Shape* reduce_input_shape = + reduce_instruction->mutable_operand(0)->mutable_shape(); + *reduce_input_shape->mutable_layout() = + LayoutUtil::MakeLayout(reduce_layout.input_minor_to_major); + + std::unique_ptr reduce_input = + Literal::CreateR4({{ /*i0=0*/ + {/*i1=0*/ + {-0.246092796, -0.179497838, -0.161181688}, + {-0.151643038, -0.240213156, -0.198156}}, + {/*i1=1*/ + {-0.14222312, -0.162200093, -0.193907976}, + {-0.239411, -0.198166847, -0.172471642}}}, + { /*i0=1*/ + {/*i1=0*/ + {-0.22965157, -0.218723893, -0.129257083}, + {-0.188762426, -0.16123569, -0.181166649}}, + {/*i1=1*/ + {-0.241772294, -0.245131493, -0.160247207}, + {-0.179881215, -0.23383224, -0.121976733}}}}); + + EXPECT_TRUE(RunAndCompareNoHloPasses(std::move(module), ErrorSpec(1e-5))); +} + +INSTANTIATE_TEST_CASE_P(ReduceWithLayoutTest_Instantiation, + ReduceWithLayoutTest, + ::testing::Values( // + ReduceLayout{{3, 2, 1, 0}, {0, 1, 2}}, // + ReduceLayout{{3, 2, 1, 0}, {0, 2, 1}}, // + ReduceLayout{{3, 2, 1, 0}, {1, 2, 0}}, // + ReduceLayout{{3, 2, 1, 0}, {1, 0, 2}}, // + ReduceLayout{{3, 2, 1, 0}, {2, 0, 1}}, // + ReduceLayout{{3, 2, 1, 0}, {2, 1, 0}}, // + ReduceLayout{{3, 1, 2, 0}, {1, 2, 0}}, // + ReduceLayout{{1, 2, 3, 0}, {1, 0, 2}}, // + ReduceLayout{{0, 2, 1, 3}, {2, 0, 1}}), // + PrintReduceLayout); + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/tests/reduce_window_test.cc b/tensorflow/compiler/xla/tests/reduce_window_test.cc index 01f23efcd52e3b227309df3b7d965f3b4c3a0cdf..73b37e201afa13546179e2ce7a76d3f7967de524 100644 --- a/tensorflow/compiler/xla/tests/reduce_window_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_window_test.cc @@ -533,6 +533,7 @@ struct R4ReduceWindowTestData { int64 strides[4]; int64 pad_low[4]; int64 pad_high[4]; + int64 layout[4]; Reducer reducer; }; @@ -548,7 +549,8 @@ string R4ReduceWindowTestDataToString( "__strides_", tensorflow::str_util::Join(param.strides, "x"), // "__pad_low_", tensorflow::str_util::Join(param.pad_low, "x"), // "__pad_high_", tensorflow::str_util::Join(param.pad_high, "x"), // - (param.reducer == kAdd) ? "add" : "max"); + "__layout_", tensorflow::str_util::Join(param.layout, "_"), // + (param.reducer == kAdd) ? "_add" : "_max"); CHECK(param.reducer == kAdd || param.reducer == kMax); // Test names are not allowed to contain the '-' character. @@ -575,7 +577,8 @@ class R4ReduceWindowTest : public ReduceWindowTestBase, param.base_bounds[2], param.base_bounds[3]); input.FillIota(1); std::unique_ptr input_literal = - Literal::CreateR4FromArray4D(input); + Literal::CreateR4FromArray4DWithLayout( + input, LayoutUtil::MakeLayout(param.layout)); ComputationDataHandle parameter; auto input_arg = CreateParameterAndTransferLiteral(0, *input_literal, "p0", &b, ¶meter); @@ -611,8 +614,13 @@ class R4ReduceWindowTest : public ReduceWindowTestBase, /*window=*/param.window_bounds, /*stride=*/param.strides, /*padding=*/padding); - ComputeAndCompareLiteral(&b, *Literal::CreateFromArray(*expected), - {input_arg.get()}, DefaultErrorSpec()); + std::unique_ptr expected_literal = + Literal::CreateFromArray(*expected); + const Shape& expected_shape_with_layout = ShapeUtil::MakeShapeWithLayout( + input_literal->shape().element_type(), + AsInt64Slice(expected_literal->shape().dimensions()), param.layout); + ComputeAndCompareLiteral(&b, *expected_literal, {input_arg.get()}, + DefaultErrorSpec(), &expected_shape_with_layout); } }; @@ -626,6 +634,7 @@ const R4ReduceWindowTestData kR4ReduceWindowTestValues[] = { /*strides=*/{1, 1, 1, 1}, /*pad_low=*/{0, 0, 0, 0}, /*pad_high=*/{0, 0, 0, 0}, + /*layout=*/{3, 2, 1, 0}, /*reducer=*/kAdd}, // Arbitrary padding (not kSame or kValid). @@ -634,6 +643,7 @@ const R4ReduceWindowTestData kR4ReduceWindowTestValues[] = { /*strides=*/{2, 2, 1, 1}, /*pad_low=*/{4, 4, 0, 0}, /*pad_high=*/{4, 4, 0, 0}, + /*layout=*/{3, 2, 1, 0}, /*reducer=*/kAdd}, // Zero base bound edge case. @@ -642,6 +652,7 @@ const R4ReduceWindowTestData kR4ReduceWindowTestValues[] = { /*strides=*/{1, 1, 1, 1}, /*pad_low=*/{0, 0, 0, 0}, /*pad_high=*/{0, 0, 0, 0}, + /*layout=*/{3, 2, 1, 0}, /*reducer=*/kAdd}, // With non-1x1 window. @@ -650,6 +661,7 @@ const R4ReduceWindowTestData kR4ReduceWindowTestValues[] = { /*strides=*/{1, 1, 1, 1}, /*pad_low=*/{0, 0, 0, 0}, /*pad_high=*/{0, 0, 0, 0}, + /*layout=*/{3, 2, 1, 0}, /*reducer=*/kAdd}, // With max instead of add. @@ -658,6 +670,7 @@ const R4ReduceWindowTestData kR4ReduceWindowTestValues[] = { /*strides=*/{1, 1, 1, 1}, /*pad_low=*/{0, 0, 0, 0}, /*pad_high=*/{0, 0, 0, 0}, + /*layout=*/{3, 2, 1, 0}, /*reducer=*/kMax}, // With stride. @@ -666,6 +679,7 @@ const R4ReduceWindowTestData kR4ReduceWindowTestValues[] = { /*strides=*/{2, 4, 1, 1}, /*pad_low=*/{0, 0, 0, 0}, /*pad_high=*/{0, 0, 0, 0}, + /*layout=*/{3, 2, 1, 0}, /*reducer=*/kAdd}, // With low padding. @@ -674,6 +688,7 @@ const R4ReduceWindowTestData kR4ReduceWindowTestValues[] = { /*strides=*/{2, 2, 1, 1}, /*pad_low=*/{3, 2, 0, 0}, /*pad_high=*/{0, 0, 0, 0}, + /*layout=*/{3, 2, 1, 0}, /*reducer=*/kAdd}, // With high padding. @@ -682,6 +697,7 @@ const R4ReduceWindowTestData kR4ReduceWindowTestValues[] = { /*strides=*/{2, 2, 1, 1}, /*pad_low=*/{0, 0, 0, 0}, /*pad_high=*/{2, 3, 0, 0}, + /*layout=*/{3, 2, 1, 0}, /*reducer=*/kAdd}, // Window touches both sides of the padding simultaneously. @@ -690,6 +706,7 @@ const R4ReduceWindowTestData kR4ReduceWindowTestValues[] = { /*strides=*/{1, 1, 1, 1}, /*pad_low=*/{1, 1, 0, 0}, /*pad_high=*/{1, 1, 0, 0}, + /*layout=*/{3, 2, 1, 0}, /*reducer=*/kAdd}, // Window is entirely in the padding for some positions. @@ -698,6 +715,7 @@ const R4ReduceWindowTestData kR4ReduceWindowTestValues[] = { /*strides=*/{1, 1, 1, 1}, /*pad_low=*/{4, 4, 0, 0}, /*pad_high=*/{4, 4, 0, 0}, + /*layout=*/{3, 2, 1, 0}, /*reducer=*/kAdd}, // Zero base bound with padding edge case. @@ -706,6 +724,7 @@ const R4ReduceWindowTestData kR4ReduceWindowTestValues[] = { /*strides=*/{1, 1, 1, 1}, /*pad_low=*/{0, 1, 0, 0}, /*pad_high=*/{0, 0, 0, 0}, + /*layout=*/{3, 2, 1, 0}, /*reducer=*/kAdd}, // With stride, low padding and high padding. @@ -714,6 +733,7 @@ const R4ReduceWindowTestData kR4ReduceWindowTestValues[] = { /*strides=*/{3, 1, 1, 1}, /*pad_low=*/{10, 1, 0, 0}, /*pad_high=*/{2, 3, 0, 0}, + /*layout=*/{3, 2, 1, 0}, /*reducer=*/kAdd}, // With second minor dimension == 9. @@ -722,6 +742,7 @@ const R4ReduceWindowTestData kR4ReduceWindowTestValues[] = { /*strides=*/{1, 1, 1, 1}, /*pad_low=*/{0, 0, 0, 0}, /*pad_high=*/{0, 0, 0, 0}, + /*layout=*/{3, 2, 1, 0}, /*reducer=*/kAdd}, // With minor dimension == 129. @@ -730,6 +751,7 @@ const R4ReduceWindowTestData kR4ReduceWindowTestValues[] = { /*strides=*/{1, 1, 1, 1}, /*pad_low=*/{0, 0, 0, 0}, /*pad_high=*/{0, 0, 0, 0}, + /*layout=*/{3, 2, 1, 0}, /*reducer=*/kAdd}, // With minor dims reduction and non-overlapped stride. @@ -738,6 +760,7 @@ const R4ReduceWindowTestData kR4ReduceWindowTestValues[] = { /*strides=*/{1, 1, 2, 2}, /*pad_low=*/{0, 0, 0, 0}, /*pad_high=*/{0, 0, 0, 0}, + /*layout=*/{3, 2, 1, 0}, /*reducer=*/kAdd}, // With minor dims reduction and overlapped stride. @@ -745,7 +768,8 @@ const R4ReduceWindowTestData kR4ReduceWindowTestValues[] = { /*window_bounds=*/{1, 1, 4, 4}, /*strides=*/{1, 1, 2, 2}, /*pad_low=*/{0, 0, 0, 0}, - /*pad_high=*/{0, 0, 0, 0}, + /*pad_high=*/{1, 0, 0, 0}, + /*layout=*/{3, 2, 1, 0}, /*reducer=*/kAdd}, }; @@ -762,10 +786,11 @@ XLA_TEST_P(R4ReduceWindowLargeTest, DISABLED_ON_INTERPRETER(DoIt)) { DoIt(); } // Test cases that are large/slow/failed. const R4ReduceWindowTestData kR4ReduceWindowLargeTestValues[] = { R4ReduceWindowTestData{/*base_bounds=*/{28, 28, 256, 128}, - /*window_bounds=*/{3, 3, 1, 1}, - /*strides=*/{1, 1, 1, 1}, + /*window_bounds=*/{3, 3, 1, 5}, + /*strides=*/{1, 1, 1, 5}, /*pad_low=*/{1, 1, 0, 0}, /*pad_high=*/{1, 1, 0, 0}, + /*layout=*/{3, 2, 1, 0}, /*reducer=*/kMax}, R4ReduceWindowTestData{/*base_bounds=*/{112, 112, 64, 128}, @@ -773,6 +798,7 @@ const R4ReduceWindowTestData kR4ReduceWindowLargeTestValues[] = { /*strides=*/{2, 2, 1, 1}, /*pad_low=*/{0, 0, 0, 0}, /*pad_high=*/{1, 1, 0, 0}, + /*layout=*/{3, 2, 1, 0}, /*reducer=*/kAdd}, }; @@ -782,6 +808,54 @@ INSTANTIATE_TEST_CASE_P( ::testing::ValuesIn(use_bfloat16_params)), R4ReduceWindowTestDataToString); +class R4ReduceWindowAnyDimsTest : public R4ReduceWindowTest {}; + +// TODO(b/72234705): Fix the test cases failed on CPU and GPU. +XLA_TEST_P(R4ReduceWindowAnyDimsTest, + DISABLED_ON_CPU_PARALLEL(DISABLED_ON_CPU(DISABLED_ON_GPU(DoIt)))) { + DoIt(); +} + +const R4ReduceWindowTestData kR4ReduceWindowAnyDimsTestValues[] = { + R4ReduceWindowTestData{/*base_bounds=*/{4, 6, 17, 140}, + /*window_bounds=*/{2, 3, 4, 5}, + /*strides=*/{1, 1, 1, 1}, + /*pad_low=*/{0, 0, 0, 0}, + /*pad_high=*/{0, 0, 0, 0}, + /*layout=*/{3, 2, 1, 0}, + /*reducer=*/kAdd}, + R4ReduceWindowTestData{/*base_bounds=*/{4, 6, 17, 140}, + /*window_bounds=*/{2, 3, 1, 1}, + /*strides=*/{1, 1, 1, 1}, + /*pad_low=*/{0, 0, 0, 0}, + /*pad_high=*/{0, 0, 0, 0}, + /*layout=*/{3, 2, 1, 0}, + /*reducer=*/kMax}, + // With 0321 layout. + R4ReduceWindowTestData{/*base_bounds=*/{4, 6, 17, 140}, + /*window_bounds=*/{2, 3, 4, 5}, + /*strides=*/{1, 2, 3, 4}, + /*pad_low=*/{0, 0, 0, 0}, + /*pad_high=*/{0, 0, 0, 0}, + /*layout=*/{0, 3, 2, 1}, + /*reducer=*/kAdd}, + + // With 0123 layout. + R4ReduceWindowTestData{/*base_bounds=*/{4, 6, 17, 23}, + /*window_bounds=*/{2, 3, 7, 9}, + /*strides=*/{1, 2, 5, 8}, + /*pad_low=*/{0, 0, 0, 0}, + /*pad_high=*/{0, 0, 0, 0}, + /*layout=*/{0, 1, 2, 3}, + /*reducer=*/kAdd}, +}; + +INSTANTIATE_TEST_CASE_P( + R4ReduceWindowAnyDimsTestInstantiation, R4ReduceWindowAnyDimsTest, + ::testing::Combine(::testing::ValuesIn(kR4ReduceWindowAnyDimsTestValues), + ::testing::ValuesIn(use_bfloat16_params)), + R4ReduceWindowTestDataToString); + struct R3ReduceWindowTestData { int64 base_bounds[3]; int64 window_bounds[3]; diff --git a/tensorflow/compiler/xla/tests/select_and_scatter_test.cc b/tensorflow/compiler/xla/tests/select_and_scatter_test.cc index 62ff349e9c011e0eb845192013a74aeb0956b791..9ee94b8571e5fc8789b60501462986967ce909a0 100644 --- a/tensorflow/compiler/xla/tests/select_and_scatter_test.cc +++ b/tensorflow/compiler/xla/tests/select_and_scatter_test.cc @@ -39,8 +39,8 @@ namespace xla { namespace { struct SelectAndScatterTestParam { - Array4D operand_shape; - Array4D source_shape; + std::vector operand_shape; + std::vector source_shape; Padding padding_type; tensorflow::gtl::ArraySlice window_dimensions; tensorflow::gtl::ArraySlice window_strides; @@ -69,83 +69,132 @@ class SelectAndScatterTest Computation min_f32_; }; -XLA_TEST_P(SelectAndScatterTest, R4Randomized) { - Array4D o(GetParam().operand_shape); +XLA_TEST_P(SelectAndScatterTest, ParamTest) { + auto operand_shape = GetParam().operand_shape; + Array o(operand_shape); o.FillRandom(1.5f); - auto operand = builder_.ConstantR4FromArray4D(o); + auto operand = builder_.ConstantFromArray(o); - Array4D s(GetParam().source_shape); + auto source_shape = GetParam().source_shape; + Array s(source_shape); s.FillRandom(12.0f); - auto source = builder_.ConstantR4FromArray4D(s); - - builder_.SelectAndScatter(operand, ge_f32_, GetParam().window_dimensions, - GetParam().window_strides, GetParam().padding_type, - source, builder_.ConstantR0(0.0f), add_f32_); + auto source = builder_.ConstantFromArray(s); - auto e = ReferenceUtil::SelectAndScatter4DGePlus( - o, s, 0.0f, GetParam().window_dimensions, GetParam().window_strides, - GetParam().padding_type == Padding::kSame); + auto select_and_scatter = builder_.SelectAndScatter( + operand, ge_f32_, GetParam().window_dimensions, GetParam().window_strides, + GetParam().padding_type, source, builder_.ConstantR0(0.0f), + add_f32_); - ComputeAndCompareR4(&builder_, *e, {}, ErrorSpec(1e-5)); + ComputeAndCompare(&builder_, select_and_scatter, {}, ErrorSpec(1e-5)); } INSTANTIATE_TEST_CASE_P( SelectAndScatterTest_Instantiation, SelectAndScatterTest, - ::testing::Values(SelectAndScatterTestParam{{6, 6, 256, 128}, - {3, 3, 256, 128}, - Padding::kSame, - {3, 3, 1, 1}, - {2, 2, 1, 1}}, - SelectAndScatterTestParam{{7, 7, 256, 128}, - {3, 3, 256, 128}, - Padding::kValid, - {3, 3, 1, 1}, - {2, 2, 1, 1}}, - SelectAndScatterTestParam{{6, 7, 256, 128}, - {3, 3, 256, 128}, - Padding::kValid, - {2, 3, 1, 1}, - {2, 2, 1, 1}}, - SelectAndScatterTestParam{{6, 7, 256, 128}, - {2, 3, 256, 128}, - Padding::kValid, - {2, 3, 1, 1}, - {3, 2, 1, 1}}, - SelectAndScatterTestParam{{9, 9, 16, 128}, - {3, 3, 16, 128}, - Padding::kValid, - {3, 3, 1, 1}, - {3, 3, 1, 1}}, - SelectAndScatterTestParam{{3, 3, 4, 4}, - {1, 1, 4, 4}, - Padding::kValid, - {3, 3, 1, 1}, - {3, 3, 1, 1}}, - SelectAndScatterTestParam{{3, 3, 4, 4}, - {1, 1, 4, 4}, - Padding::kValid, - {3, 3, 1, 1}, - {3, 3, 1, 1}}, - SelectAndScatterTestParam{{9, 3, 4, 4}, - {3, 1, 4, 4}, - Padding::kValid, - {3, 3, 1, 1}, - {3, 3, 1, 1}}, - SelectAndScatterTestParam{{7, 3, 4, 4}, - {3, 1, 4, 4}, - Padding::kValid, - {3, 3, 1, 1}, - {2, 3, 1, 1}}, - SelectAndScatterTestParam{{1, 1, 5, 5}, - {1, 1, 5, 5}, - Padding::kSame, - {3, 3, 1, 1}, - {3, 3, 1, 1}}, - SelectAndScatterTestParam{{7, 7, 8, 256}, - {4, 4, 8, 256}, - Padding::kSame, - {2, 2, 1, 1}, - {2, 2, 1, 1}})); + ::testing::Values( + SelectAndScatterTestParam{{6, 6, 6, 4, 4}, + {3, 3, 3, 4, 4}, + Padding::kSame, + {3, 3, 3, 1, 1}, + {2, 2, 2, 1, 1}}, + SelectAndScatterTestParam{{7, 7, 7, 4, 4}, + {3, 3, 3, 4, 4}, + Padding::kValid, + {3, 3, 3, 1, 1}, + {2, 2, 2, 1, 1}}, + + SelectAndScatterTestParam{{8, 8, 8, 4, 4}, + {1, 3, 3, 4, 4}, + Padding::kValid, + {8, 4, 4, 1, 1}, + {1, 2, 2, 1, 1}}, + SelectAndScatterTestParam{{6, 6, 256, 128}, + {3, 3, 256, 128}, + Padding::kSame, + {3, 3, 1, 1}, + {2, 2, 1, 1}}, + SelectAndScatterTestParam{{7, 7, 256, 128}, + {3, 3, 256, 128}, + Padding::kValid, + {3, 3, 1, 1}, + {2, 2, 1, 1}}, + SelectAndScatterTestParam{{6, 7, 256, 128}, + {3, 3, 256, 128}, + Padding::kValid, + {2, 3, 1, 1}, + {2, 2, 1, 1}}, + SelectAndScatterTestParam{{6, 7, 256, 128}, + {2, 3, 256, 128}, + Padding::kValid, + {2, 3, 1, 1}, + {3, 2, 1, 1}}, + SelectAndScatterTestParam{{9, 9, 16, 128}, + {3, 3, 16, 128}, + Padding::kValid, + {3, 3, 1, 1}, + {3, 3, 1, 1}}, + SelectAndScatterTestParam{{3, 3, 4, 4}, + {1, 1, 4, 4}, + Padding::kValid, + {3, 3, 1, 1}, + {3, 3, 1, 1}}, + SelectAndScatterTestParam{{3, 3, 4, 4}, + {1, 1, 4, 4}, + Padding::kValid, + {3, 3, 1, 1}, + {3, 3, 1, 1}}, + SelectAndScatterTestParam{{9, 3, 4, 4}, + {3, 1, 4, 4}, + Padding::kValid, + {3, 3, 1, 1}, + {3, 3, 1, 1}}, + SelectAndScatterTestParam{{7, 3, 4, 4}, + {3, 1, 4, 4}, + Padding::kValid, + {3, 3, 1, 1}, + {2, 3, 1, 1}}, + SelectAndScatterTestParam{{1, 1, 5, 5}, + {1, 1, 5, 5}, + Padding::kSame, + {3, 3, 1, 1}, + {3, 3, 1, 1}}, + SelectAndScatterTestParam{{7, 7, 8, 256}, + {4, 4, 8, 256}, + Padding::kSame, + {2, 2, 1, 1}, + {2, 2, 1, 1}}, + SelectAndScatterTestParam{ + {6, 4, 4}, {3, 4, 4}, Padding::kSame, {3, 1, 1}, {2, 1, 1}}, + SelectAndScatterTestParam{ + {6, 256, 128}, {3, 256, 128}, Padding::kSame, {3, 1, 1}, {2, 1, 1}}, + SelectAndScatterTestParam{{7, 256, 128}, + {3, 256, 128}, + Padding::kValid, + {3, 1, 1}, + {2, 1, 1}}, + SelectAndScatterTestParam{{6, 256, 128}, + {3, 256, 128}, + Padding::kValid, + {2, 1, 1}, + {2, 1, 1}}, + SelectAndScatterTestParam{{6, 256, 128}, + {2, 256, 128}, + Padding::kValid, + {2, 1, 1}, + {3, 1, 1}}, + SelectAndScatterTestParam{ + {9, 16, 128}, {3, 16, 128}, Padding::kValid, {3, 1, 1}, {3, 1, 1}}, + SelectAndScatterTestParam{ + {3, 4, 4}, {1, 4, 4}, Padding::kValid, {3, 1, 1}, {3, 1, 1}}, + SelectAndScatterTestParam{ + {3, 4, 4}, {1, 4, 4}, Padding::kValid, {3, 1, 1}, {3, 1, 1}}, + SelectAndScatterTestParam{ + {9, 4, 4}, {3, 4, 4}, Padding::kValid, {3, 1, 1}, {3, 1, 1}}, + SelectAndScatterTestParam{ + {7, 4, 4}, {3, 4, 4}, Padding::kValid, {3, 1, 1}, {2, 1, 1}}, + SelectAndScatterTestParam{ + {1, 5, 5}, {1, 5, 5}, Padding::kSame, {3, 1, 1}, {3, 1, 1}}, + SelectAndScatterTestParam{ + {7, 8, 256}, {4, 8, 256}, Padding::kSame, {2, 1, 1}, {2, 1, 1}})); // Test for F32 1D array, with a zero-element input. XLA_TEST_F(SelectAndScatterTest, R1S0F32) { diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc index 8b10aef5b81c18648b6e255445d66a6d195f8a76..0e90a323583de7336556c203a4b46fc14b53454d 100644 --- a/tensorflow/compiler/xla/tests/test_utils.cc +++ b/tensorflow/compiler/xla/tests/test_utils.cc @@ -34,7 +34,7 @@ void PopulateWithRandomFloatingPointData(Literal* literal) { TF_CHECK_OK(literal->Populate( [&](tensorflow::gtl::ArraySlice indices) { // Generate a random uniforma number from -0.0625 and 0.0625 and bias it - // with a position dependent nubmer with mean 0.037109375. These number + // with a position dependent number with mean 0.037109375. These number // should allow for long chains of accumulation without being too close // to zero or to large to accumulate all numbers accurately. return (generator(engine) - 1.0625) + diff --git a/tensorflow/compiler/xla/tests/unary_op_test.cc b/tensorflow/compiler/xla/tests/unary_op_test.cc index fa4192e9281784a4a3063601afe89fba6a9dac18..835e2d7e5594d7c8c6e523f9806e32dce23a87e9 100644 --- a/tensorflow/compiler/xla/tests/unary_op_test.cc +++ b/tensorflow/compiler/xla/tests/unary_op_test.cc @@ -215,5 +215,23 @@ XLA_TEST_F(UnaryOpTest, SignAbsTestR2) { ComputeAndCompareR2(&builder, {{0, 0}, {0, 0}}, {}); } +XLA_TEST_F(UnaryOpTest, ConvertElementTypePredToS32) { + ComputationBuilder builder(client_, TestName()); + auto lhs = builder.ConstantR1({0, 1}); + auto rhs = builder.ConstantR1({1, 1}); + builder.ConvertElementType(builder.Eq(lhs, rhs), S32); + + ComputeAndCompareR1(&builder, {0, 1}, {}); +} + +XLA_TEST_F(UnaryOpTest, ConvertElementTypePredToF32) { + ComputationBuilder builder(client_, TestName()); + auto lhs = builder.ConstantR1({0, 1}); + auto rhs = builder.ConstantR1({1, 1}); + builder.ConvertElementType(builder.Eq(lhs, rhs), F32); + + ComputeAndCompareR1(&builder, {0.0, 1.0}, {}); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc index 146fbadcb68e6c5d0fa0856c1c98b399df72051f..1d2f436194a921c8d1b23732e2b4be11b59ac043 100644 --- a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc +++ b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc @@ -110,7 +110,8 @@ void ExecuteAndFetchProfile(string* profile_output, LocalClient* client, Executable* executable = local_executable->executable(); HloExecutionProfile hlo_execution_profile( - &executable->hlo_profile_printer(), &executable->hlo_profile_index_map()); + &executable->hlo_profile_printer_data(), + &executable->hlo_profile_index_map()); TF_ASSERT_OK_AND_ASSIGN( Backend::StreamPtr stream_ptr, diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc index 1c68e271e0f75d8facc36bd0878190f3db512972..42e7f91f26f3454b247d95d328c3422c44131c43 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc +++ b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc @@ -931,7 +931,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, return false; } instruction = builder->AddInstruction(HloInstruction::CreateOutfeed( - shape, operands[0], config ? *config : "")); + operands[0]->shape(), operands[0], config ? *config : "")); break; } case HloOpcode::kRng: { diff --git a/tensorflow/compiler/xla/util.h b/tensorflow/compiler/xla/util.h index bb2db2010c5e0da6ed3fde628eb5928d555815b2..1d7dd344493f91d84714c72783c95a49ad72ad1c 100644 --- a/tensorflow/compiler/xla/util.h +++ b/tensorflow/compiler/xla/util.h @@ -398,13 +398,11 @@ std::vector> CommonFactors( // Removes illegal characters from filenames. string SanitizeFileName(string file_name); -// Simple wrapper around std::all_of. template bool c_all_of(Container container, Predicate predicate) { return std::all_of(std::begin(container), std::end(container), predicate); } -// Simple wrapper around std::transform. template OutputIterator c_transform(InputContainer input_container, @@ -414,7 +412,6 @@ OutputIterator c_transform(InputContainer input_container, output_iterator, unary_op); } -// Simple wrapper around std::copy_if. template OutputIterator c_copy_if(InputContainer input_container, OutputIterator output_iterator, @@ -423,6 +420,11 @@ OutputIterator c_copy_if(InputContainer input_container, output_iterator, predicate); } +template +void c_sort(InputContainer& input_container, Comparator comparator) { + std::sort(input_container.begin(), input_container.end(), comparator); +} + } // namespace xla #define XLA_LOG_LINES(SEV, STRING) \ diff --git a/tensorflow/compiler/xla/window_util.cc b/tensorflow/compiler/xla/window_util.cc index 224eb2a20c8fc5ac4bfe2bb92a65a3bd178dbaf6..55f42ed3a454baa3f8b6adf60a78582488733e9b 100644 --- a/tensorflow/compiler/xla/window_util.cc +++ b/tensorflow/compiler/xla/window_util.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/lib/strings/stringprintf.h" @@ -25,6 +26,26 @@ limitations under the License. namespace xla { namespace window_util { +Window MakeWindow(tensorflow::gtl::ArraySlice sizes) { + Window window; + for (int64 size : sizes) { + auto* dimension = window.add_dimensions(); + dimension->set_size(size); + dimension->set_stride(1); + } + return window; +} + +PaddingConfig MakeSymmetricPadding(tensorflow::gtl::ArraySlice sizes) { + PaddingConfig config; + for (int64 size : sizes) { + auto* dimension = config.add_dimensions(); + dimension->set_edge_padding_low(size); + dimension->set_edge_padding_high(size); + } + return config; +} + /* static */ string ToString(const WindowDimension& dim) { using tensorflow::strings::StrAppend; using tensorflow::strings::StrCat; @@ -114,13 +135,21 @@ bool HasPadding(const Window& window) { return false; } -bool HasEvenPadding(const Window& window) { +bool HasSymmetricPadding(const Window& window) { return std::all_of(window.dimensions().begin(), window.dimensions().end(), [](const WindowDimension& dim) { return dim.padding_low() == dim.padding_high(); }); } +bool HasSymmetricPadding(const PaddingConfig& padding_config) { + return std::all_of(padding_config.dimensions().begin(), + padding_config.dimensions().end(), + [](const PaddingConfig::PaddingConfigDimension& dim) { + return dim.edge_padding_low() == dim.edge_padding_high(); + }); +} + bool HasNegativePadding(const Window& window) { return std::any_of(window.dimensions().begin(), window.dimensions().end(), [](const WindowDimension& dim) { diff --git a/tensorflow/compiler/xla/window_util.h b/tensorflow/compiler/xla/window_util.h index 17c388fc0b551ec227802434b7db435c4d25d985..ba473e2c8c35202865a9a4981da7653fe1d6f552 100644 --- a/tensorflow/compiler/xla/window_util.h +++ b/tensorflow/compiler/xla/window_util.h @@ -18,10 +18,21 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/gtl/array_slice.h" namespace xla { namespace window_util { +// Creates a window with the given sizes in the dimensions and all strides set +// to 1. +Window MakeWindow(tensorflow::gtl::ArraySlice sizes); + +// Creates a padding config with symmetrical padding in each dimension, of value +// given by sizes; e.g. {0, 1, 2} would create a R3 padding config that had zero +// pixels of padding in dimension 0, one pixel of padding symmetrically, on each +// side of dimension 1, and two pixels of padding symmetrically on dimension 2. +PaddingConfig MakeSymmetricPadding(tensorflow::gtl::ArraySlice sizes); + string ToString(const WindowDimension& dim); string ToString(const Window& window); @@ -32,9 +43,14 @@ string ToString(const Window& window); bool HasStride(const Window& window); bool HasPadding(const Window& window); -bool HasEvenPadding(const Window& window); +bool HasSymmetricPadding(const Window& window); bool HasNegativePadding(const Window& window); +// As with HasSymmetricPadding(Window) above, returns whether the "padding low" +// is equivalent to the "padding high" for all dimensions, but works on a +// padding configuration. +bool HasSymmetricPadding(const PaddingConfig& padding_config); + bool HasBaseDilation(const Window& window); bool HasWindowDilation(const Window& window); bool HasDilation(const Window& window); diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto index fda1a4c27b6dea1b7e4dee76de976f93ba61c007..e1ed08c8480fa73e9c5ff914bb9f5e38f1ce96e9 100644 --- a/tensorflow/compiler/xla/xla.proto +++ b/tensorflow/compiler/xla/xla.proto @@ -179,6 +179,10 @@ message DebugOptions { // ops. bool xla_gpu_use_cudnn_batchnorm = 94; + // Dump compilation artifacts, before hlo passes are executed, in binary proto + // into this directory. + string xla_dump_prepass_hlo_proto_to = 95; + // Extra options to pass to the compilation backend; specific interpretation // of these values is left to the backend. map xla_backend_extra_options = 500; diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD index 8bed0fabd743c9cf9a51fe574401ae42730d15b4..f1e54432faa3c59ada0d89c472bcdcc28f6d0970 100644 --- a/tensorflow/contrib/BUILD +++ b/tensorflow/contrib/BUILD @@ -111,7 +111,6 @@ cc_library( name = "contrib_kernels", visibility = ["//visibility:public"], deps = [ - "//tensorflow/contrib/batching:batch_ops_kernels", "//tensorflow/contrib/boosted_trees:boosted_trees_kernels", "//tensorflow/contrib/coder:all_kernels", "//tensorflow/contrib/cudnn_rnn:cudnn_rnn_kernels", @@ -134,7 +133,6 @@ cc_library( name = "contrib_ops_op_lib", visibility = ["//visibility:public"], deps = [ - "//tensorflow/contrib/batching:batch_ops_op_lib", "//tensorflow/contrib/boosted_trees:boosted_trees_ops_op_lib", "//tensorflow/contrib/coder:all_ops", "//tensorflow/contrib/cudnn_rnn:cudnn_rnn_ops_op_lib", diff --git a/tensorflow/contrib/__init__.py b/tensorflow/contrib/__init__.py index f600a8a99816586d6bd7d7ab51354888c435e739..8f6a3cb1ca4544cae6f42fd1727d509af9fc0233 100644 --- a/tensorflow/contrib/__init__.py +++ b/tensorflow/contrib/__init__.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function # Add projects here, they will show up under tf.contrib. +from tensorflow.contrib import batching from tensorflow.contrib import bayesflow from tensorflow.contrib import cloud from tensorflow.contrib import cluster_resolver diff --git a/tensorflow/contrib/android/asset_manager_filesystem.h b/tensorflow/contrib/android/asset_manager_filesystem.h index 2b43939f148e360945e5d488d148fcb2c13008a6..665304b5eef1f8a3633c8c522259e20d744b1808 100644 --- a/tensorflow/contrib/android/asset_manager_filesystem.h +++ b/tensorflow/contrib/android/asset_manager_filesystem.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_ANDROID_ASSET_MANAGER_FILESYSTEM_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_ANDROID_ASSET_MANAGER_FILESYSTEM_H_ +#ifndef TENSORFLOW_CONTRIB_ANDROID_ASSET_MANAGER_FILESYSTEM_H_ +#define TENSORFLOW_CONTRIB_ANDROID_ASSET_MANAGER_FILESYSTEM_H_ #include #include @@ -79,4 +79,4 @@ class AssetManagerFileSystem : public FileSystem { }; } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_ANDROID_ASSET_MANAGER_FILESYSTEM_H_ +#endif // TENSORFLOW_CONTRIB_ANDROID_ASSET_MANAGER_FILESYSTEM_H_ diff --git a/tensorflow/contrib/batching/BUILD b/tensorflow/contrib/batching/BUILD index cd98f0e70335db715b8cb6c76a9d7df3e2280552..ee67909133fc26ba98355db05a4b90d3dfa6b97b 100644 --- a/tensorflow/contrib/batching/BUILD +++ b/tensorflow/contrib/batching/BUILD @@ -67,48 +67,14 @@ load( ) load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") -tf_custom_op_library( - name = "python/ops/_batch_ops.so", - srcs = ["ops/batch_ops.cc"], - deps = [ - "//tensorflow/contrib/batching/kernels:batch_kernels", - ], -) - -tf_gen_op_libs( - op_lib_names = ["batch_ops"], -) - -tf_gen_op_wrapper_py( - name = "batch_ops", - deps = [":batch_ops_op_lib"], -) - -tf_kernel_library( - name = "batch_ops_kernels", - deps = [ - "//tensorflow/contrib/batching/kernels:batch_kernels", - "//tensorflow/contrib/batching/util:periodic_function", - "//tensorflow/core/kernels:concat_lib", - "//tensorflow/core/kernels:ops_util", - "//tensorflow/core/kernels:split_lib", - ], - alwayslink = 1, -) - -tf_custom_op_py_library( +py_library( name = "batch_py", srcs = glob(["python/ops/*.py"]) + ["__init__.py"], - dso = [":python/ops/_batch_ops.so"], - kernels = [ - ":batch_ops_kernels", - ":batch_ops_op_lib", - ], srcs_version = "PY2AND3", deps = [ - ":batch_ops", "//tensorflow/contrib/util:util_py", "//tensorflow/python:array_ops", + "//tensorflow/python:batch_ops_gen", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:gradients", @@ -118,6 +84,14 @@ tf_custom_op_py_library( ], ) +cc_library( + name = "batch_ops_kernels", + deps = [ + "//tensorflow/core/kernels:batch_kernels", + ], + alwayslink = 1, +) + py_test( name = "batch_ops_test", size = "small", @@ -133,6 +107,7 @@ py_test( "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:dtypes", + "//tensorflow/python:framework", "//tensorflow/python:gradients", "//tensorflow/python:script_ops", ], diff --git a/tensorflow/contrib/batching/adaptive_shared_batch_scheduler.h b/tensorflow/contrib/batching/adaptive_shared_batch_scheduler.h index 60861f83f450d3f67f21a46bdfa3fda223b9d2b4..86250e6692004a12a1fa338767a5db1e4c2e4195 100644 --- a/tensorflow/contrib/batching/adaptive_shared_batch_scheduler.h +++ b/tensorflow/contrib/batching/adaptive_shared_batch_scheduler.h @@ -13,9 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_ADAPTIVE_SHARED_BATCH_SCHEDULER_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_ADAPTIVE_SHARED_BATCH_SCHEDULER_H_ +#ifndef TENSORFLOW_CONTRIB_BATCHING_ADAPTIVE_SHARED_BATCH_SCHEDULER_H_ +#define TENSORFLOW_CONTRIB_BATCHING_ADAPTIVE_SHARED_BATCH_SCHEDULER_H_ #include "tensorflow/core/kernels/batching_util/adaptive_shared_batch_scheduler.h" -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_ADAPTIVE_SHARED_BATCH_SCHEDULER_H_ +#endif // TENSORFLOW_CONTRIB_BATCHING_ADAPTIVE_SHARED_BATCH_SCHEDULER_H_ diff --git a/tensorflow/contrib/batching/basic_batch_scheduler.h b/tensorflow/contrib/batching/basic_batch_scheduler.h index 63ba8fcf45d8e6caad14c267bb19c0bc4eea20bf..d9b37da6933aa0847c229607f43d1d5d121a928c 100644 --- a/tensorflow/contrib/batching/basic_batch_scheduler.h +++ b/tensorflow/contrib/batching/basic_batch_scheduler.h @@ -13,9 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_BASIC_BATCH_SCHEDULER_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_BASIC_BATCH_SCHEDULER_H_ +#ifndef TENSORFLOW_CONTRIB_BATCHING_BASIC_BATCH_SCHEDULER_H_ +#define TENSORFLOW_CONTRIB_BATCHING_BASIC_BATCH_SCHEDULER_H_ #include "tensorflow/core/kernels/batching_util/basic_batch_scheduler.h" -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_BASIC_BATCH_SCHEDULER_H_ +#endif // TENSORFLOW_CONTRIB_BATCHING_BASIC_BATCH_SCHEDULER_H_ diff --git a/tensorflow/contrib/batching/batch_scheduler.h b/tensorflow/contrib/batching/batch_scheduler.h index 3afce2761f748136f4d556017823db8dbd4af50e..8e94e1fd8b969d4fef8dbc8c322557f9da3833e6 100644 --- a/tensorflow/contrib/batching/batch_scheduler.h +++ b/tensorflow/contrib/batching/batch_scheduler.h @@ -13,9 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_BATCH_SCHEDULER_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_BATCH_SCHEDULER_H_ +#ifndef TENSORFLOW_CONTRIB_BATCHING_BATCH_SCHEDULER_H_ +#define TENSORFLOW_CONTRIB_BATCHING_BATCH_SCHEDULER_H_ #include "tensorflow/core/kernels/batching_util/batch_scheduler.h" -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_BATCH_SCHEDULER_H_ +#endif // TENSORFLOW_CONTRIB_BATCHING_BATCH_SCHEDULER_H_ diff --git a/tensorflow/contrib/batching/kernels/BUILD b/tensorflow/contrib/batching/kernels/BUILD deleted file mode 100644 index 6e53dd9a5fc0201c5ed91d1eaf07f940e341fb5e..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/batching/kernels/BUILD +++ /dev/null @@ -1,34 +0,0 @@ -# Description: -# Contains kernels for the batching ops. - -package(default_visibility = ["//tensorflow:__subpackages__"]) - -licenses(["notice"]) # Apache 2.0 - -exports_files(["LICENSE"]) - -cc_library( - name = "batch_kernels", - srcs = ["batch_kernels.cc"], - deps = [ - "//tensorflow/contrib/batching:shared_batch_scheduler_hdrs", - "//tensorflow/contrib/batching/util:periodic_function_dynamic", - "//tensorflow/core:framework_headers_lib", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core/kernels:concat_lib_hdrs", - "//tensorflow/core/kernels:ops_util_hdrs", - "//tensorflow/core/kernels:split_lib_hdrs", - ], - alwayslink = 1, -) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), -) diff --git a/tensorflow/contrib/batching/ops/batch_ops.cc b/tensorflow/contrib/batching/ops/batch_ops.cc deleted file mode 100644 index 85e0ccba4aa372bdc21fb194263569b8b787bb6c..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/batching/ops/batch_ops.cc +++ /dev/null @@ -1,164 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/framework/common_shape_fns.h" -#include "tensorflow/core/framework/op.h" -#include "tensorflow/core/framework/shape_inference.h" - -namespace tensorflow { - -REGISTER_OP("Batch") - .Input("in_tensors: T") - .Output("batched_tensors: T") - .Output("batch_index: int64") - .Output("id: int64") - .Attr("num_batch_threads: int") - .Attr("max_batch_size: int") - .Attr("batch_timeout_micros: int") - .Attr("allowed_batch_sizes: list(int) = []") - .Attr("grad_timeout_micros: int") - .Attr("container: string = ''") - .Attr("shared_name: string = ''") - .Attr("batching_queue: string = ''") - .Attr("T: list(type)") - .SetShapeFn([](shape_inference::InferenceContext* c) { - std::vector in_shapes; - TF_RETURN_IF_ERROR(c->input("in_tensors", &in_shapes)); - std::vector out_shapes(in_shapes.size()); - for (int i = 0; i < in_shapes.size(); ++i) { - TF_RETURN_IF_ERROR( - c->ReplaceDim(in_shapes[i], 0, c->UnknownDim(), &out_shapes[i])); - } - TF_RETURN_IF_ERROR(c->set_output("batched_tensors", out_shapes)); - TF_RETURN_IF_ERROR(c->set_output("id", {c->Scalar()})); - TF_RETURN_IF_ERROR(c->set_output( - "batch_index", - {c->MakeShape({shape_inference::DimensionOrConstant(c->UnknownDim()), - shape_inference::DimensionOrConstant(3)})})); - return Status::OK(); - }) - .Doc(R"doc( -Batches all input tensors nondeterministically. - -When many instances of this Op are being run concurrently with the same -container/shared_name in the same device, some will output zero-shaped Tensors -and others will output Tensors of size up to max_batch_size. - -All Tensors in in_tensors are batched together (so, for example, labels and -features should be batched with a single instance of this operation. - -Each invocation of batch emits an `id` scalar which will be used to identify -this particular invocation when doing unbatch or its gradient. - -Each op which emits a non-empty batch will also emit a non-empty batch_index -Tensor, which, is a [K, 3] matrix where each row contains the invocation's id, -start, and length of elements of each set of Tensors present in batched_tensors. - -Batched tensors are concatenated along the first dimension, and all tensors in -in_tensors must have the first dimension of the same size. - -in_tensors: The tensors to be batched. -num_batch_threads: Number of scheduling threads for processing batches of work. - Determines the number of batches processed in parallel. -max_batch_size: Batch sizes will never be bigger than this. -batch_timeout_micros: Maximum number of microseconds to wait before outputting - an incomplete batch. -allowed_batch_sizes: Optional list of allowed batch sizes. If left empty, does - nothing. Otherwise, supplies a list of batch sizes, causing the op to pad - batches up to one of those sizes. The entries must increase monotonically, and - the final entry must equal max_batch_size. -grad_timeout_micros: The timeout to use for the gradient. See Unbatch. -batched_tensors: Either empty tensors or a batch of concatenated Tensors. -batch_index: If out_tensors is non-empty, has information to invert it. -container: Controls the scope of sharing of this batch. -id: always contains a scalar with a unique ID for this invocation of Batch. -shared_name: Concurrently running instances of batch in the same device with the - same container and shared_name will batch their elements together. If left - empty, the op name will be used as the shared name. -T: the types of tensors to be batched. -)doc"); - -REGISTER_OP("Unbatch") - .Input("batched_tensor: T") - .Input("batch_index: int64") - .Input("id: int64") - .Output("unbatched_tensor: T") - .Attr("timeout_micros: int") - .Attr("container: string = ''") - .Attr("shared_name: string = ''") - .Attr("T: type") - .SetShapeFn([](shape_inference::InferenceContext* c) { - shape_inference::ShapeHandle out_shape; - TF_RETURN_IF_ERROR( - c->ReplaceDim(c->input(0), 0, c->UnknownDim(), &out_shape)); - c->set_output(0, out_shape); - return Status::OK(); - }) - .Doc(R"doc( -Reverses the operation of Batch for a single output Tensor. - -An instance of Unbatch either receives an empty batched_tensor, in which case it -asynchronously waits until the values become available from a concurrently -running instance of Unbatch with the same container and shared_name, or receives -a non-empty batched_tensor in which case it finalizes all other concurrently -running instances and outputs its own element from the batch. - -batched_tensor: The possibly transformed output of Batch. The size of the first - dimension should remain unchanged by the transformations for the operation to - work. -batch_index: The matching batch_index obtained from Batch. -id: The id scalar emitted by Batch. -unbatched_tensor: The Tensor corresponding to this execution. -timeout_micros: Maximum amount of time (in microseconds) to wait to receive the - batched input tensor associated with a given invocation of the op. -container: Container to control resource sharing. -shared_name: Instances of Unbatch with the same container and shared_name are - assumed to possibly belong to the same batch. If left empty, the op name will - be used as the shared name. -)doc"); - -REGISTER_OP("UnbatchGrad") - .Input("original_input: T") - .Input("batch_index: int64") - .Input("grad: T") - .Input("id: int64") - .Output("batched_grad: T") - .Attr("container: string = ''") - .Attr("shared_name: string = ''") - .Attr("T: type") - .SetShapeFn([](shape_inference::InferenceContext* c) { - c->set_output(0, c->UnknownShapeOfRank(c->Rank(c->input(2)))); - return Status::OK(); - }) - .Doc(R"doc( -Gradient of Unbatch. - -Acts like Batch but using the given batch_index index of batching things as they -become available. This ensures that the gradients are propagated back in the -same session which did the forward pass. - -original_input: The input to the Unbatch operation this is the gradient of. -batch_index: The batch_index given to the Unbatch operation this is the gradient -of. -grad: The downstream gradient. -id: The id scalar emitted by Batch. -batched_grad: The return value, either an empty tensor or the batched gradient. -container: Container to control resource sharing. -shared_name: Instances of UnbatchGrad with the same container and shared_name - are assumed to possibly belong to the same batch. If left empty, the op name - will be used as the shared name. - )doc"); - -} // namespace tensorflow diff --git a/tensorflow/contrib/batching/python/ops/batch_ops.py b/tensorflow/contrib/batching/python/ops/batch_ops.py index cee4d7b4a9710e285957f27ace7c2762c473c5c7..4e0b3f9af989c414ad88c510c1bfd180dbadd5ea 100644 --- a/tensorflow/contrib/batching/python/ops/batch_ops.py +++ b/tensorflow/contrib/batching/python/ops/batch_ops.py @@ -18,18 +18,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.batching.ops import gen_batch_ops +from tensorflow.python.framework import ops +from tensorflow.python.ops import gen_batch_ops # go/tf-wildcard-import # pylint: disable=wildcard-import -from tensorflow.contrib.batching.ops.gen_batch_ops import * +from tensorflow.python.ops.gen_batch_ops import * # pylint: enable=wildcard-import -from tensorflow.contrib.util import loader -from tensorflow.python.framework import ops -from tensorflow.python.platform import resource_loader - - -_batch_ops = loader.load_op_library( - resource_loader.get_path_to_datafile("_batch_ops.so")) @ops.RegisterGradient("Batch") diff --git a/tensorflow/contrib/batching/shared_batch_scheduler.h b/tensorflow/contrib/batching/shared_batch_scheduler.h index 7eb1e20c42283a38564f7686db0015f153f469ed..83a59695d7db7e0a24fb437a3ea71a4d9e23c93f 100644 --- a/tensorflow/contrib/batching/shared_batch_scheduler.h +++ b/tensorflow/contrib/batching/shared_batch_scheduler.h @@ -13,9 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_SHARED_BATCH_SCHEDULER_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_SHARED_BATCH_SCHEDULER_H_ +#ifndef TENSORFLOW_CONTRIB_BATCHING_SHARED_BATCH_SCHEDULER_H_ +#define TENSORFLOW_CONTRIB_BATCHING_SHARED_BATCH_SCHEDULER_H_ #include "tensorflow/core/kernels/batching_util/shared_batch_scheduler.h" -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_SHARED_BATCH_SCHEDULER_H_ +#endif // TENSORFLOW_CONTRIB_BATCHING_SHARED_BATCH_SCHEDULER_H_ diff --git a/tensorflow/contrib/batching/test_util/fake_clock_env.h b/tensorflow/contrib/batching/test_util/fake_clock_env.h index ced27a88336324fb8c4be490138291d9234693f9..40a39a5569854350c72a47102f3dac07b362ce8e 100644 --- a/tensorflow/contrib/batching/test_util/fake_clock_env.h +++ b/tensorflow/contrib/batching/test_util/fake_clock_env.h @@ -13,9 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_TEST_UTIL_FAKE_CLOCK_ENV_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_TEST_UTIL_FAKE_CLOCK_ENV_H_ +#ifndef TENSORFLOW_CONTRIB_BATCHING_TEST_UTIL_FAKE_CLOCK_ENV_H_ +#define TENSORFLOW_CONTRIB_BATCHING_TEST_UTIL_FAKE_CLOCK_ENV_H_ #include "tensorflow/core/kernels/batching_util/fake_clock_env.h" -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_TEST_UTIL_FAKE_CLOCK_ENV_H_ +#endif // TENSORFLOW_CONTRIB_BATCHING_TEST_UTIL_FAKE_CLOCK_ENV_H_ diff --git a/tensorflow/contrib/batching/util/periodic_function.h b/tensorflow/contrib/batching/util/periodic_function.h index fb61bc2eea2ec6eb560670148611c66ddc3d73df..aa2ed0a385125fa090a7a56b6339a87eb2d57b1f 100644 --- a/tensorflow/contrib/batching/util/periodic_function.h +++ b/tensorflow/contrib/batching/util/periodic_function.h @@ -12,9 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_UTIL_PERIODIC_FUNCTION_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_UTIL_PERIODIC_FUNCTION_H_ +#ifndef TENSORFLOW_CONTRIB_BATCHING_UTIL_PERIODIC_FUNCTION_H_ +#define TENSORFLOW_CONTRIB_BATCHING_UTIL_PERIODIC_FUNCTION_H_ #include "tensorflow/core/kernels/batching_util/periodic_function.h" -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_UTIL_PERIODIC_FUNCTION_H_ +#endif // TENSORFLOW_CONTRIB_BATCHING_UTIL_PERIODIC_FUNCTION_H_ diff --git a/tensorflow/contrib/boosted_trees/BUILD b/tensorflow/contrib/boosted_trees/BUILD index 392ac7fa1ce600a64ee3b941b70b01447645e4aa..6fdcd0f996ee011842a5add79f06264a28a2145c 100644 --- a/tensorflow/contrib/boosted_trees/BUILD +++ b/tensorflow/contrib/boosted_trees/BUILD @@ -196,6 +196,7 @@ py_test( name = "quantile_ops_test", size = "small", srcs = ["python/kernel_tests/quantile_ops_test.py"], + shard_count = 3, srcs_version = "PY2AND3", deps = [ ":quantile_ops_py", diff --git a/tensorflow/contrib/boosted_trees/kernels/quantile_ops.cc b/tensorflow/contrib/boosted_trees/kernels/quantile_ops.cc index 8600c8c53caa5fd4274ba6730fc764d8315d680c..88f30064076d1b9410665e06ca27e20d14c6dde0 100644 --- a/tensorflow/contrib/boosted_trees/kernels/quantile_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/quantile_ops.cc @@ -46,6 +46,7 @@ const char* const kHandleName = "handle"; const char* const kNextStampTokenName = "next_stamp_token"; const char* const kStampTokenName = "stamp_token"; const char* const kAreBucketsReadyName = "are_buckets_ready"; +const char* const kGenerateQuantiles = "generate_quantiles"; // Names for sparse arguments. const char* const kNumSparseFeaturesName = "num_sparse_features"; const char* const kSparseBucketsName = "sparse_buckets"; @@ -182,6 +183,16 @@ std::vector GenerateBoundaries(const QuantileStream& stream, return boundaries; } +// Generates quantiles on a finalized QuantileStream. +std::vector GenerateQuantiles(const QuantileStream& stream, + int num_quantiles) { + // Do not de-dup boundaries. Exactly num_quantiles+1 boundary values + // will be returned. + std::vector boundaries = stream.GenerateQuantiles(num_quantiles); + CHECK_EQ(boundaries.size(), num_quantiles + 1); + return boundaries; +} + // Copies quantiles to output list. void CopyBoundaries(OpKernelContext* const context, const std::vector& boundaries, const int64 index, @@ -224,6 +235,8 @@ class CreateQuantileAccumulatorOp : public OpKernel { OP_REQUIRES_OK(context, context->GetAttr(kNumQuantilesName, &num_quantiles_)); OP_REQUIRES_OK(context, context->GetAttr(kMaxElementsName, &max_elements_)); + OP_REQUIRES_OK(context, + context->GetAttr(kGenerateQuantiles, &generate_quantiles_)); } void Compute(OpKernelContext* context) override { @@ -231,9 +244,9 @@ class CreateQuantileAccumulatorOp : public OpKernel { // other exceptions. If one already exists, it unrefs the new one. const Tensor* stamp_token_t; OP_REQUIRES_OK(context, context->input(kStampTokenName, &stamp_token_t)); - auto result = - new QuantileStreamResource(epsilon_, num_quantiles_, max_elements_, - stamp_token_t->scalar()()); + auto result = new QuantileStreamResource(epsilon_, num_quantiles_, + max_elements_, generate_quantiles_, + stamp_token_t->scalar()()); auto status = CreateResource(context, HandleFromInput(context, 0), result); if (!status.ok() && status.code() != tensorflow::error::ALREADY_EXISTS) { OP_REQUIRES(context, false, status); @@ -246,6 +259,7 @@ class CreateQuantileAccumulatorOp : public OpKernel { // An upperbound on the number of enteries that the summaries might have // for a feature. int64 max_elements_; + bool generate_quantiles_; }; REGISTER_KERNEL_BUILDER(Name("CreateQuantileAccumulator").Device(DEVICE_CPU), @@ -597,10 +611,15 @@ class QuantileAccumulatorFlushOp : public OpKernel { << "Passed stamp token: " << stamp_token << " " << "Current token: " << streams_resource->stamp(); QuantileStream* stream = streams_resource->stream(stamp_token); + bool generate_quantiles = streams_resource->generate_quantiles(); stream->Finalize(); + streams_resource->set_boundaries( stamp_token, - GenerateBoundaries(*stream, streams_resource->num_quantiles())); + generate_quantiles + ? GenerateQuantiles(*stream, streams_resource->num_quantiles()) + : GenerateBoundaries(*stream, streams_resource->num_quantiles())); + streams_resource->Reset(next_stamp_token); } }; diff --git a/tensorflow/contrib/boosted_trees/lib/learner/common/accumulators/class-partition-key.h b/tensorflow/contrib/boosted_trees/lib/learner/common/accumulators/class-partition-key.h index e1bef0278846e7ff6abc91e8c57f780af45e8b41..3c54868951a6db93a8b685c8da4dfc78996b7b1f 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/common/accumulators/class-partition-key.h +++ b/tensorflow/contrib/boosted_trees/lib/learner/common/accumulators/class-partition-key.h @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_ACCUMULATORS_CLASS_PARTITION_KEY_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_ACCUMULATORS_CLASS_PARTITION_KEY_H_ +#ifndef TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_ACCUMULATORS_CLASS_PARTITION_KEY_H_ +#define TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_ACCUMULATORS_CLASS_PARTITION_KEY_H_ #include "tensorflow/core/lib/hash/hash.h" @@ -58,4 +58,4 @@ struct ClassPartitionKey { } // namespace boosted_trees } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_ACCUMULATORS_CLASS_PARTITION_KEY_H_ +#endif // TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_ACCUMULATORS_CLASS_PARTITION_KEY_H_ diff --git a/tensorflow/contrib/boosted_trees/lib/learner/common/accumulators/feature-stats-accumulator.h b/tensorflow/contrib/boosted_trees/lib/learner/common/accumulators/feature-stats-accumulator.h index 3814edb5675be74794a08e00becb649f8fc53fdb..ec4e7c52bb5f4536a50192e1b5fcc019dd7b2511 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/common/accumulators/feature-stats-accumulator.h +++ b/tensorflow/contrib/boosted_trees/lib/learner/common/accumulators/feature-stats-accumulator.h @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_ACCUMULATORS_FEATURE_STATS_ACCUMULATOR_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_ACCUMULATORS_FEATURE_STATS_ACCUMULATOR_H_ +#ifndef TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_ACCUMULATORS_FEATURE_STATS_ACCUMULATOR_H_ +#define TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_ACCUMULATORS_FEATURE_STATS_ACCUMULATOR_H_ #include #include @@ -79,4 +79,4 @@ class FeatureStatsAccumulator { } // namespace boosted_trees } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_ACCUMULATORS_FEATURE_STATS_ACCUMULATOR_H_ +#endif // TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_ACCUMULATORS_FEATURE_STATS_ACCUMULATOR_H_ diff --git a/tensorflow/contrib/boosted_trees/lib/learner/common/partitioners/example_partitioner.h b/tensorflow/contrib/boosted_trees/lib/learner/common/partitioners/example_partitioner.h index aed0d9fdac108dff4576cc1563dae420340387be..37a71037041445e6a6fcf6290015b93cffef1618 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/common/partitioners/example_partitioner.h +++ b/tensorflow/contrib/boosted_trees/lib/learner/common/partitioners/example_partitioner.h @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_PARTITIONERS_EXAMPLE_PARTITIONER_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_PARTITIONERS_EXAMPLE_PARTITIONER_H_ +#ifndef TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_PARTITIONERS_EXAMPLE_PARTITIONER_H_ +#define TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_PARTITIONERS_EXAMPLE_PARTITIONER_H_ #include #include "tensorflow/contrib/boosted_trees/lib/trees/decision_tree.h" @@ -50,4 +50,4 @@ class ExamplePartitioner { } // namespace boosted_trees } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_PARTITIONERS_EXAMPLE_PARTITIONER_H_ +#endif // TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_PARTITIONERS_EXAMPLE_PARTITIONER_H_ diff --git a/tensorflow/contrib/boosted_trees/lib/learner/common/stats/feature-split-candidate.h b/tensorflow/contrib/boosted_trees/lib/learner/common/stats/feature-split-candidate.h index 339c2e0fded10e6a7b140da62e152e2868ffd164..382b85cf0b2c146f82fa79551c569b9c70d9b7a6 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/common/stats/feature-split-candidate.h +++ b/tensorflow/contrib/boosted_trees/lib/learner/common/stats/feature-split-candidate.h @@ -13,8 +13,8 @@ // limitations under the License. // // ============================================================================= -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_FEATURE_SPLIT_CANDIDATE_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_FEATURE_SPLIT_CANDIDATE_H_ +#ifndef TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_FEATURE_SPLIT_CANDIDATE_H_ +#define TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_FEATURE_SPLIT_CANDIDATE_H_ #include "tensorflow/contrib/boosted_trees/lib/learner/common/stats/split-stats.h" #include "tensorflow/contrib/boosted_trees/proto/tree_config.pb.h" @@ -58,4 +58,4 @@ struct FeatureSplitCandidate { } // namespace boosted_trees } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_FEATURE_SPLIT_CANDIDATE_H_ +#endif // TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_FEATURE_SPLIT_CANDIDATE_H_ diff --git a/tensorflow/contrib/boosted_trees/lib/learner/common/stats/gradient-stats.h b/tensorflow/contrib/boosted_trees/lib/learner/common/stats/gradient-stats.h index 34e3ddb777242553d62035a51f1aec33d0f9ba54..3dd03215d88abc223a2d081d11901ffd3fb7aaa9 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/common/stats/gradient-stats.h +++ b/tensorflow/contrib/boosted_trees/lib/learner/common/stats/gradient-stats.h @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_GRADIENT_STATS_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_GRADIENT_STATS_H_ +#ifndef TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_GRADIENT_STATS_H_ +#define TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_GRADIENT_STATS_H_ #include @@ -190,4 +190,4 @@ inline GradientStats operator-(const GradientStats& a, const GradientStats& b) { } // namespace boosted_trees } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_GRADIENT_STATS_H_ +#endif // TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_GRADIENT_STATS_H_ diff --git a/tensorflow/contrib/boosted_trees/lib/learner/common/stats/node-stats.h b/tensorflow/contrib/boosted_trees/lib/learner/common/stats/node-stats.h index 642a183aec5c7e591579fa5ee91d45729bfb624d..cd925f6b65e569538212e9c26aef0abc8482960b 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/common/stats/node-stats.h +++ b/tensorflow/contrib/boosted_trees/lib/learner/common/stats/node-stats.h @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_NODE_STATS_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_NODE_STATS_H_ +#ifndef TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_NODE_STATS_H_ +#define TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_NODE_STATS_H_ #include "third_party/eigen3/Eigen/Core" #include "third_party/eigen3/Eigen/Eigenvalues" @@ -298,4 +298,4 @@ struct NodeStats { } // namespace boosted_trees } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_NODE_STATS_H_ +#endif // TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_NODE_STATS_H_ diff --git a/tensorflow/contrib/boosted_trees/lib/learner/common/stats/split-stats.h b/tensorflow/contrib/boosted_trees/lib/learner/common/stats/split-stats.h index 054ccd9a8cd0be0c48b14cca013f15677deba900..81ee2774bdab91f492064455055181c56ef6a065 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/common/stats/split-stats.h +++ b/tensorflow/contrib/boosted_trees/lib/learner/common/stats/split-stats.h @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_SPLIT_STATS_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_SPLIT_STATS_H_ +#ifndef TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_SPLIT_STATS_H_ +#define TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_SPLIT_STATS_H_ #include @@ -81,4 +81,4 @@ struct SplitStats { } // namespace boosted_trees } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_SPLIT_STATS_H_ +#endif // TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_SPLIT_STATS_H_ diff --git a/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.h b/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.h index ee29a8aa797b96d41ec2d77bf831ee287d5443e7..cc3dc226cdbc88fc7010ada1e7f0e6c0a3913c5f 100644 --- a/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.h +++ b/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.h @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_MODELS_MULTIPLE_ADDITIVE_TREES_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_MODELS_MULTIPLE_ADDITIVE_TREES_H_ +#ifndef TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_MODELS_MULTIPLE_ADDITIVE_TREES_H_ +#define TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_MODELS_MULTIPLE_ADDITIVE_TREES_H_ #include @@ -45,4 +45,4 @@ class MultipleAdditiveTrees { } // namespace boosted_trees } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_MODELS_MULTIPLE_ADDITIVE_TREES_H_ +#endif // TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_MODELS_MULTIPLE_ADDITIVE_TREES_H_ diff --git a/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_buffer.h b/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_buffer.h index 70037d5bd8f446bdbbfcc468edb8a76c05e4fab7..804b218f1c08338df80f8dd2e6135f5d92b9928e 100644 --- a/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_buffer.h +++ b/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_buffer.h @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_QUANTILES_WEIGHTED_QUANTILES_BUFFER_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_QUANTILES_WEIGHTED_QUANTILES_BUFFER_H_ +#ifndef TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_QUANTILES_WEIGHTED_QUANTILES_BUFFER_H_ +#define TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_QUANTILES_WEIGHTED_QUANTILES_BUFFER_H_ #include #include @@ -129,4 +129,4 @@ constexpr decltype(CompareFn()) } // namespace boosted_trees } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_QUANTILES_WEIGHTED_QUANTILES_BUFFER_H_ +#endif // TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_QUANTILES_WEIGHTED_QUANTILES_BUFFER_H_ diff --git a/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_stream.h b/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_stream.h index fd577ad712f228fa8016a48942511a3263aae5da..1c4181f1b13b01f85833157e554c3b821f96ff90 100644 --- a/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_stream.h +++ b/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_stream.h @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_QUANTILES_WEIGHTED_QUANTILES_STREAM_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_QUANTILES_WEIGHTED_QUANTILES_STREAM_H_ +#ifndef TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_QUANTILES_WEIGHTED_QUANTILES_STREAM_H_ +#define TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_QUANTILES_WEIGHTED_QUANTILES_STREAM_H_ #include #include @@ -322,4 +322,4 @@ WeightedQuantilesStream::GetQuantileSpecs( } // namespace boosted_trees } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_QUANTILES_WEIGHTED_QUANTILES_STREAM_H_ +#endif // TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_QUANTILES_WEIGHTED_QUANTILES_STREAM_H_ diff --git a/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_summary.h b/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_summary.h index c329c6d4f7363a7738b06648943fe1dbd065cce5..aec232f3cbb096f0aa51e4362a821882391f8027 100644 --- a/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_summary.h +++ b/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_summary.h @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_QUANTILES_WEIGHTED_QUANTILES_SUMMARY_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_QUANTILES_WEIGHTED_QUANTILES_SUMMARY_H_ +#ifndef TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_QUANTILES_WEIGHTED_QUANTILES_SUMMARY_H_ +#define TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_QUANTILES_WEIGHTED_QUANTILES_SUMMARY_H_ #include #include @@ -334,4 +334,4 @@ constexpr decltype(CompareFn()) } // namespace boosted_trees } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_QUANTILES_WEIGHTED_QUANTILES_SUMMARY_H_ +#endif // TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_QUANTILES_WEIGHTED_QUANTILES_SUMMARY_H_ diff --git a/tensorflow/contrib/boosted_trees/lib/testutil/batch_features_testutil.h b/tensorflow/contrib/boosted_trees/lib/testutil/batch_features_testutil.h index d95878ec87b9e903930d2016bb573eee2573f776..b98190b10dc88d5bba9023e771844a2bd6c9a45d 100644 --- a/tensorflow/contrib/boosted_trees/lib/testutil/batch_features_testutil.h +++ b/tensorflow/contrib/boosted_trees/lib/testutil/batch_features_testutil.h @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_TESTUTIL_BATCH_FEATURES_TESTUTIL_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_TESTUTIL_BATCH_FEATURES_TESTUTIL_H_ +#ifndef TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_TESTUTIL_BATCH_FEATURES_TESTUTIL_H_ +#define TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_TESTUTIL_BATCH_FEATURES_TESTUTIL_H_ #include "tensorflow/contrib/boosted_trees/lib/utils/batch_features.h" #include "tensorflow/core/framework/tensor.h" @@ -42,4 +42,4 @@ void RandomlyInitializeBatchFeatures( } // namespace boosted_trees } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_TESTUTIL_BATCH_FEATURES_TESTUTIL_H_ +#endif // TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_TESTUTIL_BATCH_FEATURES_TESTUTIL_H_ diff --git a/tensorflow/contrib/boosted_trees/lib/testutil/random_tree_gen.h b/tensorflow/contrib/boosted_trees/lib/testutil/random_tree_gen.h index 5e12429ba778344edda623d149e017661f1e0222..1838b4cee21afb5df72a9b902f0ec0ce6f7ac627 100644 --- a/tensorflow/contrib/boosted_trees/lib/testutil/random_tree_gen.h +++ b/tensorflow/contrib/boosted_trees/lib/testutil/random_tree_gen.h @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_TESTUTIL_RANDOM_TREE_GEN_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_TESTUTIL_RANDOM_TREE_GEN_H_ +#ifndef TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_TESTUTIL_RANDOM_TREE_GEN_H_ +#define TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_TESTUTIL_RANDOM_TREE_GEN_H_ #include @@ -72,4 +72,4 @@ class RandomTreeGen { } // namespace boosted_trees } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_TESTUTIL_RANDOM_TREE_GEN_H_ +#endif // TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_TESTUTIL_RANDOM_TREE_GEN_H_ diff --git a/tensorflow/contrib/boosted_trees/lib/trees/decision_tree.h b/tensorflow/contrib/boosted_trees/lib/trees/decision_tree.h index 604ff02744b25b136bd935bf85635731730effe8..43526c229a65d45a2b0ced4aa1262d489526fc7b 100644 --- a/tensorflow/contrib/boosted_trees/lib/trees/decision_tree.h +++ b/tensorflow/contrib/boosted_trees/lib/trees/decision_tree.h @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_TREES_DECISION_TREE_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_TREES_DECISION_TREE_H_ +#ifndef TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_TREES_DECISION_TREE_H_ +#define TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_TREES_DECISION_TREE_H_ #include "tensorflow/contrib/boosted_trees/lib/utils/example.h" #include "tensorflow/contrib/boosted_trees/proto/tree_config.pb.h" // NOLINT @@ -46,4 +46,4 @@ class DecisionTree { } // namespace boosted_trees } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_TREES_DECISION_TREE_H_ +#endif // TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_TREES_DECISION_TREE_H_ diff --git a/tensorflow/contrib/boosted_trees/lib/utils/batch_features.h b/tensorflow/contrib/boosted_trees/lib/utils/batch_features.h index badc629a118f768d5aa25ef1b94b8190e6910c7f..da5e7448519cb7f4092f7bbbe1b526271008ec22 100644 --- a/tensorflow/contrib/boosted_trees/lib/utils/batch_features.h +++ b/tensorflow/contrib/boosted_trees/lib/utils/batch_features.h @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_BATCH_FEATURES_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_BATCH_FEATURES_H_ +#ifndef TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_BATCH_FEATURES_H_ +#define TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_BATCH_FEATURES_H_ #include #include "tensorflow/contrib/boosted_trees/lib/utils/examples_iterable.h" @@ -92,4 +92,4 @@ class BatchFeatures { } // namespace boosted_trees } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_BATCH_FEATURES_H_ +#endif // TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_BATCH_FEATURES_H_ diff --git a/tensorflow/contrib/boosted_trees/lib/utils/dropout_utils.h b/tensorflow/contrib/boosted_trees/lib/utils/dropout_utils.h index c3f1c918ca5f603cf9470071017d8ee384dc9320..928bfbfe5c9394ab4083aabced4c8e1149bb10aa 100644 --- a/tensorflow/contrib/boosted_trees/lib/utils/dropout_utils.h +++ b/tensorflow/contrib/boosted_trees/lib/utils/dropout_utils.h @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_DROPOUT_UTILS_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_DROPOUT_UTILS_H_ +#ifndef TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_DROPOUT_UTILS_H_ +#define TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_DROPOUT_UTILS_H_ #include #include @@ -74,4 +74,4 @@ class DropoutUtils { } // namespace boosted_trees } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_DROPOUT_UTILS_H_ +#endif // TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_DROPOUT_UTILS_H_ diff --git a/tensorflow/contrib/boosted_trees/lib/utils/example.h b/tensorflow/contrib/boosted_trees/lib/utils/example.h index 54f60e1dee49a4a40b84fcc6e042fac1858aa187..1371ff337f78dd1c38f2bd0ba86911642f3aeb3e 100644 --- a/tensorflow/contrib/boosted_trees/lib/utils/example.h +++ b/tensorflow/contrib/boosted_trees/lib/utils/example.h @@ -13,8 +13,8 @@ // limitations under the License. // ============================================================================= -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_EXAMPLE_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_EXAMPLE_H_ +#ifndef TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_EXAMPLE_H_ +#define TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_EXAMPLE_H_ #include #include @@ -131,4 +131,4 @@ struct Example { } // namespace boosted_trees } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_EXAMPLE_H_ +#endif // TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_EXAMPLE_H_ diff --git a/tensorflow/contrib/boosted_trees/lib/utils/examples_iterable.h b/tensorflow/contrib/boosted_trees/lib/utils/examples_iterable.h index 5b33c8158879ec65425ac77b5338ee98fbdf07db..1b654e1c44e545fb97216ad950f3cd2d3240ffd0 100644 --- a/tensorflow/contrib/boosted_trees/lib/utils/examples_iterable.h +++ b/tensorflow/contrib/boosted_trees/lib/utils/examples_iterable.h @@ -13,8 +13,8 @@ // limitations under the License. // ============================================================================= -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_EXAMPLES_ITERABLE_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_EXAMPLES_ITERABLE_H_ +#ifndef TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_EXAMPLES_ITERABLE_H_ +#define TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_EXAMPLES_ITERABLE_H_ #include @@ -205,4 +205,4 @@ class ExamplesIterable { } // namespace boosted_trees } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_EXAMPLES_ITERABLE_H_ +#endif // TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_EXAMPLES_ITERABLE_H_ diff --git a/tensorflow/contrib/boosted_trees/lib/utils/macros.h b/tensorflow/contrib/boosted_trees/lib/utils/macros.h index 28ea0a4dc191af66ced574d78d9873cc8335f491..9a53fb2ef7d0581986885f3bc8233d91b67c0166 100644 --- a/tensorflow/contrib/boosted_trees/lib/utils/macros.h +++ b/tensorflow/contrib/boosted_trees/lib/utils/macros.h @@ -13,8 +13,8 @@ // limitations under the License. // ============================================================================= -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_MACROS_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_MACROS_H_ +#ifndef TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_MACROS_H_ +#define TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_MACROS_H_ #include "tensorflow/core/platform/macros.h" @@ -23,4 +23,4 @@ return (STATUS); \ } -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_MACROS_H_ +#endif // TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_MACROS_H_ diff --git a/tensorflow/contrib/boosted_trees/lib/utils/optional_value.h b/tensorflow/contrib/boosted_trees/lib/utils/optional_value.h index c141fe059d48072c6c4495535eafec9633616d21..b2166f53d7a037fb8ec53d5295b98bb82b17d4c7 100644 --- a/tensorflow/contrib/boosted_trees/lib/utils/optional_value.h +++ b/tensorflow/contrib/boosted_trees/lib/utils/optional_value.h @@ -13,8 +13,8 @@ // limitations under the License. // ============================================================================= -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_OPTIONAL_VALUE_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_OPTIONAL_VALUE_H_ +#ifndef TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_OPTIONAL_VALUE_H_ +#define TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_OPTIONAL_VALUE_H_ #include "tensorflow/core/platform/logging.h" @@ -44,4 +44,4 @@ class OptionalValue { } // namespace boosted_trees } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_OPTIONAL_VALUE_H_ +#endif // TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_OPTIONAL_VALUE_H_ diff --git a/tensorflow/contrib/boosted_trees/lib/utils/parallel_for.h b/tensorflow/contrib/boosted_trees/lib/utils/parallel_for.h index c80431b5587cecc0bce22f6150a69d30397529da..ec06787e1db69514c9e60f6d152f3b0c7de23842 100644 --- a/tensorflow/contrib/boosted_trees/lib/utils/parallel_for.h +++ b/tensorflow/contrib/boosted_trees/lib/utils/parallel_for.h @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LIB_UTILS_PARALLEL_FOR_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_LIB_UTILS_PARALLEL_FOR_H_ +#ifndef TENSORFLOW_CONTRIB_LIB_UTILS_PARALLEL_FOR_H_ +#define TENSORFLOW_CONTRIB_LIB_UTILS_PARALLEL_FOR_H_ #include "tensorflow/core/lib/core/threadpool.h" @@ -30,4 +30,4 @@ void ParallelFor(int64 batch_size, int64 desired_parallelism, } // namespace boosted_trees } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LIB_UTILS_PARALLEL_FOR_H_ +#endif // TENSORFLOW_CONTRIB_LIB_UTILS_PARALLEL_FOR_H_ diff --git a/tensorflow/contrib/boosted_trees/lib/utils/random.h b/tensorflow/contrib/boosted_trees/lib/utils/random.h index 6dd55fcacc42b88116737ab6fb413852ffc1473d..546d344f5585458f10699a644621f0adf26b6446 100644 --- a/tensorflow/contrib/boosted_trees/lib/utils/random.h +++ b/tensorflow/contrib/boosted_trees/lib/utils/random.h @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LIB_UTILS_RANDOM_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_LIB_UTILS_RANDOM_H_ +#ifndef TENSORFLOW_CONTRIB_LIB_UTILS_RANDOM_H_ +#define TENSORFLOW_CONTRIB_LIB_UTILS_RANDOM_H_ #include "tensorflow/core/lib/random/simple_philox.h" @@ -36,4 +36,4 @@ inline int32 PoissonBootstrap(random::SimplePhilox* rng) { } // namespace boosted_trees } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LIB_UTILS_RANDOM_H_ +#endif // TENSORFLOW_CONTRIB_LIB_UTILS_RANDOM_H_ diff --git a/tensorflow/contrib/boosted_trees/lib/utils/sparse_column_iterable.h b/tensorflow/contrib/boosted_trees/lib/utils/sparse_column_iterable.h index 9664c9d1c6a0c0c8b1bbd1506944c54d2310c611..87fb1fbf5ae3cc6bcf25f68a180d1d9b21ef4d6f 100644 --- a/tensorflow/contrib/boosted_trees/lib/utils/sparse_column_iterable.h +++ b/tensorflow/contrib/boosted_trees/lib/utils/sparse_column_iterable.h @@ -13,8 +13,8 @@ // limitations under the License. // ============================================================================= -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_SPARSE_COLUMN_ITERABLE_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_SPARSE_COLUMN_ITERABLE_H_ +#ifndef TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_SPARSE_COLUMN_ITERABLE_H_ +#define TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_SPARSE_COLUMN_ITERABLE_H_ #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_types.h" @@ -127,4 +127,4 @@ class SparseColumnIterable { } // namespace boosted_trees } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_SPARSE_COLUMN_ITERABLE_H_ +#endif // TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_SPARSE_COLUMN_ITERABLE_H_ diff --git a/tensorflow/contrib/boosted_trees/lib/utils/tensor_utils.h b/tensorflow/contrib/boosted_trees/lib/utils/tensor_utils.h index 58f5e5a0d18788375cd8166d1fcbdc7c294ba5e2..475d3718eccc2b23260b7cf5286abdd31ef1bad6 100644 --- a/tensorflow/contrib/boosted_trees/lib/utils/tensor_utils.h +++ b/tensorflow/contrib/boosted_trees/lib/utils/tensor_utils.h @@ -13,8 +13,8 @@ // limitations under the License. // ============================================================================= -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_TENSOR_UTILS_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_TENSOR_UTILS_H_ +#ifndef TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_TENSOR_UTILS_H_ +#define TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_TENSOR_UTILS_H_ #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" @@ -57,4 +57,4 @@ class TensorUtils { } // namespace boosted_trees } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_TENSOR_UTILS_H_ +#endif // TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_TENSOR_UTILS_H_ diff --git a/tensorflow/contrib/boosted_trees/ops/quantile_ops.cc b/tensorflow/contrib/boosted_trees/ops/quantile_ops.cc index 1fa70bafddb0c94f47d006d5694bea941edaddf9..bb57dcf8ae7475486bcc0fc82460cbbce9a18b68 100644 --- a/tensorflow/contrib/boosted_trees/ops/quantile_ops.cc +++ b/tensorflow/contrib/boosted_trees/ops/quantile_ops.cc @@ -39,6 +39,7 @@ REGISTER_OP("CreateQuantileAccumulator") .Attr("max_elements: int = 1099511627776") // 1 << 40 .Attr("epsilon: float") .Attr("num_quantiles: int") + .Attr("generate_quantiles: bool=False") .Input("quantile_accumulator_handle: resource") .Input("stamp_token: int64") .SetShapeFn([](shape_inference::InferenceContext* c) { diff --git a/tensorflow/contrib/boosted_trees/python/kernel_tests/quantile_ops_test.py b/tensorflow/contrib/boosted_trees/python/kernel_tests/quantile_ops_test.py index 888d5c57ed33446c8b6f18d2d1e393647613d132..eefa7ef0dccf5e88099974302dd26eebe21b1bd2 100644 --- a/tensorflow/contrib/boosted_trees/python/kernel_tests/quantile_ops_test.py +++ b/tensorflow/contrib/boosted_trees/python/kernel_tests/quantile_ops_test.py @@ -106,9 +106,11 @@ class QuantileBucketsOpTest(test_util.TensorFlowTestCase): | 6 | 16 | [16, 17, 18, 19, 20, 21] """ + num_quantiles = 3 with self.test_session() as sess: accumulator = quantile_ops.QuantileAccumulator( - init_stamp_token=0, num_quantiles=3, epsilon=0.001, name="q1") + init_stamp_token=0, num_quantiles=num_quantiles, + epsilon=0.001, name="q1") resources.initialize_resources(resources.shared_resources()).run() input_column = array_ops.placeholder(dtypes.float32) weights = array_ops.placeholder(dtypes.float32) @@ -131,8 +133,104 @@ class QuantileBucketsOpTest(test_util.TensorFlowTestCase): buckets, are_ready_flush = (sess.run( [buckets, are_ready_flush])) self.assertEqual(True, are_ready_flush) + self.assertEqual(num_quantiles + 1, len(buckets)) self.assertAllEqual([1, 86., 170., 253.], buckets) + def testStreamingQuantileBucketsLowPrecisionInput(self): + """Tests inputs that simulate low precision float16 values.""" + + num_quantiles = 3 + # set generate_quantiles to True since the test will generate fewer + # boundaries otherwise. + with self.test_session() as sess: + accumulator = quantile_ops.QuantileAccumulator( + init_stamp_token=0, num_quantiles=num_quantiles, + epsilon=0.001, name="q1", generate_quantiles=True) + resources.initialize_resources(resources.shared_resources()).run() + input_column = array_ops.placeholder(dtypes.float32) + weights = array_ops.placeholder(dtypes.float32) + update = accumulator.add_summary( + stamp_token=0, + column=input_column, + example_weights=weights) + + with self.test_session() as sess: + # This input is generated by integer in the range [2030, 2060] + # but represented by with float16 precision. Integers <= 2048 are + # exactly represented, whereas numbers > 2048 are rounded; and hence + # numbers > 2048 are repeated. For precision loss / rounding, see: + # https://en.wikipedia.org/wiki/Half-precision_floating-point_format. + # + # The intent of the test is not handling of float16 values, but to + # validate the number of buckets is returned, in cases where the input + # may contain repeated values. + inputs = [ + 2030.0, 2031.0, 2032.0, 2033.0, 2034.0, 2035.0, 2036.0, 2037.0, + 2038.0, 2039.0, 2040.0, 2041.0, 2042.0, 2043.0, 2044.0, 2045.0, + 2046.0, 2047.0, 2048.0, 2048.0, 2050.0, 2052.0, 2052.0, 2052.0, + 2054.0, 2056.0, 2056.0, 2056.0, 2058.0, 2060.0 + ] + sess.run(update, + {input_column: inputs, + weights: [1] * len(inputs)}) + + with self.test_session() as sess: + sess.run(accumulator.flush(stamp_token=0, next_stamp_token=1)) + are_ready_flush, buckets = (accumulator.get_buckets(stamp_token=1)) + buckets, are_ready_flush = (sess.run( + [buckets, are_ready_flush])) + self.assertEqual(True, are_ready_flush) + self.assertEqual(num_quantiles + 1, len(buckets)) + self.assertAllEqual([2030, 2040, 2050, 2060], buckets) + + def _testStreamingQuantileBucketsHelper(self, inputs): + """Helper to test quantile buckets on different inputs.""" + + # Use 3 quantiles, 4 boundaries for simplicity. + num_quantiles = 3 + # set generate_quantiles to True since the test will generate fewer + # boundaries otherwise. + with self.test_session() as sess: + accumulator = quantile_ops.QuantileAccumulator( + init_stamp_token=0, num_quantiles=num_quantiles, + epsilon=0.001, name="q1", generate_quantiles=True) + resources.initialize_resources(resources.shared_resources()).run() + input_column = array_ops.placeholder(dtypes.float32) + weights = array_ops.placeholder(dtypes.float32) + update = accumulator.add_summary( + stamp_token=0, + column=input_column, + example_weights=weights) + + with self.test_session() as sess: + sess.run(update, + {input_column: inputs, + weights: [1] * len(inputs)}) + + with self.test_session() as sess: + sess.run(accumulator.flush(stamp_token=0, next_stamp_token=1)) + are_ready_flush, buckets = (accumulator.get_buckets(stamp_token=1)) + buckets, are_ready_flush = (sess.run( + [buckets, are_ready_flush])) + self.assertEqual(True, are_ready_flush) + self.assertEqual(num_quantiles + 1, len(buckets)) + + def testStreamingQuantileBucketsRepeatedSingleValue(self): + inputs = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1] + self._testStreamingQuantileBucketsHelper(inputs) + + def testStreamingQ2antileBucketsRepeatedTwoValues(self): + inputs = [1, 1, 1, 2, 2, 2, 2, 2, 1, 1] + self._testStreamingQuantileBucketsHelper(inputs) + + def testStreamingQ2antileBucketsRepeatedTwoValuesUnbalanced(self): + inputs = [7, 7, 7, 2, 7, 7, 2, 2, 7, 7] + self._testStreamingQuantileBucketsHelper(inputs) + + def testStreamingQuantileBucketsFewerInputstThanBuckets(self): + inputs = [5] + self._testStreamingQuantileBucketsHelper(inputs) + def testStreamingQuantileBuckets(self): """Sets up the quantile summary op test as follows. diff --git a/tensorflow/contrib/boosted_trees/python/ops/batch_ops_utils.py b/tensorflow/contrib/boosted_trees/python/ops/batch_ops_utils.py index 23168bf4935e92bcb5072348361ae04861641b6d..b281a4c6d1cab9bfa1dc4018c8f49a16f21f2a36 100644 --- a/tensorflow/contrib/boosted_trees/python/ops/batch_ops_utils.py +++ b/tensorflow/contrib/boosted_trees/python/ops/batch_ops_utils.py @@ -104,7 +104,7 @@ def run_handler_scheduled_ops(per_handler_ops, stamp, worker_device): batched_ops = collections.defaultdict(list) # Group the ops by their batching_key. Ops that share the same batching key # can be executed together. - for handler in per_handler_ops.keys(): + for handler in sorted(per_handler_ops.keys()): for op in per_handler_ops[handler]: batched_ops[(op.batching_key(), op.batch_runner_fn())].append(op) op_results = {} diff --git a/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py b/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py index 294e04002adac62fc123a3242a05a1b36f422433..97d57e8b23608d4c3a8719426a75056fc6417d1d 100644 --- a/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py +++ b/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py @@ -47,7 +47,8 @@ class QuantileAccumulator(saver.BaseSaverBuilder.SaveableObject): num_quantiles, max_elements=None, name=None, - container=None): + container=None, + generate_quantiles=False): """Creates a QuantileAccumulator object. Args: @@ -57,8 +58,11 @@ class QuantileAccumulator(saver.BaseSaverBuilder.SaveableObject): max_elements: Maximum number of elements added to the accumulator. name: the name to save the accumulator under. container: An optional `string`. Defaults to `""` + generate_quantiles: Generate quantiles instead of approximate boundaries. + If true, exactly `num_quantiles` will be produced in the final summary. """ self._epsilon = epsilon + self._generate_quantiles = generate_quantiles name = _PATTERN.sub("", name) with ops.name_scope(name, "QuantileAccumulator") as name: @@ -70,7 +74,8 @@ class QuantileAccumulator(saver.BaseSaverBuilder.SaveableObject): init_stamp_token, epsilon=epsilon, max_elements=max_elements, - num_quantiles=num_quantiles) + num_quantiles=num_quantiles, + generate_quantiles=generate_quantiles) is_initialized_op = gen_quantile_ops.quantile_accumulator_is_initialized( self._quantile_accumulator_handle) resources.register_resource(self._quantile_accumulator_handle, @@ -176,7 +181,14 @@ class QuantileAccumulator(saver.BaseSaverBuilder.SaveableObject): summaries=summary) def flush(self, stamp_token, next_stamp_token): - """Finalizes quantile summary stream and resets it for next iteration.""" + """Finalizes quantile summary stream and resets it for next iteration. + + Args: + stamp_token: Exepcted current token. + next_stamp_token: Next value for the token. + Returns: + A list of quantiles or approximate boundaries. + """ return gen_quantile_ops.quantile_accumulator_flush( quantile_accumulator_handle=self._quantile_accumulator_handle, stamp_token=stamp_token, diff --git a/tensorflow/contrib/boosted_trees/resources/decision_tree_ensemble_resource.h b/tensorflow/contrib/boosted_trees/resources/decision_tree_ensemble_resource.h index ad9c8961aaadbc4c1ff6bdc7793171d0ad48d75f..3ebf28ea442edf87815c39971ae9e01a2a8aae9a 100644 --- a/tensorflow/contrib/boosted_trees/resources/decision_tree_ensemble_resource.h +++ b/tensorflow/contrib/boosted_trees/resources/decision_tree_ensemble_resource.h @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_RESOURCES_DECISION_TREE_ENSEMBLE_RESOURCE_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_RESOURCES_DECISION_TREE_ENSEMBLE_RESOURCE_H_ +#ifndef TENSORFLOW_CONTRIB_BOOSTED_TREES_RESOURCES_DECISION_TREE_ENSEMBLE_RESOURCE_H_ +#define TENSORFLOW_CONTRIB_BOOSTED_TREES_RESOURCES_DECISION_TREE_ENSEMBLE_RESOURCE_H_ #include "tensorflow/contrib/boosted_trees/lib/trees/decision_tree.h" #include "tensorflow/contrib/boosted_trees/resources/stamped_resource.h" @@ -179,4 +179,4 @@ class DecisionTreeEnsembleResource : public StampedResource { } // namespace boosted_trees } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_RESOURCES_DECISION_TREE_ENSEMBLE_RESOURCE_H_ +#endif // TENSORFLOW_CONTRIB_BOOSTED_TREES_RESOURCES_DECISION_TREE_ENSEMBLE_RESOURCE_H_ diff --git a/tensorflow/contrib/boosted_trees/resources/quantile_stream_resource.h b/tensorflow/contrib/boosted_trees/resources/quantile_stream_resource.h index fb29f79e578e8e52b67de631c527be35b7772b41..fdaaae7f472c8f564ab45a8366d3746cbf1158ee 100644 --- a/tensorflow/contrib/boosted_trees/resources/quantile_stream_resource.h +++ b/tensorflow/contrib/boosted_trees/resources/quantile_stream_resource.h @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_RESOURCES_QUANTILE_STREAM_RESOURCE_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_RESOURCES_QUANTILE_STREAM_RESOURCE_H_ +#ifndef TENSORFLOW_CONTRIB_BOOSTED_TREES_RESOURCES_QUANTILE_STREAM_RESOURCE_H_ +#define TENSORFLOW_CONTRIB_BOOSTED_TREES_RESOURCES_QUANTILE_STREAM_RESOURCE_H_ #include "tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_stream.h" #include "tensorflow/contrib/boosted_trees/proto/quantiles.pb.h" // NOLINT @@ -32,12 +32,14 @@ using QuantileStream = class QuantileStreamResource : public StampedResource { public: QuantileStreamResource(const float epsilon, const int32 num_quantiles, - const int64 max_elements, int64 stamp_token) + const int64 max_elements, bool generate_quantiles, + int64 stamp_token) : stream_(epsilon, max_elements), are_buckets_ready_(false), epsilon_(epsilon), num_quantiles_(num_quantiles), - max_elements_(max_elements) { + max_elements_(max_elements), + generate_quantiles_(generate_quantiles) { set_stamp(stamp_token); } @@ -74,6 +76,11 @@ class QuantileStreamResource : public StampedResource { are_buckets_ready_ = are_buckets_ready; } + bool generate_quantiles() const { return generate_quantiles_; } + void set_generate_quantiles(bool generate_quantiles) { + generate_quantiles_ = generate_quantiles; + } + private: ~QuantileStreamResource() override {} @@ -95,10 +102,15 @@ class QuantileStreamResource : public StampedResource { const int32 num_quantiles_; // An upper-bound for the number of elements. int64 max_elements_; + + // Generate quantiles instead of approximate boundaries. + // If true, exactly `num_quantiles` will be produced in the final summary. + bool generate_quantiles_; + TF_DISALLOW_COPY_AND_ASSIGN(QuantileStreamResource); }; } // namespace boosted_trees } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_RESOURCES_QUANTILE_STREAM_RESOURCE_H_ +#endif // TENSORFLOW_CONTRIB_BOOSTED_TREES_RESOURCES_QUANTILE_STREAM_RESOURCE_H_ diff --git a/tensorflow/contrib/boosted_trees/resources/stamped_resource.h b/tensorflow/contrib/boosted_trees/resources/stamped_resource.h index aabeeb98516eda6f7e8e7e296d6860fe5d8d5ec3..957bbe8d61d3dd32adba1a7f0cf840c69bce6273 100644 --- a/tensorflow/contrib/boosted_trees/resources/stamped_resource.h +++ b/tensorflow/contrib/boosted_trees/resources/stamped_resource.h @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_RESOURCES_STAMPED_RESOURCE_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_RESOURCES_STAMPED_RESOURCE_H_ +#ifndef TENSORFLOW_CONTRIB_BOOSTED_TREES_RESOURCES_STAMPED_RESOURCE_H_ +#define TENSORFLOW_CONTRIB_BOOSTED_TREES_RESOURCES_STAMPED_RESOURCE_H_ #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/platform/mutex.h" @@ -39,4 +39,4 @@ class StampedResource : public ResourceBase { } // namespace boosted_trees } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_RESOURCES_STAMPED_RESOURCE_H_ +#endif // TENSORFLOW_CONTRIB_BOOSTED_TREES_RESOURCES_STAMPED_RESOURCE_H_ diff --git a/tensorflow/contrib/cloud/kernels/bigquery_table_accessor.h b/tensorflow/contrib/cloud/kernels/bigquery_table_accessor.h index 7d0eee59ae2f47503c4f8994ef356ce0dc336733..b349063715c903c982cfe2fb116b6525e35ff63b 100644 --- a/tensorflow/contrib/cloud/kernels/bigquery_table_accessor.h +++ b/tensorflow/contrib/cloud/kernels/bigquery_table_accessor.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_CLOUD_BIGQUERY_PARTITION_ACCESSOR_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_CLOUD_BIGQUERY_PARTITION_ACCESSOR_H_ +#ifndef TENSORFLOW_CORE_KERNELS_CLOUD_BIGQUERY_PARTITION_ACCESSOR_H_ +#define TENSORFLOW_CORE_KERNELS_CLOUD_BIGQUERY_PARTITION_ACCESSOR_H_ #include #include @@ -205,4 +205,4 @@ class BigQueryTableAccessor { }; } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_CLOUD_BIGQUERY_PARTITION_ACCESSOR_H_ +#endif // TENSORFLOW_CORE_KERNELS_CLOUD_BIGQUERY_PARTITION_ACCESSOR_H_ diff --git a/tensorflow/contrib/cloud/kernels/bigquery_table_accessor_test_data.h b/tensorflow/contrib/cloud/kernels/bigquery_table_accessor_test_data.h index b2b11f4f57800d55ebc86273fcda71e673ff143a..59f23332983e2328286d3b1b8b8c8fa228be991e 100644 --- a/tensorflow/contrib/cloud/kernels/bigquery_table_accessor_test_data.h +++ b/tensorflow/contrib/cloud/kernels/bigquery_table_accessor_test_data.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_CLOUD_BIGQUERY_TABLE_ACCESSOR_TEST_DATA_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_CLOUD_BIGQUERY_TABLE_ACCESSOR_TEST_DATA_H_ +#ifndef TENSORFLOW_CORE_KERNELS_CLOUD_BIGQUERY_TABLE_ACCESSOR_TEST_DATA_H_ +#define TENSORFLOW_CORE_KERNELS_CLOUD_BIGQUERY_TABLE_ACCESSOR_TEST_DATA_H_ #include @@ -401,4 +401,4 @@ const string kTestEmptyRow = R"({ } // namespace } // namepsace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_CLOUD_BIGQUERY_TABLE_ACCESSOR_TEST_DATA_H_ +#endif // TENSORFLOW_CORE_KERNELS_CLOUD_BIGQUERY_TABLE_ACCESSOR_TEST_DATA_H_ diff --git a/tensorflow/contrib/cmake/external/snappy.cmake b/tensorflow/contrib/cmake/external/snappy.cmake index 013b3a862f13fd9017fade500d391ecc2bd27fae..fd57734298affda13fa90f4cff560eeeb08e59ab 100644 --- a/tensorflow/contrib/cmake/external/snappy.cmake +++ b/tensorflow/contrib/cmake/external/snappy.cmake @@ -47,4 +47,4 @@ ExternalProject_Add(snappy ) # actually enables snappy in the source code -add_definitions(-DTF_USE_SNAPPY) \ No newline at end of file +add_definitions(-DTF_USE_SNAPPY) diff --git a/tensorflow/contrib/cmake/python_modules.txt b/tensorflow/contrib/cmake/python_modules.txt index e37d059a84cb3d75cebf2473e7880f6d6cb20a69..7db454bd83ec7fee463b8cd448f5a5ff4ba73258 100644 --- a/tensorflow/contrib/cmake/python_modules.txt +++ b/tensorflow/contrib/cmake/python_modules.txt @@ -1,3 +1,5 @@ +# python_sanity_test.py will complain about invalid or missing entries +# problematic entries can be commented for temporary whitelisting tensorflow tensorflow/core tensorflow/core/example @@ -109,7 +111,6 @@ tensorflow/contrib/android/java/org/tensorflow/contrib tensorflow/contrib/android/java/org/tensorflow/contrib/android tensorflow/contrib/android/jni tensorflow/contrib/batching -tensorflow/contrib/batching/kernels tensorflow/contrib/batching/python tensorflow/contrib/batching/python/ops tensorflow/contrib/bayesflow @@ -308,6 +309,8 @@ tensorflow/contrib/metrics tensorflow/contrib/metrics/python tensorflow/contrib/metrics/python/metrics tensorflow/contrib/metrics/python/ops +tensorflow/contrib/mpi_collectives/python +tensorflow/contrib/mpi_collectives/python/ops tensorflow/contrib/model_pruning tensorflow/contrib/model_pruning/examples tensorflow/contrib/model_pruning/examples/cifar10 diff --git a/tensorflow/contrib/cmake/python_sanity_test.py b/tensorflow/contrib/cmake/python_sanity_test.py new file mode 100644 index 0000000000000000000000000000000000000000..e0056823a80833329bcb1f275a3384a33127bb40 --- /dev/null +++ b/tensorflow/contrib/cmake/python_sanity_test.py @@ -0,0 +1,128 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Complain about invalid or missing entries in python_*.txt files. + +Problematic entries can be commented for temporary whitelisting. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import unittest + + +def abs_path(path): + root = os.path.dirname(__file__) + + for _ in range(3): + root = os.path.join(root, os.pardir) + + path = os.path.join(root, path) + path = os.path.abspath(path) + return path + + +def read_entries(test): + with open(abs_path(test.entries_file), "r") as f: + lines = f.readlines() + + lines = [line.strip() for line in lines] + lines = [line for line in lines if line] + + test.entries = [] + test.whitelist = [] + + for line in lines: + # line is comment + if line.startswith("#"): + line = line[1:].strip() + # whitelist entry + if line.startswith("tensorflow/"): + test.whitelist.append(line) + # line has comment -> strip comment + elif line.find("#") != -1: + line = line[:line.find("#")].strip() + test.entries.append(line) + else: + test.entries.append(line) + + +def test_invalid_directories(test): + for entry in test.entries: + if not os.path.isdir(abs_path(entry)): + problem = "'" + test.entries_file + "' contains invalid '" + entry + "'" + solution = ("Please remove the invalid entry (or add the missing " + "directory).") + raise AssertionError(problem + "\n" + solution) + + +def test_missing_directory(test, path): + if path in test.whitelist: + return + + dir_exists = os.path.isdir(abs_path(path)) + entry_exists = path in test.entries + + if dir_exists and not entry_exists: + problem = "'" + test.entries_file + "' is missing '" + path + "'" + solution = "Please add the missing entry (comment to whitelist if needed)." + raise AssertionError(problem + "\n" + solution) + + +class PythonModuleTest(unittest.TestCase): + + def setUp(self): + self.entries_file = "tensorflow/contrib/cmake/python_modules.txt" + read_entries(self) + + def testInvalidEntries(self): + test_invalid_directories(self) + + def testMissingModules(self): + module_names = next(os.walk(abs_path("tensorflow/contrib")))[1] + + for module_name in module_names: + path = "tensorflow/contrib/" + module_name + + test_missing_directory(self, path + "/python") + test_missing_directory(self, path + "/python/ops") + test_missing_directory(self, path + "/python/kernels") + test_missing_directory(self, path + "/python/layers") + + +class PythonProtoTest(unittest.TestCase): + + def setUp(self): + self.entries_file = "tensorflow/contrib/cmake/python_protos.txt" + read_entries(self) + + def testInvalidEntries(self): + test_invalid_directories(self) + + +class PythonProtoCCTest(unittest.TestCase): + + def setUp(self): + self.entries_file = "tensorflow/contrib/cmake/python_protos_cc.txt" + read_entries(self) + + def testInvalidEntries(self): + test_invalid_directories(self) + + +if __name__ == "__main__": + unittest.main() diff --git a/tensorflow/contrib/cmake/tf_core_framework.cmake b/tensorflow/contrib/cmake/tf_core_framework.cmake index 24d7fb82a268623be06c2b98b5857b6b9b95c3a1..129c208ecd6b574ed63c2fe378e1a6ebb92de558 100644 --- a/tensorflow/contrib/cmake/tf_core_framework.cmake +++ b/tensorflow/contrib/cmake/tf_core_framework.cmake @@ -126,7 +126,9 @@ endfunction() file(GLOB_RECURSE tf_protos_cc_srcs RELATIVE ${tensorflow_source_dir} "${tensorflow_source_dir}/tensorflow/core/*.proto" "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/proto/*.proto" + "${tensorflow_source_dir}/tensorflow/contrib/tpu/proto/*.proto" ) + RELATIVE_PROTOBUF_GENERATE_CPP(PROTO_SRCS PROTO_HDRS ${tensorflow_source_dir} ${tf_protos_cc_srcs} ) diff --git a/tensorflow/contrib/cmake/tf_core_ops.cmake b/tensorflow/contrib/cmake/tf_core_ops.cmake index 6f56e9d0869bc0d3311ffbc68326f8ab43758019..138993db35252d3f1ab6326dff463bdc10cabdb1 100644 --- a/tensorflow/contrib/cmake/tf_core_ops.cmake +++ b/tensorflow/contrib/cmake/tf_core_ops.cmake @@ -15,6 +15,7 @@ set(tf_op_lib_names "audio_ops" "array_ops" + "batch_ops" "bitwise_ops" "candidate_sampling_ops" "checkpoint_ops" diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake index 17bbdb1a86f4a1b026b6d159a7b8adad9a3d1f57..8862390d2b62f72c11d60f2ae48a845d22363f06 100755 --- a/tensorflow/contrib/cmake/tf_python.cmake +++ b/tensorflow/contrib/cmake/tf_python.cmake @@ -126,7 +126,8 @@ STRING(REGEX REPLACE ";" "\\\\;" python_protos "${python_protos}") STRING(REGEX REPLACE "\n" ";" python_protos "${python_protos}") foreach(python_proto ${python_protos}) - if(NOT python_proto MATCHES "\#") + if(NOT python_proto MATCHES "^\#") + STRING(REGEX REPLACE " *\#.*" "" python_proto "${python_proto}") if(NOT EXISTS "${tensorflow_source_dir}/${python_proto}") message(SEND_ERROR "Python proto directory not found: ${python_proto}") endif() @@ -147,7 +148,8 @@ STRING(REGEX REPLACE ";" "\\\\;" python_protos_cc "${python_protos_cc}") STRING(REGEX REPLACE "\n" ";" python_protos_cc "${python_protos_cc}") foreach(python_proto_cc ${python_protos_cc}) - if(NOT python_proto_cc MATCHES "\#") + if(NOT python_proto_cc MATCHES "^\#") + STRING(REGEX REPLACE " *\#.*" "" python_proto_cc "${python_proto_cc}") if(NOT EXISTS "${tensorflow_source_dir}/${python_proto_cc}") message(SEND_ERROR "Python proto CC directory not found: ${python_proto_cc}") endif() @@ -209,7 +211,8 @@ STRING(REGEX REPLACE ";" "\\\\;" python_modules "${python_modules}") STRING(REGEX REPLACE "\n" ";" python_modules "${python_modules}") foreach(python_module ${python_modules}) - if(NOT python_module MATCHES "\#") + if(NOT python_module MATCHES "^\#") + STRING(REGEX REPLACE " *\#.*" "" python_module "${python_module}") if(NOT EXISTS "${tensorflow_source_dir}/${python_module}") message(SEND_ERROR "Python module not found: ${python_module}") endif() @@ -314,6 +317,7 @@ endfunction() GENERATE_PYTHON_OP_LIB("audio_ops") GENERATE_PYTHON_OP_LIB("array_ops") +GENERATE_PYTHON_OP_LIB("batch_ops") GENERATE_PYTHON_OP_LIB("bitwise_ops") GENERATE_PYTHON_OP_LIB("math_ops") GENERATE_PYTHON_OP_LIB("functional_ops") diff --git a/tensorflow/contrib/coder/kernels/range_coder.h b/tensorflow/contrib/coder/kernels/range_coder.h index c24fb707fc9f1776a4e6e7be7df3245c0cdccb0b..f46413072e34a55128d7854b9c312dfdde457d85 100644 --- a/tensorflow/contrib/coder/kernels/range_coder.h +++ b/tensorflow/contrib/coder/kernels/range_coder.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_CODER_KERNELS_RANGE_CODER_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_CODER_KERNELS_RANGE_CODER_H_ +#ifndef TENSORFLOW_CONTRIB_CODER_KERNELS_RANGE_CODER_H_ +#define TENSORFLOW_CONTRIB_CODER_KERNELS_RANGE_CODER_H_ #include #include @@ -106,4 +106,4 @@ class RangeDecoder { }; } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_CODER_KERNELS_RANGE_CODER_H_ +#endif // TENSORFLOW_CONTRIB_CODER_KERNELS_RANGE_CODER_H_ diff --git a/tensorflow/contrib/coder/kernels/range_coder_ops_util.h b/tensorflow/contrib/coder/kernels/range_coder_ops_util.h index 95241a8682891dc94780a9194d20aa9dc22e17c8..b8aabcef62e9de53810397960f871abc4adc0cf9 100644 --- a/tensorflow/contrib/coder/kernels/range_coder_ops_util.h +++ b/tensorflow/contrib/coder/kernels/range_coder_ops_util.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_CODER_KERNELS_RANGE_CODER_OPS_UTIL_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_CODER_KERNELS_RANGE_CODER_OPS_UTIL_H_ +#ifndef TENSORFLOW_CONTRIB_CODER_KERNELS_RANGE_CODER_OPS_UTIL_H_ +#define TENSORFLOW_CONTRIB_CODER_KERNELS_RANGE_CODER_OPS_UTIL_H_ #include @@ -30,4 +30,4 @@ Status MergeAxes(const TensorShape& broadcast_shape, std::vector* merged_storage_shape_pointer); } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_CODER_KERNELS_RANGE_CODER_OPS_UTIL_H_ +#endif // TENSORFLOW_CONTRIB_CODER_KERNELS_RANGE_CODER_OPS_UTIL_H_ diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD index 1fbf18f30a293de697826885d15bb95b40568daa..1cf0202fd88951ffcc611af39fa0915110c4d819 100644 --- a/tensorflow/contrib/data/python/kernel_tests/BUILD +++ b/tensorflow/contrib/data/python/kernel_tests/BUILD @@ -475,7 +475,7 @@ py_test( py_test( name = "stats_dataset_ops_test", - size = "small", + size = "medium", srcs = ["stats_dataset_ops_test.py"], srcs_version = "PY2AND3", tags = ["no_pip"], diff --git a/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py index 69252612a8e6cb29c513003188946be21f3432c2..dd8247bfd47a9880c7cfe905103702e43b1f2165 100644 --- a/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py @@ -765,6 +765,15 @@ class MapDatasetSerializationTest( self.verify_error_on_save(_build_ds, 15, errors.InvalidArgumentError) + def testCaptureConstantInMapFn(self): + + def _build_ds(): + constant_var = constant_op.constant(5) + return (contrib_dataset_ops.Dataset.from_tensors(0).repeat(10).map( + lambda x: x + constant_var)) + + self.run_core_tests(_build_ds, None, 10) + def testCaptureDefunInMapFn(self): num_outputs = 100 @@ -856,6 +865,15 @@ class ParallelMapDatasetSerializationTest( self.verify_error_on_save(_build_ds, 15, errors.InvalidArgumentError) + def testCaptureConstantInMapFn(self): + + def _build_ds(): + constant_var = constant_op.constant(5) + return (contrib_dataset_ops.Dataset.from_tensors(0).repeat(10).map( + lambda x: x + constant_var)) + + self.run_core_tests(_build_ds, None, 10) + def testCaptureDefunInMapFn(self): num_outputs = 100 diff --git a/tensorflow/contrib/distributions/BUILD b/tensorflow/contrib/distributions/BUILD index 95848af69950bdaa680c41daecd8cbd8f3174f8e..7f510c42215f48a9e795eb81bd9f66b0a2108335 100644 --- a/tensorflow/contrib/distributions/BUILD +++ b/tensorflow/contrib/distributions/BUILD @@ -128,6 +128,19 @@ cuda_py_test( tags = ["no_pip"], ) +cuda_py_test( + name = "autoregressive_test", + size = "small", + srcs = ["python/kernel_tests/autoregressive_test.py"], + additional_deps = [ + ":distributions_py", + "//third_party/py/numpy", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:platform_test", + ], +) + cuda_py_test( name = "binomial_test", size = "small", @@ -918,6 +931,22 @@ cuda_py_test( ], ) +cuda_py_test( + name = "real_nvp_test", + size = "small", + srcs = ["python/kernel_tests/bijectors/real_nvp_test.py"], + additional_deps = [ + ":bijectors_py", + ":distributions_py", + "//third_party/py/numpy", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform_test", + ], +) + cuda_py_test( name = "permute_test", size = "small", diff --git a/tensorflow/contrib/distributions/__init__.py b/tensorflow/contrib/distributions/__init__.py index 7b401e178f35fe56e4eb461936565f5c630ec4cf..60a187e541df4a794ae3944c30c427944915f7d0 100644 --- a/tensorflow/contrib/distributions/__init__.py +++ b/tensorflow/contrib/distributions/__init__.py @@ -23,6 +23,7 @@ from __future__ import print_function # pylint: disable=unused-import,wildcard-import,line-too-long,g-importing-member from tensorflow.contrib.distributions.python.ops import bijectors +from tensorflow.contrib.distributions.python.ops.autoregressive import * from tensorflow.contrib.distributions.python.ops.binomial import * from tensorflow.contrib.distributions.python.ops.cauchy import * from tensorflow.contrib.distributions.python.ops.chi2 import * @@ -84,6 +85,7 @@ from tensorflow.python.ops.distributions.uniform import * from tensorflow.python.util.all_util import remove_undocumented _allowed_symbols = [ + 'auto_correlation', 'bijectors', 'Cauchy', 'ConditionalDistribution', @@ -92,6 +94,7 @@ _allowed_symbols = [ 'NOT_REPARAMETERIZED', 'ReparameterizationType', 'Distribution', + 'Autoregressive', 'Binomial', 'Bernoulli', 'BernoulliWithSigmoidProbs', diff --git a/tensorflow/contrib/distributions/python/kernel_tests/autoregressive_test.py b/tensorflow/contrib/distributions/python/kernel_tests/autoregressive_test.py new file mode 100644 index 0000000000000000000000000000000000000000..0928dc3f358ede693865a8d1ff9257a0ecbe9499 --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/autoregressive_test.py @@ -0,0 +1,94 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.distributions.python.ops import autoregressive as autoregressive_lib +from tensorflow.contrib.distributions.python.ops import independent as independent_lib +from tensorflow.contrib.distributions.python.ops import test_util +from tensorflow.contrib.distributions.python.ops.bijectors.affine import Affine +from tensorflow.contrib.distributions.python.ops.bijectors.masked_autoregressive import MaskedAutoregressiveFlow +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops.distributions import normal as normal_lib +from tensorflow.python.ops.distributions import transformed_distribution as transformed_distribution_lib +from tensorflow.python.ops.distributions import util as distribution_util +from tensorflow.python.platform import test + + +class AutogressiveTest(test_util.VectorDistributionTestHelpers, test.TestCase): + """Tests the Autoregressive distribution.""" + + def setUp(self): + self._rng = np.random.RandomState(42) + + def _random_scale_tril(self, event_size): + n = np.int32(event_size * (event_size + 1) // 2) + p = 2. * self._rng.random_sample(n).astype(np.float32) - 1. + return distribution_util.fill_triangular(0.25 * p) + + def _normal_fn(self, affine_bijector): + def _fn(samples): + scale = math_ops.exp(affine_bijector.forward(samples)) + return independent_lib.Independent( + normal_lib.Normal(loc=0., scale=scale, validate_args=True), + reinterpreted_batch_ndims=1) + return _fn + + def testSampleAndLogProbConsistency(self): + batch_shape = [] + event_size = 2 + with self.test_session() as sess: + batch_event_shape = np.concatenate([batch_shape, [event_size]], axis=0) + sample0 = array_ops.zeros(batch_event_shape) + affine = Affine(scale_tril=self._random_scale_tril(event_size)) + ar = autoregressive_lib.Autoregressive( + self._normal_fn(affine), sample0, validate_args=True) + self.run_test_sample_consistent_log_prob( + sess.run, ar, radius=1., center=0., rtol=0.01) + + def testCompareToBijector(self): + """Demonstrates equivalence between TD, Bijector approach and AR dist.""" + sample_shape = np.int32([4, 5]) + batch_shape = np.int32([]) + event_size = np.int32(2) + with self.test_session() as sess: + batch_event_shape = np.concatenate([batch_shape, [event_size]], axis=0) + sample0 = array_ops.zeros(batch_event_shape) + affine = Affine(scale_tril=self._random_scale_tril(event_size)) + ar = autoregressive_lib.Autoregressive( + self._normal_fn(affine), sample0, validate_args=True) + ar_flow = MaskedAutoregressiveFlow( + is_constant_jacobian=True, + shift_and_log_scale_fn=lambda x: [None, affine.forward(x)], + validate_args=True) + td = transformed_distribution_lib.TransformedDistribution( + distribution=normal_lib.Normal(loc=0., scale=1.), + bijector=ar_flow, + event_shape=[event_size], + batch_shape=batch_shape, + validate_args=True) + x_shape = np.concatenate( + [sample_shape, batch_shape, [event_size]], axis=0) + x = 2. * self._rng.random_sample(x_shape).astype(np.float32) - 1. + td_log_prob_, ar_log_prob_ = sess.run([td.log_prob(x), ar.log_prob(x)]) + self.assertAllClose(td_log_prob_, ar_log_prob_, atol=0., rtol=1e-6) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/real_nvp_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/real_nvp_test.py new file mode 100644 index 0000000000000000000000000000000000000000..46fe7797419a9906ecdad60dd0dfe1e9d7c743ed --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/real_nvp_test.py @@ -0,0 +1,144 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for MaskedAutoregressiveFlow.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +from tensorflow.contrib.distributions.python.ops import test_util +from tensorflow.contrib.distributions.python.ops.bijectors.invert import Invert +from tensorflow.contrib.distributions.python.ops.bijectors.real_nvp import real_nvp_default_template +from tensorflow.contrib.distributions.python.ops.bijectors.real_nvp import RealNVP +from tensorflow.python.framework import constant_op +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import variables +from tensorflow.python.ops.distributions import normal as normal_lib +from tensorflow.python.ops.distributions import transformed_distribution as transformed_distribution_lib +from tensorflow.python.platform import test + + +class RealNVPTest(test_util.VectorDistributionTestHelpers, test.TestCase): + + @property + def _real_nvp_kwargs(self): + return { + "shift_and_log_scale_fn": real_nvp_default_template( + hidden_layers=[3], shift_only=False), + "is_constant_jacobian": False, + } + + def testBijector(self): + x_ = np.arange(3 * 4 * 2).astype(np.float32).reshape(3, 4 * 2) + with self.test_session() as sess: + nvp = RealNVP( + num_masked=4, + validate_args=True, + **self._real_nvp_kwargs) + x = constant_op.constant(x_) + forward_x = nvp.forward(x) + # Use identity to invalidate cache. + inverse_y = nvp.inverse(array_ops.identity(forward_x)) + fldj = nvp.forward_log_det_jacobian(x) + # Use identity to invalidate cache. + ildj = nvp.inverse_log_det_jacobian(array_ops.identity(forward_x)) + variables.global_variables_initializer().run() + [ + forward_x_, + inverse_y_, + ildj_, + fldj_, + ] = sess.run([ + forward_x, + inverse_y, + ildj, + fldj, + ]) + self.assertEqual("real_nvp", nvp.name) + self.assertAllClose(forward_x_, forward_x_, rtol=1e-6, atol=0.) + self.assertAllClose(x_, inverse_y_, rtol=1e-5, atol=0.) + self.assertAllClose(ildj_, -fldj_, rtol=1e-6, atol=0.) + + def testMutuallyConsistent(self): + dims = 4 + with self.test_session() as sess: + nvp = RealNVP( + num_masked=3, + validate_args=True, + **self._real_nvp_kwargs) + dist = transformed_distribution_lib.TransformedDistribution( + distribution=normal_lib.Normal(loc=0., scale=1.), + bijector=nvp, + event_shape=[dims], + validate_args=True) + self.run_test_sample_consistent_log_prob( + sess_run_fn=sess.run, + dist=dist, + num_samples=int(1e5), + radius=1., + center=0., + rtol=0.02) + + def testInvertMutuallyConsistent(self): + dims = 4 + with self.test_session() as sess: + nvp = Invert(RealNVP( + num_masked=3, + validate_args=True, + **self._real_nvp_kwargs)) + dist = transformed_distribution_lib.TransformedDistribution( + distribution=normal_lib.Normal(loc=0., scale=1.), + bijector=nvp, + event_shape=[dims], + validate_args=True) + self.run_test_sample_consistent_log_prob( + sess_run_fn=sess.run, + dist=dist, + num_samples=int(1e5), + radius=1., + center=0., + rtol=0.02) + + +class NICETest(RealNVPTest): + + @property + def _real_nvp_kwargs(self): + return { + "shift_and_log_scale_fn": real_nvp_default_template( + hidden_layers=[2], shift_only=True), + "is_constant_jacobian": True, + } + + +class RealNVPConstantShiftScaleTest(RealNVPTest): + + @property + def _real_nvp_kwargs(self): + + def constant_shift_log_scale_fn(x0, output_units): + del x0, output_units + shift = constant_op.constant([0.1]) + log_scale = constant_op.constant([0.5]) + return shift, log_scale + + return { + "shift_and_log_scale_fn": constant_shift_log_scale_fn, + "is_constant_jacobian": True, + } + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/reshape_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/reshape_test.py index 49451446b56d290f130c5db90c13b94974d92dc9..e216d88cb190dc16fc0056186f80817d6f2d7c67 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/reshape_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/reshape_test.py @@ -22,12 +22,15 @@ import numpy as np from tensorflow.contrib.distributions.python.ops.bijectors.reshape import Reshape from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops.distributions.bijector_test_util import assert_bijective_and_finite from tensorflow.python.platform import test +@test_util.with_c_api class _ReshapeBijectorTest(object): """Base class for testing the reshape transformation. @@ -136,7 +139,8 @@ class _ReshapeBijectorTest(object): sess.run(bijector.forward_event_shape_tensor(shape_in), feed_dict=feed_dict) - def testInvalidDimensionsOpError(self): + # pylint: disable=invalid-name + def _testInvalidDimensionsOpError(self, expected_error_message): with self.test_session() as sess: @@ -146,10 +150,10 @@ class _ReshapeBijectorTest(object): event_shape_in=shape_in, validate_args=True) - with self.assertRaisesError( - "elements must be either positive integers or `-1`."): + with self.assertRaisesError(expected_error_message): sess.run(bijector.forward_event_shape_tensor(shape_in), feed_dict=feed_dict) + # pylint: enable=invalid-name def testValidButNonMatchingInputOpError(self): x = np.random.randn(4, 3, 2) @@ -184,7 +188,8 @@ class _ReshapeBijectorTest(object): sess.run(bijector.forward(x), feed_dict=feed_dict) - def testInputOutputMismatchOpError(self): + # pylint: disable=invalid-name + def _testInputOutputMismatchOpError(self, expected_error_message): x1 = np.random.randn(4, 2, 3) x2 = np.random.randn(4, 1, 1, 5) @@ -196,13 +201,11 @@ class _ReshapeBijectorTest(object): event_shape_in=shape_in, validate_args=True) - # test that *all* methods check basic assertions - with self.assertRaisesError( - "Input to reshape is a tensor with"): + with self.assertRaisesError(expected_error_message): sess.run(bijector.forward(x1), feed_dict=fd_mismatched) - with self.assertRaisesError( - "Input to reshape is a tensor with"): + with self.assertRaisesError(expected_error_message): sess.run(bijector.inverse(x2), feed_dict=fd_mismatched) + # pylint: enable=invalid-name def testOneShapePartiallySpecified(self): expected_x = np.random.randn(4, 6) @@ -262,6 +265,7 @@ class _ReshapeBijectorTest(object): raise NotImplementedError("Subclass failed to implement `build_shapes`.") +@test_util.with_c_api class ReshapeBijectorTestStatic(test.TestCase, _ReshapeBijectorTest): def build_shapes(self, shape_in, shape_out): @@ -299,7 +303,22 @@ class ReshapeBijectorTestStatic(test.TestCase, _ReshapeBijectorTest): validate_args=True) assert_bijective_and_finite(bijector, x, y, rtol=1e-6, atol=0) + def testInvalidDimensionsOpError(self): + if ops._USE_C_API: + error_message = "Invalid value in tensor used for shape: -2" + else: + error_message = "elements must be either positive integers or `-1`." + self._testInvalidDimensionsOpError(error_message) + + def testInputOutputMismatchOpError(self): + if ops._USE_C_API: + error_message = "Cannot reshape a tensor with" + else: + error_message = "Input to reshape is a tensor with" + self._testInputOutputMismatchOpError(error_message) + +@test_util.with_c_api class ReshapeBijectorTestDynamic(test.TestCase, _ReshapeBijectorTest): def build_shapes(self, shape_in, shape_out): @@ -313,7 +332,15 @@ class ReshapeBijectorTestDynamic(test.TestCase, _ReshapeBijectorTest): def assertRaisesError(self, msg): return self.assertRaisesOpError(msg) + def testInvalidDimensionsOpError(self): + self._testInvalidDimensionsOpError( + "elements must be either positive integers or `-1`.") + + def testInputOutputMismatchOpError(self): + self._testInputOutputMismatchOpError("Input to reshape is a tensor with") + +@test_util.with_c_api class ReshapeBijectorTestDynamicNdims(test.TestCase, _ReshapeBijectorTest): def build_shapes(self, shape_in, shape_out): @@ -325,6 +352,13 @@ class ReshapeBijectorTestDynamicNdims(test.TestCase, _ReshapeBijectorTest): def assertRaisesError(self, msg): return self.assertRaisesOpError(msg) + def testInvalidDimensionsOpError(self): + self._testInvalidDimensionsOpError( + "elements must be either positive integers or `-1`.") + + def testInputOutputMismatchOpError(self): + self._testInputOutputMismatchOpError("Input to reshape is a tensor with") + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/vector_diffeomixture_test.py b/tensorflow/contrib/distributions/python/kernel_tests/vector_diffeomixture_test.py index d292b04665e34196670ee4f1c1655f805e04e06a..04f047aa0c81b3f59b97f14554fb59cb1b3dd8af 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/vector_diffeomixture_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/vector_diffeomixture_test.py @@ -27,6 +27,8 @@ from tensorflow.python.ops.linalg import linear_operator_diag as linop_diag_lib from tensorflow.python.ops.linalg import linear_operator_identity as linop_identity_lib from tensorflow.python.platform import test +rng = np.random.RandomState(0) + class VectorDiffeomixtureTest( test_util.VectorDistributionTestHelpers, test.TestCase): @@ -37,7 +39,7 @@ class VectorDiffeomixtureTest( dims = 4 vdm = vdm_lib.VectorDiffeomixture( mix_loc=[[0.], [1.]], - mix_scale=[1.], + temperature=[1.], distribution=normal_lib.Normal(0., 1.), loc=[ None, @@ -66,7 +68,7 @@ class VectorDiffeomixtureTest( dims = 4 vdm = vdm_lib.VectorDiffeomixture( mix_loc=[[0.], [1.]], - mix_scale=[1.], + temperature=[1.], distribution=normal_lib.Normal(1., 1.5), loc=[ None, @@ -95,7 +97,7 @@ class VectorDiffeomixtureTest( dims = 4 vdm = vdm_lib.VectorDiffeomixture( mix_loc=[[0.], [1.]], - mix_scale=[1.], + temperature=[1.], distribution=normal_lib.Normal(0., 1.), loc=[ None, @@ -122,12 +124,39 @@ class VectorDiffeomixtureTest( self.run_test_sample_consistent_log_prob( sess.run, vdm, radius=4., center=2., rtol=0.01) + def testSampleProbConsistentBroadcastMixTwoBatchDims(self): + dims = 4 + loc_1 = rng.randn(2, 3, dims).astype(np.float32) + + with self.test_session() as sess: + vdm = vdm_lib.VectorDiffeomixture( + mix_loc=(rng.rand(2, 3, 1) - 0.5).astype(np.float32), + temperature=[1.], + distribution=normal_lib.Normal(0., 1.), + loc=[ + None, + loc_1, + ], + scale=[ + linop_identity_lib.LinearOperatorScaledIdentity( + num_rows=dims, + multiplier=[np.float32(1.1)], + is_positive_definite=True), + ] * 2, + validate_args=True) + # Ball centered at component0's mean. + self.run_test_sample_consistent_log_prob( + sess.run, vdm, radius=2., center=0., rtol=0.01) + # Larger ball centered at component1's mean. + self.run_test_sample_consistent_log_prob( + sess.run, vdm, radius=3., center=loc_1, rtol=0.02) + def testMeanCovarianceNoBatch(self): with self.test_session() as sess: dims = 3 vdm = vdm_lib.VectorDiffeomixture( mix_loc=[[0.], [4.]], - mix_scale=[10.], + temperature=[1 / 10.], distribution=normal_lib.Normal(0., 1.), loc=[ np.float32([-2.]), @@ -147,12 +176,94 @@ class VectorDiffeomixtureTest( self.run_test_sample_consistent_mean_covariance( sess.run, vdm, rtol=0.02, cov_rtol=0.08) + def testTemperatureControlsHowMuchThisLooksLikeDiscreteMixture(self): + # As temperature decreases, this should approach a mixture of normals, with + # components at -2, 2. + with self.test_session() as sess: + dims = 1 + vdm = vdm_lib.VectorDiffeomixture( + mix_loc=[0.], + temperature=[[2.], [1.], [0.2]], + distribution=normal_lib.Normal(0., 1.), + loc=[ + np.float32([-2.]), + np.float32([2.]), + ], + scale=[ + linop_identity_lib.LinearOperatorScaledIdentity( + num_rows=dims, + multiplier=np.float32(0.5), + is_positive_definite=True), + ] * 2, # Use the same scale for each component. + quadrature_size=8, + validate_args=True) + + samps = vdm.sample(10000) + self.assertAllEqual((10000, 3, 1), samps.shape) + samps_ = sess.run(samps).reshape(10000, 3) # Make scalar event shape. + + # One characteristic of a discrete mixture (as opposed to a "smear") is + # that more weight is put near the component centers at -2, 2, and thus + # less weight is put near the origin. + prob_of_being_near_origin = (np.abs(samps_) < 1).mean(axis=0) + self.assertGreater( + prob_of_being_near_origin[0], prob_of_being_near_origin[1]) + self.assertGreater( + prob_of_being_near_origin[1], prob_of_being_near_origin[2]) + + # Run this test as well, just because we can. + self.run_test_sample_consistent_mean_covariance( + sess.run, vdm, rtol=0.02, cov_rtol=0.08) + + def testConcentrationLocControlsHowMuchWeightIsOnEachComponent(self): + with self.test_session() as sess: + dims = 1 + vdm = vdm_lib.VectorDiffeomixture( + mix_loc=[[-1.], [0.], [1.]], + temperature=[0.5], + distribution=normal_lib.Normal(0., 1.), + loc=[ + np.float32([-2.]), + np.float32([2.]), + ], + scale=[ + linop_identity_lib.LinearOperatorScaledIdentity( + num_rows=dims, + multiplier=np.float32(0.5), + is_positive_definite=True), + ] * 2, # Use the same scale for each component. + quadrature_size=8, + validate_args=True) + + samps = vdm.sample(10000) + self.assertAllEqual((10000, 3, 1), samps.shape) + samps_ = sess.run(samps).reshape(10000, 3) # Make scalar event shape. + + # One characteristic of putting more weight on a component is that the + # mean is closer to that component's mean. + # Get the mean for each batch member, the names signify the value of + # concentration for that batch member. + mean_neg1, mean_0, mean_1 = samps_.mean(axis=0) + + # Since concentration is the concentration for component 0, + # concentration = -1 ==> more weight on component 1, which has mean = 2 + # concentration = 0 ==> equal weight + # concentration = 1 ==> more weight on component 0, which has mean = -2 + self.assertLess(-2, mean_1) + self.assertLess(mean_1, mean_0) + self.assertLess(mean_0, mean_neg1) + self.assertLess(mean_neg1, 2) + + # Run this test as well, just because we can. + self.run_test_sample_consistent_mean_covariance( + sess.run, vdm, rtol=0.02, cov_rtol=0.08) + def testMeanCovarianceNoBatchUncenteredNonStandardBase(self): with self.test_session() as sess: dims = 3 vdm = vdm_lib.VectorDiffeomixture( mix_loc=[[0.], [4.]], - mix_scale=[10.], + temperature=[0.1], distribution=normal_lib.Normal(-1., 1.5), loc=[ np.float32([-2.]), @@ -177,7 +288,7 @@ class VectorDiffeomixtureTest( dims = 3 vdm = vdm_lib.VectorDiffeomixture( mix_loc=[[0.], [4.]], - mix_scale=[10.], + temperature=[0.1], distribution=normal_lib.Normal(0., 1.), loc=[ np.float32([[-2.]]), @@ -205,7 +316,7 @@ class VectorDiffeomixtureTest( dims = 4 vdm = vdm_lib.VectorDiffeomixture( mix_loc=[0.], - mix_scale=[1.], + temperature=[0.1], distribution=normal_lib.Normal(0., 1.), loc=[ None, @@ -229,29 +340,6 @@ class VectorDiffeomixtureTest( self.run_test_sample_consistent_log_prob( sess.run, vdm, radius=4., center=2., rtol=0.005) - # TODO(jvdillon): We've tested that (i) .sample and .log_prob are consistent, - # (ii) .mean, .stddev etc... and .sample are consistent. However, we haven't - # tested that the quadrature approach well-approximates the integral. - # - # To that end, consider adding these tests: - # - # Test1: In the limit of high mix_scale, this approximates a discrete mixture, - # and there are many discrete mixtures where we can explicitly compute - # mean/var, etc... So test1 would choose one of those discrete mixtures and - # show our mean/var/etc... is close to that. - # - # Test2: In the limit of low mix_scale, the a diffeomixture of Normal(-5, 1), - # Normal(5, 1) should (I believe...must check) should look almost like - # Uniform(-5, 5), and thus (i) .prob(x) should be about 1/10 for x in (-5, 5), - # and (ii) the first few moments should approximately match that of - # Uniform(-5, 5) - # - # Test3: If mix_loc is symmetric, then for any mix_scale, our - # quadrature-based diffeomixture of Normal(-1, 1), Normal(1, 1) should have - # mean zero, exactly. - - # TODO(jvdillon): Add more tests which verify broadcasting. - if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distributions/python/ops/autoregressive.py b/tensorflow/contrib/distributions/python/ops/autoregressive.py new file mode 100644 index 0000000000000000000000000000000000000000..852298bf334666db003353d5fc8e172ffb738668 --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/autoregressive.py @@ -0,0 +1,208 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""The Autoregressive distribution.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.framework import ops +from tensorflow.python.ops.distributions import distribution as distribution_lib +from tensorflow.python.ops.distributions import util as distribution_util + + +class Autoregressive(distribution_lib.Distribution): + """Autoregressive distributions. + + The Autoregressive distribution enables learning (often) richer multivariate + distributions by repeatedly applying a [diffeomorphic]( + https://en.wikipedia.org/wiki/Diffeomorphism) transformation (such as + implemented by `Bijector`s). Regarding terminology, + + "Autoregressive models decompose the joint density as a product of + conditionals, and model each conditional in turn. Normalizing flows + transform a base density (e.g. a standard Gaussian) into the target density + by an invertible transformation with tractable Jacobian." [1] + + In other words, the "autoregressive property" is equivalent to the + decomposition, `p(x) = prod{ p(x[i] | x[0:i]) : i=0, ..., d }`. The provided + `shift_and_log_scale_fn`, `masked_autoregressive_default_template`, achieves + this property by zeroing out weights in its `masked_dense` layers. + + Practically speaking the autoregressive property means that there exists a + permutation of the event coordinates such that each coordinate is a + diffeomorphic function of only preceding coordinates. [2] + + #### Mathematical Details + + The probability function is, + + ```none + prob(x; fn, n) = fn(x).prob(x) + ``` + + And a sample is generated by, + + ```none + x = fn(...fn(fn(x0).sample()).sample()).sample() + ``` + + where the ellipses (`...`) represent `n-2` composed calls to `fn`, `fn` + constructs a `tf.distributions.Distribution`-like instance, and `x0` is a + fixed initializing `Tensor`. + + #### Examples + + ```python + tfd = tf.contrib.distributions + + def normal_fn(self, event_size): + n = event_size * (event_size + 1) / 2 + p = tf.Variable(tfd.Normal(loc=0., scale=1.).sample(n)) + affine = tfd.bijectors.Affine( + scale_tril=tfd.fill_triangular(0.25 * p)) + def _fn(samples): + scale = math_ops.exp(affine.forward(samples)).eval() + return independent_lib.Independent( + normal_lib.Normal(loc=0., scale=scale, validate_args=True), + reinterpreted_batch_ndims=1) + return _fn + + batch_and_event_shape = [3, 2, 4] + sample0 = array_ops.zeros(batch_and_event_shape) + ar = autoregressive_lib.Autoregressive( + self._normal_fn(batch_and_event_shape[-1]), sample0) + x = ar.sample([6, 5]) + # ==> x.shape = [6, 5, 3, 2, 4] + prob_x = ar.prob(x) + # ==> x.shape = [6, 5, 3, 2] + + ``` + + [1]: "Masked Autoregressive Flow for Density Estimation." + George Papamakarios, Theo Pavlakou, Iain Murray. Arxiv. 2017. + https://arxiv.org/abs/1705.07057 + + [2]: "Conditional Image Generation with PixelCNN Decoders." + Aaron van den Oord, Nal Kalchbrenner, Oriol Vinyals, Lasse Espeholt, Alex + Graves, Koray Kavukcuoglu. Arxiv, 2016. + https://arxiv.org/abs/1606.05328 + """ + + def __init__(self, + distribution_fn, + sample0=None, + num_steps=None, + validate_args=False, + allow_nan_stats=True, + name="Autoregressive"): + """Construct an `Autoregressive` distribution. + + Args: + distribution_fn: Python `callable` which constructs a + `tf.distributions.Distribution`-like instance from a `Tensor` (e.g., + `sample0`). The function must respect the "autoregressive property", + i.e., there exists a permutation of event such that each coordinate is a + diffeomorphic function of on preceding coordinates. + sample0: Initial input to `distribution_fn`; used to + build the distribution in `__init__` which in turn specifies this + distribution's properties, e.g., `event_shape`, `batch_shape`, `dtype`. + If unspecified, then `distribution_fn` should be default constructable. + num_steps: Number of times `distribution_fn` is composed from samples, + e.g., `num_steps=2` implies + `distribution_fn(distribution_fn(sample0).sample(n)).sample()`. + validate_args: Python `bool`. Whether to validate input with asserts. + If `validate_args` is `False`, and the inputs are invalid, + correct behavior is not guaranteed. + allow_nan_stats: Python `bool`, default `True`. When `True`, statistics + (e.g., mean, mode, variance) use the value "`NaN`" to indicate the + result is undefined. When `False`, an exception is raised if one or + more of the statistic's batch members are undefined. + name: Python `str` name prefixed to Ops created by this class. + Default value: "Autoregressive". + + Raises: + ValueError: if `num_steps` and + `distribution_fn(sample0).event_shape.num_elements()` are both `None`. + ValueError: if `num_steps < 1`. + """ + parameters = locals() + with ops.name_scope(name): + self._distribution_fn = distribution_fn + self._sample0 = sample0 + self._distribution0 = (distribution_fn() if sample0 is None + else distribution_fn(sample0)) + if num_steps is None: + num_steps = self._distribution0.event_shape.num_elements() + if num_steps is None: + raise ValueError("distribution_fn must generate a distribution " + "with fully known `event_shape`.") + if num_steps < 1: + raise ValueError("num_steps ({}) must be at least 1.".format(num_steps)) + self._num_steps = num_steps + super(Autoregressive, self).__init__( + dtype=self._distribution0.dtype, + reparameterization_type=self._distribution0.reparameterization_type, + validate_args=validate_args, + allow_nan_stats=allow_nan_stats, + parameters=parameters, + graph_parents=self._distribution0._graph_parents, # pylint: disable=protected-access + name=name) + + @property + def distribution_fn(self): + return self._distribution_fn + + @property + def sample0(self): + return self._sample0 + + @property + def num_steps(self): + return self._num_steps + + @property + def distribution0(self): + return self._distribution0 + + def _batch_shape(self): + return self.distribution0.batch_shape + + def _batch_shape_tensor(self): + return self.distribution0.batch_shape_tensor() + + def _event_shape(self): + return self.distribution0.event_shape + + def _event_shape_tensor(self): + return self.distribution0.event_shape_tensor() + + def _sample_n(self, n, seed=None): + if seed is None: + seed = distribution_util.gen_new_seed( + seed=np.random.randint(2**32 - 1), + salt="autoregressive") + samples = self.distribution0.sample(n, seed=seed) + for _ in range(self._num_steps): + samples = self.distribution_fn(samples).sample(seed=seed) + return samples + + def _log_prob(self, value): + return self.distribution_fn(value).log_prob(value) + + def _prob(self, value): + return self.distribution_fn(value).prob(value) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py b/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py index bc0ec7f195af009c87020ce8c4ea18f2e713759a..93923c3f083c7f5136b55e9021cbd6323684b976 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py @@ -29,6 +29,7 @@ @@MaskedAutoregressiveFlow @@Permute @@PowerTransform +@@RealNVP @@Reshape @@Sigmoid @@SigmoidCentered @@ -39,6 +40,7 @@ @@masked_autoregressive_default_template @@masked_dense +@@real_nvp_default_template """ from __future__ import absolute_import @@ -60,6 +62,7 @@ from tensorflow.contrib.distributions.python.ops.bijectors.invert import * from tensorflow.contrib.distributions.python.ops.bijectors.masked_autoregressive import * from tensorflow.contrib.distributions.python.ops.bijectors.permute import * from tensorflow.contrib.distributions.python.ops.bijectors.power_transform import * +from tensorflow.contrib.distributions.python.ops.bijectors.real_nvp import * from tensorflow.contrib.distributions.python.ops.bijectors.reshape import * from tensorflow.contrib.distributions.python.ops.bijectors.sigmoid import * from tensorflow.contrib.distributions.python.ops.bijectors.sigmoid_centered import * diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/real_nvp.py b/tensorflow/contrib/distributions/python/ops/bijectors/real_nvp.py new file mode 100644 index 0000000000000000000000000000000000000000..2840f52e742eac5e9e37a576bf7f6d6f05a07a35 --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/bijectors/real_nvp.py @@ -0,0 +1,282 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Real NVP bijector.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.layers import core as layers +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import template as template_ops +from tensorflow.python.ops.distributions import bijector as bijector_lib + + +__all__ = [ + "RealNVP", + "real_nvp_default_template" +] + + +class RealNVP(bijector_lib.Bijector): + """RealNVP "affine coupling layer" for vector-valued events. + + Real NVP models a normalizing flow on a `D`-dimensional distribution via a + single `D-d`-dimensional conditional distribution [1]: + + `y[d:D] = y[d:D] * math_ops.exp(log_scale_fn(y[d:D])) + shift_fn(y[d:D])` + `y[0:d] = x[0:d]` + + The last `D-d` units are scaled and shifted based on the first `d` units only, + while the first `d` units are 'masked' and left unchanged. Real NVP's + `shift_and_log_scale_fn` computes vector-valued quantities. For + scale-and-shift transforms that do not depend on any masked units, i.e. + `d=0`, use the `tfb.Affine` bijector with learned parameters instead. + + Masking is currently only supported for base distributions with + `event_ndims=1`. For more sophisticated masking schemes like checkerboard or + channel-wise masking [2], use the `tfb.Permute` bijector to re-order desired + masked units into the first `d` units. For base distributions with + `event_ndims > 1`, use the `tfb.Reshape` bijector to flatten the event shape. + + Recall that the MAF bijector [2] implements a normalizing flow via an + autoregressive transformation. MAF and IAF have opposite computational + tradeoffs - MAF can train all units in parallel but must sample units + sequentially, while IAF must train units sequentially but can sample in + parallel. In contrast, Real NVP can compute both forward and inverse + computations in parallel. However, the lack of an autoregressive + transformations makes it less expressive on a per-bijector basis. + + A "valid" `shift_and_log_scale_fn` must compute each `shift` (aka `loc` or + "mu" [2]) and `log(scale)` (aka "alpha" [2]) such that each are broadcastable + with the arguments to `forward` and `inverse`, i.e., such that the + calculations in `forward`, `inverse` [below] are possible. For convenience, + `real_nvp_default_nvp` is offered as a possible `shift_and_log_scale_fn` + function. + + NICE [3] is a special case of the Real NVP bijector which discards the scale + transformation, resulting in a constant-time inverse-log-determinant-Jacobian. + To use a NICE bijector instead of Real NVP, `shift_and_log_scale_fn` should + return `(shift, None)`, and `is_constant_jacobian` should be set to `True` in + the `RealNVP` constructor. Calling `real_nvp_default_template` with + `shift_only=True` returns one such NICE-compatible `shift_and_log_scale_fn`. + + Caching: the scalar input depth `D` of the base distribution is not known at + construction time. The first call to any of `forward(x)`, `inverse(x)`, + `inverse_log_det_jacobian(x)`, or `forward_log_det_jacobian(x)` memoizes + `D`, which is re-used in subsequent calls. This shape must be known prior to + graph execution (which is the case if using tf.layers). + + #### Example Use + + ```python + tfd = tf.contrib.distributions + tfb = tfd.bijectors + + # A common choice for a normalizing flow is to use a Gaussian for the base + # distribution. (However, any continuous distribution would work.) E.g., + nvp = tfd.TransformedDistribution( + distribution=tfd.MultivariateNormalDiag(loc=[0., 0., 0.])), + bijector=tfb.RealNVP( + num_masked=2, + shift_and_log_scale_fn=tfb.real_nvp_default_template( + hidden_layers=[512, 512]))) + + x = nvp.sample() + nvp.log_prob(x) + nvp.log_prob(0.) + ``` + + For more examples, see [4]. + + [1]: "Density Estimation using Real NVP." + Laurent Dinh, Jascha Sohl-Dickstein, Samy Bengio. ICLR. 2017. + https://arxiv.org/abs/1605.08803 + + [2]: "Masked Autoregressive Flow for Density Estimation." + George Papamakarios, Theo Pavlakou, Iain Murray. Arxiv. 2017. + https://arxiv.org/abs/1705.07057 + + [3]: "NICE: Non-linear Independent Components Estimation." + Laurent Dinh, David Krueger, Yoshua Bengio. ICLR. 2015. + https://arxiv.org/abs/1410.8516 + + [4]: "Normalizing Flows Tutorial, Part 2: Modern Normalizing Flows." + Eric Jang. Blog post. January 2018. + http://blog.evjang.com/2018/01/nf2.html + """ + + def __init__(self, + num_masked, + shift_and_log_scale_fn, + is_constant_jacobian=False, + validate_args=False, + name=None): + """Creates the Real NVP or NICE bijector. + + Args: + num_masked: Python `int` indicating that the first `d` units of the event + should be masked. Must be in the closed interval `[1, D-1]`, where `D` + is the event size of the base distribution. + shift_and_log_scale_fn: Python `callable` which computes `shift` and + `log_scale` from both the forward domain (`x`) and the inverse domain + (`y`). Calculation must respect the "autoregressive property" (see class + docstring). Suggested default + `masked_autoregressive_default_template(hidden_layers=...)`. + Typically the function contains `tf.Variables` and is wrapped using + `tf.make_template`. Returning `None` for either (both) `shift`, + `log_scale` is equivalent to (but more efficient than) returning zero. + is_constant_jacobian: Python `bool`. Default: `False`. When `True` the + implementation assumes `log_scale` does not depend on the forward domain + (`x`) or inverse domain (`y`) values. (No validation is made; + `is_constant_jacobian=False` is always safe but possibly computationally + inefficient.) + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + name: Python `str`, name given to ops managed by this object. + + Raises: + ValueError: If num_masked < 1. + """ + name = name or "real_nvp" + if num_masked <= 0: + raise ValueError("num_masked must be a positive integer.") + self._num_masked = num_masked + # At construction time, we don't know input_depth. + self._input_depth = None + self._shift_and_log_scale_fn = shift_and_log_scale_fn + super(RealNVP, self).__init__( + event_ndims=1, + is_constant_jacobian=is_constant_jacobian, + validate_args=validate_args, + name=name) + + def _cache_input_depth(self, x): + if self._input_depth is None: + self._input_depth = x.shape.with_rank_at_least(1)[-1].value + if self._input_depth is None: + raise NotImplementedError( + "Rightmost dimension must be known prior to graph execution.") + if self._num_masked >= self._input_depth: + raise ValueError( + "Number of masked units must be smaller than the event size.") + + def _forward(self, x): + self._cache_input_depth(x) + # Performs scale and shift. + x0, x1 = x[:, :self._num_masked], x[:, self._num_masked:] + shift, log_scale = self._shift_and_log_scale_fn( + x0, self._input_depth - self._num_masked) + y1 = x1 + if log_scale is not None: + y1 *= math_ops.exp(log_scale) + if shift is not None: + y1 += shift + y = array_ops.concat([x0, y1], axis=-1) + return y + + def _inverse(self, y): + self._cache_input_depth(y) + # Performs un-shift and un-scale. + y0, y1 = y[:, :self._num_masked], y[:, self._num_masked:] + shift, log_scale = self._shift_and_log_scale_fn( + y0, self._input_depth - self._num_masked) + x1 = y1 + if shift is not None: + x1 -= shift + if log_scale is not None: + x1 *= math_ops.exp(-log_scale) + x = array_ops.concat([y0, x1], axis=-1) + return x + + def _inverse_log_det_jacobian(self, y): + self._cache_input_depth(y) + y0 = y[:, :self._num_masked] + _, log_scale = self._shift_and_log_scale_fn( + y0, self._input_depth - self._num_masked) + if log_scale is None: + return constant_op.constant(0., dtype=y.dtype, name="ildj") + return -math_ops.reduce_sum(log_scale, axis=-1) + + def _forward_log_det_jacobian(self, x): + self._cache_input_depth(x) + x0 = x[:, :self._num_masked] + _, log_scale = self._shift_and_log_scale_fn( + x0, self._input_depth - self._num_masked) + if log_scale is None: + return constant_op.constant(0., dtype=x.dtype, name="ildj") + return math_ops.reduce_sum(log_scale, axis=-1) + + +def real_nvp_default_template( + hidden_layers, + shift_only=False, + activation=nn_ops.relu, + name=None, + *args, + **kwargs): + """Build a scale-and-shift function using a multi-layer neural network. + + This will be wrapped in a make_template to ensure the variables are only + created once. It takes the `d`-dimensional input x[0:d] and returns the `D-d` + dimensional outputs `loc` ("mu") and `log_scale` ("alpha"). + + Arguments: + hidden_layers: Python `list`-like of non-negative integer, scalars + indicating the number of units in each hidden layer. Default: `[512, 512]. + shift_only: Python `bool` indicating if only the `shift` term shall be + computed (i.e. NICE bijector). Default: `False`. + activation: Activation function (callable). Explicitly setting to `None` + implies a linear activation. + name: A name for ops managed by this function. Default: + "real_nvp_default_template". + *args: `tf.layers.dense` arguments. + **kwargs: `tf.layers.dense` keyword arguments. + + Returns: + shift: `Float`-like `Tensor` of shift terms (the "mu" in [2]). + log_scale: `Float`-like `Tensor` of log(scale) terms (the "alpha" in [2]). + + Raises: + NotImplementedError: if rightmost dimension of `inputs` is unknown prior to + graph execution. + """ + + with ops.name_scope(name, "real_nvp_default_template"): + def _fn(x, output_units): + """Fully connected MLP parameterized via `real_nvp_template`.""" + for units in hidden_layers: + x = layers.dense( + inputs=x, + units=units, + activation=activation, + *args, + **kwargs) + x = layers.dense( + inputs=x, + units=(1 if shift_only else 2) * output_units, + activation=None, + *args, + **kwargs) + if shift_only: + return x, None + shift, log_scale = array_ops.split(x, 2, axis=-1) + return shift, log_scale + return template_ops.make_template( + "real_nvp_default_template", _fn) diff --git a/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py b/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py index 7ce8a83fd91e2dfaa0ccef633f803b3ae595e646..0c747f8e68529484ae6f695b8500cde74857bb11 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py +++ b/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py @@ -50,20 +50,25 @@ __all__ = [ def quadrature_scheme_softmaxnormal_gauss_hermite( - loc, scale, quadrature_size, + normal_loc, normal_scale, quadrature_size, validate_args=False, name=None): """Use Gauss-Hermite quadrature to form quadrature on `K - 1` simplex. + A `SoftmaxNormal` random variable `Y` may be generated via + + ``` + Y = SoftmaxCentered(X), + X = Normal(normal_loc, normal_scale) + ``` + Note: for a given `quadrature_size`, this method is generally less accurate than `quadrature_scheme_softmaxnormal_quantiles`. Args: - loc: `float`-like `Tensor` with shape `[b1, ..., bB, K-1]`, B>=0. - Represents the `location` parameter of the SoftmaxNormal used for - selecting one of the `K` affine transformations. - scale: `float`-like `Tensor` with shape `[b1, ..., bB, K-1]`, B>=0. - Represents the `scale` parameter of the SoftmaxNormal used for - selecting one of the `K` affine transformations. + normal_loc: `float`-like `Tensor` with shape `[b1, ..., bB, K-1]`, B>=0. + The location parameter of the Normal used to construct the SoftmaxNormal. + normal_scale: `float`-like `Tensor`. Broadcastable with `normal_loc`. + The scale parameter of the Normal used to construct the SoftmaxNormal. quadrature_size: Python `int` scalar representing the number of quadrature points. validate_args: Python `bool`, default `False`. When `True` distribution @@ -80,24 +85,25 @@ def quadrature_scheme_softmaxnormal_gauss_hermite( associated with each grid point. """ with ops.name_scope(name, "quadrature_scheme_softmaxnormal_gauss_hermite", - [loc, scale]): - loc = ops.convert_to_tensor(loc, name="loc") - dt = loc.dtype.base_dtype - scale = ops.convert_to_tensor(scale, dtype=dt, name="scale") + [normal_loc, normal_scale]): + normal_loc = ops.convert_to_tensor(normal_loc, name="normal_loc") + dt = normal_loc.dtype.base_dtype + normal_scale = ops.convert_to_tensor( + normal_scale, dtype=dt, name="normal_scale") - loc = maybe_check_quadrature_param(loc, "loc", validate_args) - scale = maybe_check_quadrature_param(scale, "scale", validate_args) + normal_scale = maybe_check_quadrature_param( + normal_scale, "normal_scale", validate_args) grid, probs = np.polynomial.hermite.hermgauss(deg=quadrature_size) - grid = grid.astype(loc.dtype.as_numpy_dtype) - probs = probs.astype(loc.dtype.as_numpy_dtype) + grid = grid.astype(dt.dtype.as_numpy_dtype) + probs = probs.astype(dt.dtype.as_numpy_dtype) probs /= np.linalg.norm(probs, ord=1, keepdims=True) - probs = ops.convert_to_tensor(probs, name="probs", dtype=loc.dtype) + probs = ops.convert_to_tensor(probs, name="probs", dtype=dt) grid = softmax( -distribution_util.pad( - (loc[..., array_ops.newaxis] + - np.sqrt(2.) * scale[..., array_ops.newaxis] * grid), + (normal_loc[..., array_ops.newaxis] + + np.sqrt(2.) * normal_scale[..., array_ops.newaxis] * grid), axis=-2, front=True), axis=-2) # shape: [B, components, deg] @@ -106,18 +112,23 @@ def quadrature_scheme_softmaxnormal_gauss_hermite( def quadrature_scheme_softmaxnormal_quantiles( - loc, scale, quadrature_size, + normal_loc, normal_scale, quadrature_size, validate_args=False, name=None): """Use SoftmaxNormal quantiles to form quadrature on `K - 1` simplex. + A `SoftmaxNormal` random variable `Y` may be generated via + + ``` + Y = SoftmaxCentered(X), + X = Normal(normal_loc, normal_scale) + ``` + Args: - loc: `float`-like `Tensor` with shape `[b1, ..., bB, K-1]`, B>=0. - Represents the `location` parameter of the SoftmaxNormal used for - selecting one of the `K` affine transformations. - scale: `float`-like `Tensor` with shape `[b1, ..., bB, K-1]`, B>=0. - Represents the `scale` parameter of the SoftmaxNormal used for - selecting one of the `K` affine transformations. - quadrature_size: Python scalar `int` representing the number of quadrature + normal_loc: `float`-like `Tensor` with shape `[b1, ..., bB, K-1]`, B>=0. + The location parameter of the Normal used to construct the SoftmaxNormal. + normal_scale: `float`-like `Tensor`. Broadcastable with `normal_loc`. + The scale parameter of the Normal used to construct the SoftmaxNormal. + quadrature_size: Python `int` scalar representing the number of quadrature points. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime @@ -132,15 +143,17 @@ def quadrature_scheme_softmaxnormal_quantiles( probs: Shape `[b1, ..., bB, K, quadrature_size]` `Tensor` representing the associated with each grid point. """ - with ops.name_scope(name, "softmax_normal_grid_and_probs", [loc, scale]): - loc = ops.convert_to_tensor(loc, name="loc") - dt = loc.dtype.base_dtype - scale = ops.convert_to_tensor(scale, dtype=dt, name="scale") + with ops.name_scope(name, "softmax_normal_grid_and_probs", + [normal_loc, normal_scale]): + normal_loc = ops.convert_to_tensor(normal_loc, name="normal_loc") + dt = normal_loc.dtype.base_dtype + normal_scale = ops.convert_to_tensor( + normal_scale, dtype=dt, name="normal_scale") - loc = maybe_check_quadrature_param(loc, "loc", validate_args) - scale = maybe_check_quadrature_param(scale, "scale", validate_args) + normal_scale = maybe_check_quadrature_param( + normal_scale, "normal_scale", validate_args) - dist = normal_lib.Normal(loc=loc, scale=scale) + dist = normal_lib.Normal(loc=normal_loc, scale=normal_scale) def _get_batch_ndims(): """Helper to get dist.batch_shape.ndims, statically if possible.""" @@ -195,114 +208,51 @@ def quadrature_scheme_softmaxnormal_quantiles( class VectorDiffeomixture(distribution_lib.Distribution): """VectorDiffeomixture distribution. - The VectorDiffeomixture is an approximation to a [compound distribution]( - https://en.wikipedia.org/wiki/Compound_probability_distribution), i.e., + A vector diffeomixture (VDM) is a distribution parameterized by a convex + combination of `K` component `loc` vectors, `loc[k], k = 0,...,K-1`, and `K` + `scale` matrices `scale[k], k = 0,..., K-1`. It approximates the following + [compound distribution] + (https://en.wikipedia.org/wiki/Compound_probability_distribution) ```none - p(x) = int_{X} q(x | v) p(v) dv - = lim_{Q->infty} sum{ prob[i] q(x | loc=sum_k^K lambda[k;i] loc[k], - scale=sum_k^K lambda[k;i] scale[k]) - : i=0, ..., Q-1 } + p(x) = int p(x | z) p(z) dz, + where z is in the K-simplex, and + p(x | z) := p(x | loc=sum_k z[k] loc[k], scale=sum_k z[k] scale[k]) ``` - where `q(x | v)` is a vector version of the `distribution` argument and `p(v)` - is a SoftmaxNormal parameterized by `mix_loc` and `mix_scale`. The - vector-ization of `distribution` entails an affine transformation of iid - samples from `distribution`. The `prob` term is from quadrature and - `lambda[k] = sigmoid(mix_loc[k] + sqrt(2) mix_scale[k] grid[k])` where the - `grid` points correspond to the `prob`s. - - In the non-approximation case, a draw from the mixture distribution (the - "prior") represents the convex weights for different affine transformations. - I.e., draw a mixing vector `v` (from the `K-1`-simplex) and let the final - sample be: `y = (sum_k^K v[k] scale[k]) @ x + (sum_k^K v[k] loc[k])` where `@` - denotes matrix multiplication. However, the non-approximate distribution does - not have an analytical probability density function (pdf). Therefore the - `VectorDiffeomixture` class implements an approximation based on - [numerical quadrature]( - https://en.wikipedia.org/wiki/Numerical_integration) (default: - [Gauss--Hermite quadrature]( - https://en.wikipedia.org/wiki/Gauss%E2%80%93Hermite_quadrature)). I.e., in - Note: although the `VectorDiffeomixture` is approximately the - `SoftmaxNormal-Distribution` compound distribution, it is itself a valid - distribution. It possesses a `sample`, `log_prob`, `mean`, `covariance` which - are all mutually consistent. - - #### Intended Use - - This distribution is noteworthy because it implements a mixture of - `Vector`-ized distributions yet has samples differentiable in the - distribution's parameters (aka "reparameterized"). It has an analytical - density function with `O(dKQ)` complexity. `d` is the vector dimensionality, - `K` is the number of components, and `Q` is the number of quadrature points. - These properties make it well-suited for Bayesian Variational Inference, i.e., - as a surrogate family for the posterior. - - For large values of `mix_scale`, the `VectorDistribution` behaves increasingly - like a discrete mixture. (In most cases this limit is only achievable by also - increasing the quadrature polynomial degree, `Q`.) - - The term `Vector` is consistent with similar named Tensorflow `Distribution`s. - For more details, see the "About `Vector` distributions in Tensorflow." - section. - - The term `Diffeomixture` is a portmanteau of - [diffeomorphism](https://en.wikipedia.org/wiki/Diffeomorphism) and [compound - mixture](https://en.wikipedia.org/wiki/Compound_probability_distribution). For - more details, see the "About `Diffeomixture`s and reparametrization.`" - section. - - #### Mathematical Details - - The `VectorDiffeomixture` approximates a SoftmaxNormal-mixed ("prior") - [compound distribution]( - https://en.wikipedia.org/wiki/Compound_probability_distribution). - Using variable-substitution and [numerical quadrature]( - https://en.wikipedia.org/wiki/Numerical_integration) (default: - [Gauss--Hermite quadrature]( - https://en.wikipedia.org/wiki/Gauss%E2%80%93Hermite_quadrature)) we can - redefine the distribution to be a parameter-less convex combination of `K` - different affine combinations of a `d` iid samples from `distribution`. - - That is, defined over `R**d` this distribution is parameterized by a - (batch of) length-`K` `mix_loc` and `mix_scale` vectors, a length-`K` list of - (a batch of) length-`d` `loc` vectors, and a length-`K` list of `scale` - `LinearOperator`s each operating on a (batch of) length-`d` vector space. - Finally, a `distribution` parameter specifies the underlying base distribution - which is "lifted" to become multivariate ("lifting" is the same concept as in - `TransformedDistribution`). - - The probability density function (pdf) is, + The integral `int p(x | z) p(z) dz` is approximated with a quadrature scheme + adapted to the mixture density `p(z)`. The `N` quadrature points `z_{N, n}` + and weights `w_{N, n}` (which are non-negative and sum to 1) are chosen + such that - ```none - pdf(y; mix_loc, mix_scale, loc, scale, phi) - = sum{ prob[i] phi(f_inverse(x; i)) / abs(det(interp_scale[i])) - : i=0, ..., Q-1 } - ``` + ```q_N(x) := sum_{n=1}^N w_{n, N} p(x | z_{N, n}) --> p(x)``` - where, `phi` is the base distribution pdf, and, + as `N --> infinity`. - ```none - f_inverse(x; i) = inv(interp_scale[i]) @ (x - interp_loc[i]) - interp_loc[i] = sum{ lambda[k; i] loc[k] : k=0, ..., K-1 } - interp_scale[i] = sum{ lambda[k; i] scale[k] : k=0, ..., K-1 } - ``` + Since `q_N(x)` is in fact a mixture (of `N` points), we may sample from + `q_N` exactly. It is important to note that the VDM is *defined* as `q_N` + above, and *not* `p(x)`. Therefore, sampling and pdf may be implemented as + exact (up to floating point error) methods. - and, + A common choice for the conditional `p(x | z)` is a multivariate Normal. - ```none - grid, weight = np.polynomial.hermite.hermgauss(quadrature_size) - prob[k] = weight[k] / sqrt(pi) - lambda[k; i] = sigmoid(mix_loc[k] + sqrt(2) mix_scale[k] grid[i]) + The implemented marginal `p(z)` is the `SoftmaxNormal`, which is a + `K-1` dimensional Normal transformed by a `SoftmaxCentered` bijector, making + it a density on the `K`-simplex. That is, + + ``` + Z = SoftmaxCentered(X), + X = Normal(mix_loc / temperature, 1 / temperature) ``` - The distribution corresponding to `phi` must be a scalar-batch, scalar-event - distribution. Typically it is reparameterized. If not, it must be a function - of non-trainable parameters. + The default quadrature scheme chooses `z_{N, n}` as `N` midpoints of + the quantiles of `p(z)` (generalized quantiles if `K > 2`). - WARNING: If you backprop through a VectorDiffeomixture sample and the "base" - distribution is both: not `FULLY_REPARAMETERIZED` and a function of trainable - variables, then the gradient is not guaranteed correct! + See [1] for more details. + + [1]. "Quadrature Compound: An approximating family of distributions" + Joshua Dillon, Ian Langmore, arXiv preprints + https://arxiv.org/abs/1801.03080 #### About `Vector` distributions in TensorFlow. @@ -310,12 +260,11 @@ class VectorDiffeomixture(distribution_lib.Distribution): particularly useful in [variational Bayesian methods](https://en.wikipedia.org/wiki/Variational_Bayesian_methods). - Conditioned on a draw from the SoftmaxNormal, `Y|v` is a vector whose + Conditioned on a draw from the SoftmaxNormal, `X|z` is a vector whose components are linear combinations of affine transformations, thus is itself - an affine transformation. Therefore `Y|v` lives in the vector space generated - by vectors of affine-transformed distributions. + an affine transformation. - Note: The marginals `Y_1|v, ..., Y_d|v` are *not* generally identical to some + Note: The marginals `X_1|v, ..., X_d|v` are *not* generally identical to some parameterization of `distribution`. This is due to the fact that the sum of draws from `distribution` are not generally itself the same `distribution`. @@ -331,12 +280,16 @@ class VectorDiffeomixture(distribution_lib.Distribution): optimize Monte-Carlo objectives. Such objectives are a finite-sample approximation of an expectation and arise throughout scientific computing. + WARNING: If you backprop through a VectorDiffeomixture sample and the "base" + distribution is both: not `FULLY_REPARAMETERIZED` and a function of trainable + variables, then the gradient is not guaranteed correct! + #### Examples ```python tfd = tf.contrib.distributions - # Create two batches of VectorDiffeomixtures, one with mix_loc=[0.] and + # Create two batches of VectorDiffeomixtures, one with mix_loc=[0.], # another with mix_loc=[1]. In both cases, `K=2` and the affine # transformations involve: # k=0: loc=zeros(dims) scale=LinearOperatorScaledIdentity @@ -344,7 +297,7 @@ class VectorDiffeomixture(distribution_lib.Distribution): dims = 5 vdm = tfd.VectorDiffeomixture( mix_loc=[[0.], [1]], - mix_scale=[1.], + temperature=[1.], distribution=tfd.Normal(loc=0., scale=1.), loc=[ None, # Equivalent to `np.zeros(dims, dtype=np.float32)`. @@ -364,7 +317,7 @@ class VectorDiffeomixture(distribution_lib.Distribution): def __init__(self, mix_loc, - mix_scale, + temperature, distribution, loc=None, scale=None, @@ -373,15 +326,24 @@ class VectorDiffeomixture(distribution_lib.Distribution): validate_args=False, allow_nan_stats=True, name="VectorDiffeomixture"): - """Constructs the VectorDiffeomixture on `R**d`. + """Constructs the VectorDiffeomixture on `R^d`. + + The vector diffeomixture (VDM) approximates the compound distribution + + ```none + p(x) = int p(x | z) p(z) dz, + where z is in the K-simplex, and + p(x | z) := p(x | loc=sum_k z[k] loc[k], scale=sum_k z[k] scale[k]) + ``` Args: - mix_loc: `float`-like `Tensor` with shape `[b1, ..., bB, K-1]`. Represents - the `location` parameter of the SoftmaxNormal used for selecting one of - the `K` affine transformations. - mix_scale: `float`-like `Tensor` with shape `[b1, ..., bB, K-1]`. - Represents the `scale` parameter of the SoftmaxNormal used for selecting - one of the `K` affine transformations. + mix_loc: `float`-like `Tensor` with shape `[b1, ..., bB, K-1]`. + In terms of samples, larger `mix_loc[..., k]` ==> + `Z` is more likely to put more weight on its `kth` component. + temperature: `float`-like `Tensor`. Broadcastable with `mix_loc`. + In terms of samples, smaller `temperature` means one component is more + likely to dominate. I.e., smaller `temperature` makes the VDM look more + like a standard mixture of `K` components. distribution: `tf.Distribution`-like instance. Distribution from which `d` iid samples are used as input to the selected affine transformation. Must be a scalar-batch, scalar-event distribution. Typically @@ -401,8 +363,9 @@ class VectorDiffeomixture(distribution_lib.Distribution): transformation. `LinearOperator`s must have shape `[B1, ..., Bb, d, d]`, `b >= 0`, i.e., characterizes `b`-batches of `d x d` matrices quadrature_size: Python `int` scalar representing number of - quadrature points. - quadrature_fn: Python callable taking `mix_loc`, `mix_scale`, + quadrature points. Larger `quadrature_size` means `q_N(x)` better + approximates `p(x)`. + quadrature_fn: Python callable taking `normal_loc`, `normal_scale`, `quadrature_size`, `validate_args` and returning `tuple(grid, probs)` representing the SoftmaxNormal grid and corresponding normalized weight. normalized) weight. @@ -430,7 +393,7 @@ class VectorDiffeomixture(distribution_lib.Distribution): ValueError: if `not distribution.is_scalar_event`. """ parameters = locals() - with ops.name_scope(name, values=[mix_loc, mix_scale]): + with ops.name_scope(name, values=[mix_loc, temperature]): if not scale or len(scale) < 2: raise ValueError("Must specify list (or list-like object) of scale " "LinearOperators, one for each component with " @@ -473,8 +436,15 @@ class VectorDiffeomixture(distribution_lib.Distribution): raise NotImplementedError("Currently only bimixtures are supported; " "len(scale)={} is not 2.".format(len(scale))) + mix_loc = ops.convert_to_tensor( + mix_loc, dtype=dtype, name="mix_loc") + temperature = ops.convert_to_tensor( + temperature, dtype=dtype, name="temperature") self._grid, probs = tuple(quadrature_fn( - mix_loc, mix_scale, quadrature_size, validate_args)) + mix_loc / temperature, + 1. / temperature, + quadrature_size, + validate_args)) # Note: by creating the logits as `log(prob)` we ensure that # `self.mixture_distribution.logits` is equivalent to @@ -618,7 +588,14 @@ class VectorDiffeomixture(distribution_lib.Distribution): weight = array_ops.gather( array_ops.reshape(self.grid, shape=[-1]), ids + offset) - weight = weight[..., array_ops.newaxis] + # At this point, weight flattened all batch dims into one. + # We also need to append a singleton to broadcast with event dims. + if self.batch_shape.is_fully_defined(): + new_shape = [-1] + self.batch_shape.as_list() + [1] + else: + new_shape = array_ops.concat( + ([-1], self.batch_shape_tensor(), [1]), axis=0) + weight = array_ops.reshape(weight, shape=new_shape) if len(x) != 2: # We actually should have already triggered this exception. However as a @@ -686,7 +663,7 @@ class VectorDiffeomixture(distribution_lib.Distribution): # To compute E[Cov(Z|V)], we'll add matrices within three categories: # scaled-identity, diagonal, and full. Then we'll combine these at the end. - scaled_identity = None + scale_identity_multiplier = None diag = None full = None @@ -694,10 +671,12 @@ class VectorDiffeomixture(distribution_lib.Distribution): s = aff.scale # Just in case aff.scale has side-effects, we'll call once. if (s is None or isinstance(s, linop_identity_lib.LinearOperatorIdentity)): - scaled_identity = add(scaled_identity, p[..., k, array_ops.newaxis]) + scale_identity_multiplier = add(scale_identity_multiplier, + p[..., k, array_ops.newaxis]) elif isinstance(s, linop_identity_lib.LinearOperatorScaledIdentity): - scaled_identity = add(scaled_identity, (p[..., k, array_ops.newaxis] * - math_ops.square(s.multiplier))) + scale_identity_multiplier = add( + scale_identity_multiplier, + (p[..., k, array_ops.newaxis] * math_ops.square(s.multiplier))) elif isinstance(s, linop_diag_lib.LinearOperatorDiag): diag = add(diag, (p[..., k, array_ops.newaxis] * math_ops.square(s.diag_part()))) @@ -709,12 +688,13 @@ class VectorDiffeomixture(distribution_lib.Distribution): full = add(full, x) # We must now account for the fact that the base distribution might have a - # non-unity variance. Recall that `Cov(SX+m) = S.T Cov(X) S = S.T S Var(X)`. + # non-unity variance. Recall that, since X ~ iid Law(X_0), + # `Cov(SX+m) = S Cov(X) S.T = S S.T Diag(Var(X_0))`. # We can scale by `Var(X)` (vs `Cov(X)`) since X corresponds to `d` iid # samples from a scalar-event distribution. v = self.distribution.variance() - if scaled_identity is not None: - scaled_identity *= v + if scale_identity_multiplier is not None: + scale_identity_multiplier *= v if diag is not None: diag *= v[..., array_ops.newaxis] if full is not None: @@ -723,10 +703,10 @@ class VectorDiffeomixture(distribution_lib.Distribution): if diag_only: # Apparently we don't need the full matrix, just the diagonal. r = add(diag, full) - if r is None and scaled_identity is not None: + if r is None and scale_identity_multiplier is not None: ones = array_ops.ones(self.event_shape_tensor(), dtype=self.dtype) - return scaled_identity * ones - return add(r, scaled_identity) + return scale_identity_multiplier[..., array_ops.newaxis] * ones + return add(r, scale_identity_multiplier) # `None` indicates we don't know if the result is positive-definite. is_positive_definite = (True if all(aff.scale.is_positive_definite @@ -742,10 +722,10 @@ class VectorDiffeomixture(distribution_lib.Distribution): to_add.append(linop_full_lib.LinearOperatorFullMatrix( matrix=full, is_positive_definite=is_positive_definite)) - if scaled_identity is not None: + if scale_identity_multiplier is not None: to_add.append(linop_identity_lib.LinearOperatorScaledIdentity( num_rows=self.event_shape_tensor()[0], - multiplier=scaled_identity, + multiplier=scale_identity_multiplier, is_positive_definite=is_positive_definite)) return (linop_add_lib.add_operators(to_add)[0].to_dense() diff --git a/tensorflow/contrib/eager/python/checkpointable_test.py b/tensorflow/contrib/eager/python/checkpointable_test.py index f820990bbe5fe6c9b4cdf890680aaad0847010c0..ff419614f580d3bace9d99648478cc2204d7801d 100644 --- a/tensorflow/contrib/eager/python/checkpointable_test.py +++ b/tensorflow/contrib/eager/python/checkpointable_test.py @@ -70,42 +70,36 @@ class CheckpointableAdam(adam.AdamOptimizer, checkpointable.Checkpointable): checkpointable.Checkpointable.__init__(self) adam.AdamOptimizer.__init__(self, *args, **kwargs) - # NOTE: Copied from AdamOptimizer with modifications to use add_variable + # NOTE: Copied from Optimizer with modifications to use add_variable # for non-slot variables. These contortions are necessary to maintain # checkpoint compatibility with variable.name based saving. - def _create_slots(self, var_list): - # Create the beta1 and beta2 accumulators on the same device as the first - # variable. Sort the var_list to make sure this device is consistent across - # workers (these need to go on the same PS, otherwise some updates are - # silently ignored). - first_var = min(var_list, key=lambda x: x.name) - - create_new = self._beta1_power is None - if not create_new and context.in_graph_mode(): - create_new = (self._beta1_power.graph is not first_var.graph) - - if create_new: - with ops.colocate_with(first_var): + # TODO(allenl): Make this cleaner. + def _create_non_slot_variable(self, initial_value, name, colocate_with): + """Add an extra variable, not associated with a slot.""" + if context.in_graph_mode(): + graph = colocate_with.graph + else: + graph = None + key = (name, graph) + v = self._non_slot_dict.get(key, None) + if v is None: + with ops.colocate_with(colocate_with): def _variable_getter(name, shape, dtype, initializer): del shape, dtype # not used, but there for compatibility return variable_scope.variable( name=name, initial_value=initializer, trainable=False) - self._beta1_power = self.add_variable( - name="beta1_power", - shape=[], - initializer=self._beta1, + initial_value = ops.convert_to_tensor(initial_value) + v = self.add_variable( + name=name, + shape=initial_value.get_shape(), + initializer=initial_value, getter=_variable_getter) - self._beta2_power = self.add_variable( - name="beta2_power", - shape=[], - initializer=self._beta2, - getter=_variable_getter) - # Create slots for the first and second moments. - for v in var_list: - self._zeros_slot(v, "m", self._name) - self._zeros_slot(v, "v", self._name) + + self._non_slot_dict[key] = v + + return v # TODO(allenl): Override slot variable creation (_get_or_make_slot, # _get_or_make_slot_with_initializer, _zeros_slot) to allow deferred diff --git a/tensorflow/contrib/eager/python/datasets.py b/tensorflow/contrib/eager/python/datasets.py index a7f50c13bb992fd47669fb9956dde6b271e16ffd..544a3eafc08f892f6e3315f0656c97b9877cfa0e 100644 --- a/tensorflow/contrib/eager/python/datasets.py +++ b/tensorflow/contrib/eager/python/datasets.py @@ -141,7 +141,12 @@ class Iterator(object): # TODO(ashankar): Consider removing this ops.device() contextmanager # and instead mimic ops placement in graphs: Operations on resource # handles execute on the same device as where the resource is placed. - ret = gen_dataset_ops.iterator_get_next( + # NOTE(mrry): Here we use the "_sync" variant of `iterator_get_next` + # because in eager mode this code will run synchronously on the calling + # thread. Therefore we do not need to make a defensive context switch + # to a background thread, and can achieve a small constant performance + # boost by invoking the iterator synchronously. + ret = gen_dataset_ops.iterator_get_next_sync( self._resource, output_types=self._flat_output_types, output_shapes=self._flat_output_shapes) diff --git a/tensorflow/contrib/eager/python/evaluator.py b/tensorflow/contrib/eager/python/evaluator.py index 3faaeef5903615ea122800a6690117dde682e830..68e7b5421fec7f73f10e381ca45f9d900de299d7 100644 --- a/tensorflow/contrib/eager/python/evaluator.py +++ b/tensorflow/contrib/eager/python/evaluator.py @@ -178,7 +178,7 @@ class Evaluator(object): call_op: An op that updates evaluation state on a mini-batch of examples. Must generate an tf.errors.OutOfRangeError when done. results_op: A dictionary of tensors that compute the final evaluation - results from the evaulation state. + results from the evaluation state. sess: The Session to run the evaluation in. Defaults to the default Session. diff --git a/tensorflow/contrib/eager/python/examples/linear_regression/BUILD b/tensorflow/contrib/eager/python/examples/linear_regression/BUILD index bab7ad0c701b2110fda9a8d27792fd361a5fc1c0..f86331af6f7928f0f86c888e22706c6e0a5978b2 100644 --- a/tensorflow/contrib/eager/python/examples/linear_regression/BUILD +++ b/tensorflow/contrib/eager/python/examples/linear_regression/BUILD @@ -23,3 +23,13 @@ cuda_py_test( "//tensorflow:tensorflow_py", ], ) + +cuda_py_test( + name = "linear_regression_graph_test", + size = "small", + srcs = ["linear_regression_graph_test.py"], + additional_deps = [ + ":linear_regression", + "//tensorflow:tensorflow_py", + ], +) diff --git a/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression.py b/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression.py index f4b7d67f940f5d752e1d22d643b763e2d97e987e..6ce4de6ee0bf50400eff339ac04e132252a2b53e 100644 --- a/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression.py +++ b/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression.py @@ -63,6 +63,10 @@ class LinearModel(tfe.Network): return self._hidden_layer(xs) +def mean_square_loss(model, xs, ys): + return tf.reduce_mean(tf.square(model(xs) - ys)) + + def fit(model, dataset, optimizer, verbose=False, logdir=None): """Fit the linear-regression model. @@ -76,10 +80,8 @@ def fit(model, dataset, optimizer, verbose=False, logdir=None): """ # The loss function to optimize. - def mean_square_loss(xs, ys): - return tf.reduce_mean(tf.square(model(xs) - ys)) - - loss_and_grads = tfe.implicit_value_and_gradients(mean_square_loss) + mse = lambda xs, ys: mean_square_loss(model, xs, ys) + loss_and_grads = tfe.implicit_value_and_gradients(mse) tf.train.get_or_create_global_step() if logdir: @@ -103,14 +105,20 @@ def fit(model, dataset, optimizer, verbose=False, logdir=None): def synthetic_dataset(w, b, noise_level, batch_size, num_batches): """tf.data.Dataset that yields synthetic data for linear regression.""" + return synthetic_dataset_helper(w, b, + tf.shape(w)[0], noise_level, batch_size, + num_batches) + +def synthetic_dataset_helper(w, b, num_features, noise_level, batch_size, + num_batches): # w is a matrix with shape [N, M] # b is a vector with shape [M] # So: # - Generate x's as vectors with shape [batch_size N] # - y = tf.matmul(x, W) + b + noise def batch(_): - x = tf.random_normal([batch_size, tf.shape(w)[0]]) + x = tf.random_normal([batch_size, num_features]) y = tf.matmul(x, w) + b + noise_level * tf.random_normal([]) return x, y diff --git a/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression_graph_test.py b/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression_graph_test.py new file mode 100644 index 0000000000000000000000000000000000000000..557ad42752144243ae3da61b955b31398cba846e --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression_graph_test.py @@ -0,0 +1,85 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Graph benchmark for linear regression, to contrast with eager execution.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import time + +import tensorflow as tf +from tensorflow.contrib.eager.python.examples.linear_regression import linear_regression + + +class GraphLinearRegressionBenchmark(tf.test.Benchmark): + + def benchmarkGraphLinearRegression(self): + num_epochs = 10 + num_batches = 200 + batch_size = 64 + dataset = linear_regression.synthetic_dataset_helper( + w=tf.random_uniform([3, 1]), + b=tf.random_uniform([1]), + num_features=3, + noise_level=0.01, + batch_size=batch_size, + num_batches=num_batches) + iterator = dataset.make_initializable_iterator() + x, y = iterator.get_next() + + model = linear_regression.LinearModel() + + if tf.test.is_gpu_available(): + use_gpu = True + device = "/device:GPU:0" + else: + use_gpu = False + device = "/device:CPU:0" + + with tf.device(device): + loss = linear_regression.mean_square_loss(model, x, y) + optimization_step = tf.train.GradientDescentOptimizer( + learning_rate=0.1).minimize(loss) + + with tf.Session() as sess: + sess.run(tf.global_variables_initializer()) + + def train(num_epochs): + for _ in range(num_epochs): + sess.run(iterator.initializer) + try: + while True: + _, _ = sess.run([optimization_step, loss]) + except tf.errors.OutOfRangeError: + pass + + # Warmup: a single epoch. + train(1) + + start_time = time.time() + train(num_epochs) + wall_time = time.time() - start_time + + examples_per_sec = num_epochs * num_batches * batch_size / wall_time + self.report_benchmark( + name="graph_train_%s" % + ("gpu" if use_gpu else "cpu"), + iters=num_epochs * num_batches, + extras={"examples_per_sec": examples_per_sec}, + wall_time=wall_time) + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression_test.py b/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression_test.py index 39e7aabd7be04ba36a786a4c08d0df6c2ce916d0..e53234b51a7dccc11e548ac81a7ef070c628aa52 100644 --- a/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression_test.py +++ b/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression_test.py @@ -83,6 +83,7 @@ class LinearRegressionTest(tf.test.TestCase): class EagerLinearRegressionBenchmark(tf.test.Benchmark): def benchmarkEagerLinearRegression(self): + num_epochs = 10 num_batches = 200 batch_size = 64 dataset = linear_regression.synthetic_dataset( @@ -102,14 +103,15 @@ class EagerLinearRegressionBenchmark(tf.test.Benchmark): linear_regression.fit(model, burn_in_dataset, optimizer) start_time = time.time() - linear_regression.fit(model, dataset, optimizer) + for _ in range(num_epochs): + linear_regression.fit(model, dataset, optimizer) wall_time = time.time() - start_time - examples_per_sec = num_batches * batch_size / wall_time + examples_per_sec = num_epochs * num_batches * batch_size / wall_time self.report_benchmark( name="eager_train_%s" % ("gpu" if tfe.num_gpus() > 0 else "cpu"), - iters=num_batches, + iters=num_epochs * num_batches, extras={"examples_per_sec": examples_per_sec}, wall_time=wall_time) diff --git a/tensorflow/contrib/eager/python/examples/mnist/mnist.py b/tensorflow/contrib/eager/python/examples/mnist/mnist.py index 82b3d3919cf0176961853d2bd85802e5dafa789e..2a7be95811f6fff06e2c489890703561ed879c42 100644 --- a/tensorflow/contrib/eager/python/examples/mnist/mnist.py +++ b/tensorflow/contrib/eager/python/examples/mnist/mnist.py @@ -23,7 +23,6 @@ from __future__ import division from __future__ import print_function import argparse -import functools import os import sys import time @@ -124,21 +123,18 @@ def train_one_epoch(model, optimizer, dataset, log_interval=None): tf.train.get_or_create_global_step() - def model_loss(labels, images): - prediction = model(images, training=True) - loss_value = loss(prediction, labels) - tf.contrib.summary.scalar('loss', loss_value) - tf.contrib.summary.scalar('accuracy', - compute_accuracy(prediction, labels)) - return loss_value - for (batch, (images, labels)) in enumerate(tfe.Iterator(dataset)): with tf.contrib.summary.record_summaries_every_n_global_steps(10): - batch_model_loss = functools.partial(model_loss, labels, images) - optimizer.minimize( - batch_model_loss, global_step=tf.train.get_global_step()) + with tfe.GradientTape() as tape: + prediction = model(images, training=True) + loss_value = loss(prediction, labels) + tf.contrib.summary.scalar('loss', loss_value) + tf.contrib.summary.scalar('accuracy', + compute_accuracy(prediction, labels)) + grads = tape.gradient(loss_value, model.variables) + optimizer.apply_gradients(zip(grads, model.variables)) if log_interval and batch % log_interval == 0: - print('Batch #%d\tLoss: %.6f' % (batch, batch_model_loss())) + print('Batch #%d\tLoss: %.6f' % (batch, loss_value)) def test(model, dataset): diff --git a/tensorflow/contrib/eager/python/examples/resnet50/README.md b/tensorflow/contrib/eager/python/examples/resnet50/README.md index db023e6c976c8eda09ef0dee7eecb144678773c4..79e460052945718eac194653015d60d900998e2d 100644 --- a/tensorflow/contrib/eager/python/examples/resnet50/README.md +++ b/tensorflow/contrib/eager/python/examples/resnet50/README.md @@ -34,7 +34,7 @@ bazel run -c opt --config=cuda :resnet50_graph_test -- --benchmarks=. (Or remove the `--config=cuda` flag for running on CPU instead of GPU). -On October 31, 2017, the benchmarks demostrated comparable performance +On October 31, 2017, the benchmarks demonstrated comparable performance for eager and graph execution of this particular model when using a single NVIDIA Titan X (Pascal) GPU on a host with an Intel Xeon E5-1650 CPU @ 3.50GHz and a batch size of 32. diff --git a/tensorflow/contrib/eager/python/examples/resnet50/resnet50.py b/tensorflow/contrib/eager/python/examples/resnet50/resnet50.py index b302a87e0e8a61d2456db1eba847f31bd70f552e..9982fdb07eefa665379e7be095f4f8017d92cf97 100644 --- a/tensorflow/contrib/eager/python/examples/resnet50/resnet50.py +++ b/tensorflow/contrib/eager/python/examples/resnet50/resnet50.py @@ -97,7 +97,7 @@ class _ConvBlock(tfe.Network): Args: kernel_size: the kernel size of middle conv layer at main path - filters: list of integers, the filterss of 3 conv layer at main path + filters: list of integers, the filters of 3 conv layer at main path stage: integer, current stage label, used for generating layer names block: 'a','b'..., current block label, used for generating layer names data_format: data_format for the input ('channels_first' or diff --git a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py index e2ae665a74fcf297b3174006783a7b8fed19ff03..76e06269b6bbeb3386a6346244d294b1c5167b6e 100644 --- a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py +++ b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py @@ -52,14 +52,13 @@ def random_batch(batch_size): def train_one_step(model, images, labels, optimizer): - def model_loss(): + with tfe.GradientTape() as tape: logits = model(images, training=True) loss = tf.losses.softmax_cross_entropy( logits=logits, onehot_labels=labels) tf.contrib.summary.scalar(name='loss', tensor=loss) - return loss - - optimizer.minimize(model_loss) + grads = tape.gradient(loss, model.variables) + optimizer.apply_gradients(zip(grads, model.variables)) class ResNet50Test(tf.test.TestCase): diff --git a/tensorflow/contrib/eager/python/examples/rnn_ptb/README.md b/tensorflow/contrib/eager/python/examples/rnn_ptb/README.md index 743ebb68ee5bba5635899267cc4839828f7e4e2f..966177e91c212c1aa132fe3af6f7dc9a50fb984e 100644 --- a/tensorflow/contrib/eager/python/examples/rnn_ptb/README.md +++ b/tensorflow/contrib/eager/python/examples/rnn_ptb/README.md @@ -40,7 +40,7 @@ bazel run -c opt --config=cuda :rnn_ptb_graph_test -- --benchmarks=. (Or remove the `--config=cuda` flag for running on CPU instead of GPU). -On October 31, 2017, the benchmarks demostrated slightly better performance +On October 31, 2017, the benchmarks demonstrated slightly better performance (3-6%) for graph execution over eager execution for this particular model when using a single NVIDIA Titan X (Pascal) GPU on a host with an Intel Xeon E5-1650 CPU @ 3.50GHz and a batch size of 32. diff --git a/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py b/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py index 7b9637a9d58c87e93c7c0ea7173a6b88c885ee25..d34e9ea68b76373d4b5a9ee9e3852c60a7c81525 100644 --- a/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py +++ b/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py @@ -88,7 +88,7 @@ class Embedding(tf.layers.Layer): class PTBModel(tfe.Network): - """LSTM for word language modelling. + """LSTM for word language modeling. Model described in: (Zaremba, et. al.) Recurrent Neural Network Regularization @@ -340,7 +340,7 @@ if __name__ == "__main__": parser.add_argument( "--logdir", type=str, default="", help="Directory for checkpoint.") parser.add_argument( - "--epoch", type=int, default=20, help="Number of epoches.") + "--epoch", type=int, default=20, help="Number of epochs.") parser.add_argument("--batch-size", type=int, default=20, help="Batch size.") parser.add_argument( "--seq-len", type=int, default=35, help="Sequence length.") diff --git a/tensorflow/contrib/eager/python/examples/spinn/data.py b/tensorflow/contrib/eager/python/examples/spinn/data.py index a6e046320f78541bef4e091e97f08fd51857af83..fcaae0a4f8c0bad916d74bd9b80fcfa55a63d84a 100644 --- a/tensorflow/contrib/eager/python/examples/spinn/data.py +++ b/tensorflow/contrib/eager/python/examples/spinn/data.py @@ -51,11 +51,11 @@ def get_non_parenthesis_words(items): """Get the non-parenthesis items from a SNLI parsed sentence. Args: - items: Data items from a parsed SNLI setence, with parentheses. E.g., + items: Data items from a parsed SNLI sentence, with parentheses. E.g., ["(", "Man", "(", "(", "(", "(", "(", "wearing", "pass", ")", ... Returns: - A list of non-parenthis word items, all converted to lower case. E.g., + A list of non-parentheses word items, all converted to lower case. E.g., ["man", "wearing", "pass", ... """ return [x.lower() for x in items if x not in PARENTHESES and x] @@ -201,7 +201,7 @@ def load_word_vectors(data_root, vocab): def calculate_bins(length2count, min_bin_size): - """Cacluate bin boundaries given a histogram of lengths and mininum bin size. + """Calculate bin boundaries given a histogram of lengths and minimum bin size. Args: length2count: A `dict` mapping length to sentence count. @@ -335,9 +335,9 @@ class SnliData(object): # The sorting above and the batching here makes sure that sentences of # similar max lengths are batched together, minimizing the inefficiency # due to uneven max lengths. The sentences are batched differently in - # each call to get_generator() due to the shuffling before sotring + # each call to get_generator() due to the shuffling before sorting # above. The pad_and_reverse_word_ids() and pad_transitions() functions - # take care of any remaning unevenness of the max sentence lengths. + # take care of any remaining unevenness of the max sentence lengths. end = min(begin + batch_size, len(labels)) # Transpose, because the SPINN model requires time-major, instead of # batch-major. diff --git a/tensorflow/contrib/eager/python/g3doc/guide.md b/tensorflow/contrib/eager/python/g3doc/guide.md index 0095ffa0db99d46d25654d73504d0d7d41c18b6f..7eea93ce1f5aefe82d73b49f57b636692818ba16 100644 --- a/tensorflow/contrib/eager/python/g3doc/guide.md +++ b/tensorflow/contrib/eager/python/g3doc/guide.md @@ -292,7 +292,7 @@ def loss(weight, bias): error = prediction(training_inputs, weight, bias) - training_outputs return tf.reduce_mean(tf.square(error)) -# Function that returns the the derivative of loss with respect to +# Function that returns the derivative of loss with respect to # weight and bias grad = tfe.gradients_function(loss) diff --git a/tensorflow/contrib/eager/python/network_test.py b/tensorflow/contrib/eager/python/network_test.py index 8e6b947e5cb28910bcb4877aa66150992a8d6445..81c77e41acf420fa84857ccb366aa2fbd6055f42 100644 --- a/tensorflow/contrib/eager/python/network_test.py +++ b/tensorflow/contrib/eager/python/network_test.py @@ -688,7 +688,7 @@ class NetworkTest(test.TestCase): net2(one) # Layer names typically are globally unique rather than being unique within # the scope of their first use. However, within a Network they must be named - # locally so that previous Layer consutrciton does not interfere with + # locally so that previous Layer construction does not interfere with # variable naming (e.g. add a Layer construction before the Network, # suddenly your previously saved checkpoint is incompatible). self.assertEqual("dense", net1.l1.name) diff --git a/tensorflow/contrib/eager/python/saver.py b/tensorflow/contrib/eager/python/saver.py index 57b070ec6eeac00c77f199a846639d64c4957cd8..62421849c766a1124c726812428985c913c653a3 100644 --- a/tensorflow/contrib/eager/python/saver.py +++ b/tensorflow/contrib/eager/python/saver.py @@ -82,7 +82,7 @@ def restore_variables_on_create(save_path, map_func=None): map_func_wrapper = lambda self, x: x else: if not callable(map_func): - raise ValueError("map_func must be callaled.") + raise ValueError("map_func must be callable.") map_func_wrapper = lambda self, x: map_func(x) ckpt_var_cache = dict() diff --git a/tensorflow/contrib/estimator/python/estimator/head.py b/tensorflow/contrib/estimator/python/estimator/head.py index d6ca33e18923a5dd996431b0ff87c6ad3bccea92..fd0994490aac7b9a0ed628e0c3c624d0fefb1b81 100644 --- a/tensorflow/contrib/estimator/python/estimator/head.py +++ b/tensorflow/contrib/estimator/python/estimator/head.py @@ -220,7 +220,7 @@ def multi_label_head(n_classes, `batch_size`. The head expects `logits` with shape `[D0, D1, ... DN, n_classes]`. In many - applications, the shape is `[batch_size, label_n_classes]`. + applications, the shape is `[batch_size, n_classes]`. Labels can be: * A multi-hot tensor of shape `[D0, D1, ... DN, n_classes]` @@ -392,8 +392,32 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access processed_labels=processed_labels) def create_estimator_spec( - self, features, mode, logits, labels=None, train_op_fn=None): - """See `Head`.""" + self, features, mode, logits, labels=None, train_op_fn=None, + regularization_losses=None): + """Returns an `EstimatorSpec`. + + Args: + features: Input `dict` of `Tensor` or `SparseTensor` objects. + mode: Estimator's `ModeKeys`. + logits: logits `Tensor` with shape `[D0, D1, ... DN, n_classes]`. + For many applications, the shape is `[batch_size, n_classes]`. + labels: Labels with shape matching `logits`. Can be multi-hot `Tensor` + with shape `[D0, D1, ... DN, n_classes]` or `SparseTensor` with + `dense_shape` `[D0, D1, ... DN, ?]`. `labels` is required argument when + `mode` equals `TRAIN` or `EVAL`. + train_op_fn: Function that takes a scalar loss `Tensor` and returns + `train_op`. Required in TRAIN mode. + regularization_losses: A list of additional scalar losses to be added to + the training loss, such as regularization losses. These losses are + usually expressed as a batch average, so for best results users need to + set `loss_reduction=SUM_OVER_BATCH_SIZE` or + `loss_reduction=SUM_OVER_NONZERO_WEIGHTS` when creating the head to + avoid scaling errors. + Returns: + `EstimatorSpec`. + Raises: + ValueError: If `train_op_fn` is `None` in TRAIN mode. + """ with ops.name_scope(self._name, 'head'): logits = head_lib._check_logits_final_dim(logits, self.logits_dimension) # pylint:disable=protected-access @@ -422,18 +446,26 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access (training_loss, unreduced_loss, weights, processed_labels) = self.create_loss( features=features, mode=mode, logits=logits, labels=labels) + if regularization_losses: + regularization_loss = math_ops.add_n(regularization_losses) + regularized_training_loss = math_ops.add_n( + [training_loss, regularization_loss]) + else: + regularization_loss = None + regularized_training_loss = training_loss # Eval. if mode == model_fn.ModeKeys.EVAL: return model_fn.EstimatorSpec( mode=model_fn.ModeKeys.EVAL, predictions=predictions, - loss=training_loss, + loss=regularized_training_loss, eval_metric_ops=self._eval_metric_ops( labels=processed_labels, probabilities=probabilities, weights=weights, - unreduced_loss=unreduced_loss)) + unreduced_loss=unreduced_loss, + regularization_loss=regularization_loss)) # Train. if train_op_fn is None: @@ -447,25 +479,31 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access else: mean_loss = None with ops.name_scope(''): + keys = metric_keys.MetricKeys summary.scalar( - head_lib._summary_key(self._name, metric_keys.MetricKeys.LOSS), # pylint:disable=protected-access - training_loss) + head_lib._summary_key(self._name, keys.LOSS), # pylint:disable=protected-access + regularized_training_loss) if mean_loss is not None: summary.scalar( - head_lib._summary_key( # pylint:disable=protected-access - self._name, metric_keys.MetricKeys.LOSS_MEAN), + head_lib._summary_key(self._name, keys.LOSS_MEAN), # pylint:disable=protected-access mean_loss) + if regularization_loss is not None: + summary.scalar( + head_lib._summary_key(self._name, keys.LOSS_REGULARIZATION), # pylint:disable=protected-access + regularization_loss) return model_fn.EstimatorSpec( mode=model_fn.ModeKeys.TRAIN, predictions=predictions, - loss=training_loss, - train_op=train_op_fn(training_loss)) + loss=regularized_training_loss, + train_op=train_op_fn(regularized_training_loss)) - def _eval_metric_ops(self, labels, probabilities, weights, unreduced_loss): + def _eval_metric_ops( + self, labels, probabilities, weights, unreduced_loss, + regularization_loss): """Returns a dict of metrics for eval_metric_ops.""" with ops.name_scope( None, 'metrics', - [labels, probabilities, weights, unreduced_loss]): + [labels, probabilities, weights, unreduced_loss, regularization_loss]): keys = metric_keys.MetricKeys metric_ops = { # Estimator already adds a metric for loss. @@ -482,6 +520,13 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access weights=weights, curve='PR', name=keys.AUC_PR), } + if regularization_loss is not None: + loss_regularization_key = head_lib._summary_key( # pylint:disable=protected-access + self._name, keys.LOSS_REGULARIZATION) + metric_ops[loss_regularization_key] = ( + metrics_lib.mean( + values=regularization_loss, + name=keys.LOSS_REGULARIZATION)) for threshold in self._thresholds: accuracy_key = keys.ACCURACY_AT_THRESHOLD % threshold metric_ops[head_lib._summary_key(self._name, accuracy_key)] = ( # pylint:disable=protected-access diff --git a/tensorflow/contrib/estimator/python/estimator/head_test.py b/tensorflow/contrib/estimator/python/estimator/head_test.py index e39e44541d2d30b1ecc9d4d41d0760decdc58168..1adbd6f0fe32df4a513a2683d03fcefca07e2a42 100644 --- a/tensorflow/contrib/estimator/python/estimator/head_test.py +++ b/tensorflow/contrib/estimator/python/estimator/head_test.py @@ -399,12 +399,13 @@ class MultiLabelHead(test.TestCase): def _test_eval( self, head, logits, labels, expected_loss, expected_metrics, - features=None): + features=None, regularization_losses=None): spec = head.create_estimator_spec( features=features or {}, mode=model_fn.ModeKeys.EVAL, logits=logits, - labels=labels) + labels=labels, + regularization_losses=regularization_losses) # Assert spec contains expected tensors. self.assertIsNotNone(spec.loss) @@ -486,6 +487,38 @@ class MultiLabelHead(test.TestCase): expected_loss=expected_loss, expected_metrics=expected_metrics) + def test_eval_with_regularization_losses(self): + n_classes = 2 + head = head_lib.multi_label_head( + n_classes, loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE) + logits = np.array([[-1., 1.], [-1.5, 1.5]], dtype=np.float32) + labels = np.array([[1, 0], [1, 1]], dtype=np.int64) + regularization_losses = [1.5, 0.5] + expected_regularization_loss = 2. + # unregularized_loss = sum( + # labels * -log(sigmoid(logits)) + + # (1 - labels) * -log(1 - sigmoid(logits))) / batch_size + expected_unregularized_loss = np.sum( + _sigmoid_cross_entropy(labels=labels, logits=logits)) / 2. + expected_regularized_loss = ( + expected_unregularized_loss + expected_regularization_loss) + keys = metric_keys.MetricKeys + expected_metrics = { + keys.LOSS_MEAN: expected_unregularized_loss, + keys.LOSS_REGULARIZATION: expected_regularization_loss, + # auc and auc_pr cannot be reliably calculated for only 4 samples, but + # this assert tests that the algorithm remains consistent. + keys.AUC: 0.3333, + keys.AUC_PR: 0.7639, + } + self._test_eval( + head=head, + logits=logits, + labels=labels, + expected_loss=expected_regularized_loss, + expected_metrics=expected_metrics, + regularization_losses=regularization_losses) + def test_eval_with_label_vocabulary(self): n_classes = 2 head = head_lib.multi_label_head( @@ -829,6 +862,49 @@ class MultiLabelHead(test.TestCase): self._test_train( head=head, logits=logits, labels=labels, expected_loss=expected_loss) + def test_train_with_regularization_losses(self): + head = head_lib.multi_label_head( + n_classes=2, loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE) + logits = np.array([[-10., 10.], [-15., 10.]], dtype=np.float32) + labels = np.array([[1, 0], [1, 1]], dtype=np.int64) + regularization_losses = [1.5, 0.5] + # For large logits, sigmoid cross entropy loss is approximated as: + # loss = labels * (logits < 0) * (-logits) + + # (1 - labels) * (logits > 0) * logits => + # expected_unweighted_loss = [[10., 10.], [15., 0.]] + # Average over classes and over batch and add regularization loss. + expected_loss = 35. / 4. + 2. + expected_summaries = { + metric_keys.MetricKeys.LOSS: expected_loss, + metric_keys.MetricKeys.LOSS_REGULARIZATION: 2., + } + expected_train_result = 'my_train_op' + def _train_op_fn(loss): + return string_ops.string_join( + [constant_op.constant(expected_train_result), + string_ops.as_string(loss, precision=3)]) + + spec = head.create_estimator_spec( + features={'x': np.array(((42,),), dtype=np.int32)}, + mode=model_fn.ModeKeys.TRAIN, + logits=logits, + labels=labels, + train_op_fn=_train_op_fn, + regularization_losses=regularization_losses) + + # Assert predictions, loss, train_op, and summaries. + tol = 1e-3 + with self.test_session() as sess: + _initialize_variables(self, spec.scaffold) + self.assertIsNotNone(spec.scaffold.summary_op) + loss, train_result, summary_str = sess.run((spec.loss, spec.train_op, + spec.scaffold.summary_op)) + self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol) + self.assertEqual( + six.b('{0:s}{1:.3f}'.format(expected_train_result, expected_loss)), + train_result) + _assert_simple_summaries(self, expected_summaries, summary_str, tol) + def test_train_with_weights(self): n_classes = 2 head = head_lib.multi_label_head(n_classes, weight_column='example_weights') diff --git a/tensorflow/contrib/factorization/python/ops/kmeans.py b/tensorflow/contrib/factorization/python/ops/kmeans.py index 9a5413fc3f2642443621b33d325e3d8c893fd6ac..4d0f9b24240ccbafe89ef912b4d3252cefb1f7f2 100644 --- a/tensorflow/contrib/factorization/python/ops/kmeans.py +++ b/tensorflow/contrib/factorization/python/ops/kmeans.py @@ -25,6 +25,7 @@ import time from tensorflow.contrib.factorization.python.ops import clustering_ops from tensorflow.python.estimator import estimator from tensorflow.python.estimator import model_fn as model_fn_lib +from tensorflow.python.estimator.export import export_output from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops @@ -32,6 +33,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import metrics from tensorflow.python.ops import state_ops from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.saved_model import signature_constants from tensorflow.python.summary import summary from tensorflow.python.training import session_run_hook from tensorflow.python.training import training_util @@ -207,6 +209,15 @@ class _ModelFn(object): training_hooks.append( _LossRelativeChangeHook(loss, self._relative_tolerance)) + export_outputs = { + KMeansClustering.ALL_DISTANCES: + export_output.PredictOutput(all_distances[0]), + KMeansClustering.CLUSTER_INDEX: + export_output.PredictOutput(model_predictions[0]), + signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: + export_output.PredictOutput(model_predictions[0]) + } + return model_fn_lib.EstimatorSpec( mode=mode, predictions={ @@ -216,7 +227,8 @@ class _ModelFn(object): loss=loss, train_op=training_op, eval_metric_ops={KMeansClustering.SCORE: metrics.mean(loss)}, - training_hooks=training_hooks) + training_hooks=training_hooks, + export_outputs=export_outputs) # TODO(agarwal,ands): support sharded input. diff --git a/tensorflow/contrib/ffmpeg/decode_audio_op.cc b/tensorflow/contrib/ffmpeg/decode_audio_op.cc index 92fad70b1f9cc55e0690a3fbb35abcf56aa68f16..5ab57ca4cd413bd92f1576278b22d2602c905309 100644 --- a/tensorflow/contrib/ffmpeg/decode_audio_op.cc +++ b/tensorflow/contrib/ffmpeg/decode_audio_op.cc @@ -44,7 +44,7 @@ const char* kValidFileFormats[] = {"mp3", "mp4", "ogg", "wav"}; void Decode(OpKernelContext* context, const tensorflow::StringPiece& file_contents, const string& file_format, const int32 samples_per_second, - const int32 channel_count) { + const int32 channel_count, const string& stream) { // Write the input data to a temp file. const string temp_filename = io::GetTempFilename(file_format); OP_REQUIRES_OK(context, WriteFile(temp_filename, file_contents)); @@ -54,7 +54,7 @@ void Decode(OpKernelContext* context, std::vector output_samples; Status result = ffmpeg::ReadAudioFile(temp_filename, file_format, samples_per_second, - channel_count, &output_samples); + channel_count, stream, &output_samples); if (result.code() == error::Code::NOT_FOUND) { OP_REQUIRES( context, result.ok(), @@ -99,7 +99,12 @@ void Decode(OpKernelContext* context, */ class DecodeAudioOpV2 : public OpKernel { public: - explicit DecodeAudioOpV2(OpKernelConstruction* context) : OpKernel(context) {} + explicit DecodeAudioOpV2(OpKernelConstruction* context) : OpKernel(context) { + string stream; + if (context->GetAttr("stream", &stream).ok()) { + stream_ = stream; + } + } void Compute(OpKernelContext* context) override { OP_REQUIRES( @@ -153,8 +158,12 @@ class DecodeAudioOpV2 : public OpKernel { errors::InvalidArgument("channel_count must be positive, but got: ", channel_count)); - Decode(context, contents, file_format, samples_per_second, channel_count); + Decode(context, contents, file_format, samples_per_second, channel_count, + stream_); } + + private: + string stream_; }; REGISTER_KERNEL_BUILDER(Name("DecodeAudioV2").Device(DEVICE_CPU), @@ -166,6 +175,7 @@ REGISTER_OP("DecodeAudioV2") .Input("samples_per_second: int32") .Input("channel_count: int32") .Output("sampled_audio: float") + .Attr("stream: string = ''") .SetShapeFn([](shape_inference::InferenceContext* c) { const Tensor* channels_tensor = c->input_tensor(3); if (channels_tensor == nullptr) { @@ -237,7 +247,7 @@ class DecodeAudioOp : public OpKernel { const tensorflow::StringPiece file_contents = contents.scalar()(); Decode(context, file_contents, file_format_, samples_per_second_, - channel_count_); + channel_count_, ""); } private: diff --git a/tensorflow/contrib/ffmpeg/decode_audio_op_test.py b/tensorflow/contrib/ffmpeg/decode_audio_op_test.py index 0d7c9cb99e8a5fad4a7ccf86d7253170ace91fd7..3dc663bb6f589d09ed067eae09d7d7dd0c40ec95 100644 --- a/tensorflow/contrib/ffmpeg/decode_audio_op_test.py +++ b/tensorflow/contrib/ffmpeg/decode_audio_op_test.py @@ -33,7 +33,8 @@ class DecodeAudioOpTest(test.TestCase): def _loadFileAndTest(self, filename, file_format, duration_sec, samples_per_second, channel_count, - samples_per_second_tensor=None, feed_dict=None): + samples_per_second_tensor=None, feed_dict=None, + stream=None): """Loads an audio file and validates the output tensor. Args: @@ -49,6 +50,9 @@ class DecodeAudioOpTest(test.TestCase): feed_dict: Used when evaluating the `decode_audio` op. If not provided, will be empty. Useful when providing a placeholder for `samples_per_second_tensor`. + stream: A string specifying which stream from the content file + should be decoded. The default value is '' which leaves the + decision to ffmpeg. """ if samples_per_second_tensor is None: samples_per_second_tensor = samples_per_second @@ -62,7 +66,7 @@ class DecodeAudioOpTest(test.TestCase): contents, file_format=file_format, samples_per_second=samples_per_second_tensor, - channel_count=channel_count) + channel_count=channel_count, stream=stream) audio = audio_op.eval(feed_dict=feed_dict or {}) self.assertEqual(len(audio.shape), 2) self.assertNear( @@ -72,6 +76,17 @@ class DecodeAudioOpTest(test.TestCase): 0.1 * audio.shape[0]) self.assertEqual(audio.shape[1], channel_count) + def testStreamIdentifier(self): + # mono_16khz_mp3_32khz_aac.mp4 was generated from: + # ffmpeg -i tensorflow/contrib/ffmpeg/testdata/mono_16khz_mp3.mp4 \ + # -i tensorflow/contrib/ffmpeg/testdata/mono_32khz_aac.mp4 \ + # -strict -2 -map 0:a -map 1:a \ + # tensorflow/contrib/ffmpeg/testdata/mono_16khz_mp3_32khz_aac.mp4 + self._loadFileAndTest('mono_16khz_mp3_32khz_aac.mp4', 'mp4', 2.77, 20000, + 1, stream='0') + self._loadFileAndTest('mono_16khz_mp3_32khz_aac.mp4', 'mp4', 2.77, 20000, + 1, stream='1') + def testMonoMp3(self): self._loadFileAndTest('mono_16khz.mp3', 'mp3', 0.57, 20000, 1) self._loadFileAndTest('mono_16khz.mp3', 'mp3', 0.57, 20000, 2) diff --git a/tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc b/tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc index 1e8af1458cea13b2ddb89b7d93a4ffb8b974ecd2..c85b1837ab5b0c1a3cea0525918f7717228d2fab 100644 --- a/tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc +++ b/tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc @@ -44,8 +44,10 @@ std::vector FfmpegAudioCommandLine(const string& input_filename, const string& output_filename, const string& input_format_id, int32 samples_per_second, - int32 channel_count) { - return {"-nostats", // No additional progress display. + int32 channel_count, + const string& stream) { + std::vector command({ + "-nostats", // No additional progress display. "-nostdin", // No interactive commands accepted. "-f", input_format_id, // eg: "mp3" "-probesize", StrCat(kDefaultProbeSize), "-i", input_filename, @@ -58,8 +60,15 @@ std::vector FfmpegAudioCommandLine(const string& input_filename, // Output set (in several ways) to signed 16-bit little-endian ints. "-codec:a:0", "pcm_s16le", "-sample_fmt", "s16", "-f", "s16le", "-sn", // No subtitle recording. - "-y", // Overwrite output file. - StrCat(output_filename)}; + "-y" // Overwrite output file. + }); + if (!stream.empty()) { + command.emplace_back("-map"); + command.emplace_back(StrCat("0:", stream)); + } + command.emplace_back(StrCat(output_filename)); + + return command; } std::vector FfmpegVideoCommandLine(const string& input_filename, @@ -73,7 +82,9 @@ std::vector FfmpegVideoCommandLine(const string& input_filename, "-probesize", StrCat(kDefaultProbeSize), "-loglevel", - "error", // Print errors only. + // Info is needed to get the information about stream, etc. + // It is generated to a separate file, not stdout/stderr. + "info", "-hide_banner", // Skip printing build options, version, etc. "-vcodec", "rawvideo", @@ -123,7 +134,6 @@ bool IsBinaryInstalled(const string& binary_name) { std::transform(args.begin(), args.end(), std::back_inserter(args_chars), [](const string& s) { return const_cast(s.c_str()); }); args_chars.push_back(nullptr); - ::execvp(kFfmpegExecutable, args_chars.data()); // exec only returns on error. const int error = errno; @@ -308,13 +318,12 @@ Status WriteFile(const string& filename, StringPiece contents) { Status ReadAudioFile(const string& filename, const string& audio_format_id, int32 samples_per_second, int32 channel_count, - std::vector* output_samples) { + const string& stream, std::vector* output_samples) { // Create an argument list. string output_filename = io::GetTempFilename("raw"); const std::vector args = FfmpegAudioCommandLine(filename, output_filename, audio_format_id, - samples_per_second, channel_count); - + samples_per_second, channel_count, stream); // Unfortunately, it's impossible to differentiate an exec failure due to the // binary being missing and an error from the binary's execution. Therefore, // check to see if the binary *should* be available. If not, return an error @@ -368,7 +377,6 @@ Status ReadVideoFile(const string& filename, std::vector* output_data, // Create an argument list. const std::vector args = FfmpegVideoCommandLine(filename, output_filename); - // Execute ffmpeg and report errors. pid_t child_pid = ::fork(); if (child_pid < 0) { diff --git a/tensorflow/contrib/ffmpeg/ffmpeg_lib.h b/tensorflow/contrib/ffmpeg/ffmpeg_lib.h index c5ea1432bf8b61c87615074a93a45325371c4c87..a8d5a0dd83fb504b5e6671c3e82dc7d2dd3e6a9b 100644 --- a/tensorflow/contrib/ffmpeg/ffmpeg_lib.h +++ b/tensorflow/contrib/ffmpeg/ffmpeg_lib.h @@ -13,8 +13,8 @@ // limitations under the License. // ============================================================================= -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_FFMPEG_FFMPEG_LIB_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_FFMPEG_FFMPEG_LIB_H_ +#ifndef TENSORFLOW_CONTRIB_FFMPEG_FFMPEG_LIB_H_ +#define TENSORFLOW_CONTRIB_FFMPEG_FFMPEG_LIB_H_ #include #include @@ -42,7 +42,7 @@ Status WriteFile(const string& filename, tensorflow::StringPiece contents); // contain a separate sample for each channel. Frames are ordered by time. Status ReadAudioFile(const string& filename, const string& audio_format_id, int32 samples_per_second, int32 channel_count, - std::vector* output_samples); + const string& stream, std::vector* output_samples); // Creates an audio file using ffmpeg in a specific format. The samples are in // [-1.0, 1.0]. If there are multiple channels in the audio then each frame will @@ -61,4 +61,4 @@ Status ReadVideoFile(const string& filename, std::vector* output_data, } // namespace ffmpeg } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_FFMPEG_DEFAULT_FFMPEG_LIB_H_ +#endif // TENSORFLOW_CONTRIB_FFMPEG_DEFAULT_FFMPEG_LIB_H_ diff --git a/tensorflow/contrib/ffmpeg/ffmpeg_ops.py b/tensorflow/contrib/ffmpeg/ffmpeg_ops.py index 08b5a6ea48c2d4959af68a2ee9d27d21c6245457..020b5c99c61019254bef0b1dff6bc5901c92758a 100644 --- a/tensorflow/contrib/ffmpeg/ffmpeg_ops.py +++ b/tensorflow/contrib/ffmpeg/ffmpeg_ops.py @@ -31,7 +31,7 @@ _ffmpeg_so = loader.load_op_library( def decode_audio(contents, file_format=None, samples_per_second=None, - channel_count=None): + channel_count=None, stream=None): """Create an op that decodes the contents of an audio file. Note that ffmpeg is free to select the "best" audio track from an mp4. @@ -51,6 +51,9 @@ def decode_audio(contents, file_format=None, samples_per_second=None, `contents` have more than this number, then some channels will be merged or dropped. If `contents` has fewer than this, then additional channels will be created from the existing ones. + stream: A string specifying which stream from the content file + should be decoded, e.g., '0' means the 0-th stream. + The default value is '' which leaves the decision to ffmpeg. Returns: A rank-2 tensor that has time along dimension 0 and channels along @@ -61,7 +64,7 @@ def decode_audio(contents, file_format=None, samples_per_second=None, """ return gen_decode_audio_op_py.decode_audio_v2( contents, file_format=file_format, samples_per_second=samples_per_second, - channel_count=channel_count) + channel_count=channel_count, stream=stream) ops.NotDifferentiable('DecodeAudio') diff --git a/tensorflow/contrib/ffmpeg/testdata/mono_16khz_mp3_32khz_aac.mp4 b/tensorflow/contrib/ffmpeg/testdata/mono_16khz_mp3_32khz_aac.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..2485da86d60837800fbb0b390c440e674de25993 Binary files /dev/null and b/tensorflow/contrib/ffmpeg/testdata/mono_16khz_mp3_32khz_aac.mp4 differ diff --git a/tensorflow/contrib/framework/python/framework/tensor_util_test.py b/tensorflow/contrib/framework/python/framework/tensor_util_test.py index 2effe8eb26e98caa2707315d5f2e0e530ead31d3..8cdb340f2ddd9b3a7f55c1937ef045f4627e99be 100644 --- a/tensorflow/contrib/framework/python/framework/tensor_util_test.py +++ b/tensorflow/contrib/framework/python/framework/tensor_util_test.py @@ -30,6 +30,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors_impl from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import variables as variables_lib from tensorflow.python.platform import test @@ -77,6 +78,7 @@ class AssertScalarIntTest(test.TestCase): [3, 4], dtype=dtypes.int32)) +@test_util.with_c_api class WithShapeTest(test.TestCase): def _assert_with_shape(self, tensor, expected_value, expected_shape, @@ -213,16 +215,25 @@ class WithShapeTest(test.TestCase): tensor_partial_shape.set_shape([None, 2]) for incompatible_shape in [[0], [1]]: + if ops._USE_C_API: + error_message = "Shapes must be equal rank, but are 2 and 1" + else: + error_message = r"Shapes \(\?, 2\) and \([01],\) are not compatible" self.assertRaisesRegexp( - ValueError, r"Shapes \(\?, 2\) and \([01],\) are not compatible", + ValueError, error_message, tensor_util.with_shape, incompatible_shape, tensor_partial_shape) for incompatible_shape in [[1, 2, 1]]: self.assertRaisesRegexp(ValueError, "Dimensions must be equal", tensor_util.with_shape, incompatible_shape, tensor_partial_shape) for incompatible_shape in [[2, 1]]: + if ops._USE_C_API: + error_message = (r"Dimension 1 in both shapes must be equal, but are " + r"2 and 1. Shapes are \[\?,2\] and \[2,1\].") + else: + error_message = r"Shapes \(\?, 2\) and \(2, 1\) are not compatible" self.assertRaisesRegexp( - ValueError, r"Shapes \(\?, 2\) and \(2, 1\) are not compatible", + ValueError, error_message, tensor_util.with_shape, incompatible_shape, tensor_partial_shape) compatible_shape = [2, 2] diff --git a/tensorflow/contrib/fused_conv/kernels/fused_conv_ops_gpu.h b/tensorflow/contrib/fused_conv/kernels/fused_conv_ops_gpu.h index fa7a3c03aa35c756252b22a004be91fa24c10e41..ba52697679dafc239b1dac5562573b3589877a8c 100644 --- a/tensorflow/contrib/fused_conv/kernels/fused_conv_ops_gpu.h +++ b/tensorflow/contrib/fused_conv/kernels/fused_conv_ops_gpu.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_FUSED_CONV_KERNELS_FUSED_CONV_OPS_GPU_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_FUSED_CONV_KERNELS_FUSED_CONV_OPS_GPU_H_ +#ifndef TENSORFLOW_CONTRIB_FUSED_CONV_KERNELS_FUSED_CONV_OPS_GPU_H_ +#define TENSORFLOW_CONTRIB_FUSED_CONV_KERNELS_FUSED_CONV_OPS_GPU_H_ #if GOOGLE_CUDA @@ -72,4 +72,4 @@ class FusedConvParameters : public ConvParameters { #endif // GOOGLE_CUDA -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_FUSED_CONV_KERNELS_FUSED_CONV_OPS_GPU_H_ +#endif // TENSORFLOW_CONTRIB_FUSED_CONV_KERNELS_FUSED_CONV_OPS_GPU_H_ diff --git a/tensorflow/contrib/gan/BUILD b/tensorflow/contrib/gan/BUILD index b355a79b1a5d967eb82a30d41c073bbb52e0364c..5db34f0f8db93620b8b4a6b71f63b66ac718ee30 100644 --- a/tensorflow/contrib/gan/BUILD +++ b/tensorflow/contrib/gan/BUILD @@ -177,6 +177,7 @@ py_library( srcs_version = "PY2AND3", deps = [ ":losses_impl", + ":namedtuples", "//tensorflow/python:util", ], ) @@ -188,6 +189,9 @@ py_test( deps = [ ":tuple_losses", "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:variables", "//third_party/py/numpy", ], ) @@ -395,6 +399,7 @@ py_library( srcs_version = "PY2AND3", deps = [ ":eval_utils", + ":namedtuples", "//tensorflow/python:array_ops", "//tensorflow/python:framework_ops", "//tensorflow/python:math_ops", diff --git a/tensorflow/contrib/gan/python/eval/python/summaries_impl.py b/tensorflow/contrib/gan/python/eval/python/summaries_impl.py index 508b4d20d8767f42246a0d0c87f911b7ac612f45..74811ff4096eb5215148f0565bf094b83408014c 100644 --- a/tensorflow/contrib/gan/python/eval/python/summaries_impl.py +++ b/tensorflow/contrib/gan/python/eval/python/summaries_impl.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.contrib.gan.python import namedtuples from tensorflow.contrib.gan.python.eval.python import eval_utils from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops @@ -48,6 +49,15 @@ def add_gan_model_image_summaries(gan_model, grid_size=4): Raises: ValueError: If real and generated data aren't images. """ + if isinstance(gan_model, namedtuples.CycleGANModel): + saved_params = locals() + saved_params.pop('gan_model', None) + with ops.name_scope('cyclegan_x2y_image_summaries'): + add_gan_model_image_summaries(gan_model.model_x2y, **saved_params) + with ops.name_scope('cyclegan_y2x_image_summaries'): + add_gan_model_image_summaries(gan_model.model_y2x, **saved_params) + return + _assert_is_image(gan_model.real_data) _assert_is_image(gan_model.generated_data) @@ -96,6 +106,15 @@ def add_image_comparison_summaries(gan_model, num_comparisons=2, ValueError: If the generator input, real, and generated data aren't all the same size. """ + if isinstance(gan_model, namedtuples.CycleGANModel): + saved_params = locals() + saved_params.pop('gan_model', None) + with ops.name_scope('cyclegan_x2y_image_comparison_summaries'): + add_image_comparison_summaries(gan_model.model_x2y, **saved_params) + with ops.name_scope('cyclegan_y2x_image_comparison_summaries'): + add_image_comparison_summaries(gan_model.model_y2x, **saved_params) + return + _assert_is_image(gan_model.generator_inputs) _assert_is_image(gan_model.generated_data) _assert_is_image(gan_model.real_data) @@ -133,6 +152,13 @@ def add_gan_model_summaries(gan_model): Args: gan_model: A GANModel tuple. """ + if isinstance(gan_model, namedtuples.CycleGANModel): + with ops.name_scope('cyclegan_x2y_summaries'): + add_gan_model_summaries(gan_model.model_x2y) + with ops.name_scope('cyclegan_y2x_summaries'): + add_gan_model_summaries(gan_model.model_y2x) + return + with ops.name_scope('generator_variables'): for var in gan_model.generator_variables: summary.histogram(var.name, var) @@ -147,6 +173,13 @@ def add_regularization_loss_summaries(gan_model): Args: gan_model: A GANModel tuple. """ + if isinstance(gan_model, namedtuples.CycleGANModel): + with ops.name_scope('cyclegan_x2y_regularization_loss_summaries'): + add_regularization_loss_summaries(gan_model.model_x2y) + with ops.name_scope('cyclegan_y2x_regularization_loss_summaries'): + add_regularization_loss_summaries(gan_model.model_y2x) + return + if gan_model.generator_scope: summary.scalar( 'generator_regularization_loss', diff --git a/tensorflow/contrib/gan/python/eval/python/summaries_test.py b/tensorflow/contrib/gan/python/eval/python/summaries_test.py index a3b02bcefc6cbaa6e24131b336b5c9c072bde52c..a02d8772e130a2a927735e56c4272aba1f1a6996 100644 --- a/tensorflow/contrib/gan/python/eval/python/summaries_test.py +++ b/tensorflow/contrib/gan/python/eval/python/summaries_test.py @@ -57,40 +57,83 @@ def get_gan_model(): discriminator_fn=discriminator_model) +def get_cyclegan_model(): + with variable_scope.variable_scope('x2y'): + model_x2y = get_gan_model() + with variable_scope.variable_scope('y2x'): + model_y2x = get_gan_model() + return namedtuples.CycleGANModel( + model_x2y=model_x2y, + model_y2x=model_y2x, + reconstructed_x=array_ops.zeros([3, 30, 35, 6]), + reconstructed_y=array_ops.zeros([3, 30, 35, 6])) + + class SummariesTest(test.TestCase): - def testAddGanModelImageSummaries(self): - summaries.add_gan_model_image_summaries(get_gan_model(), grid_size=2) + def _test_add_gan_model_image_summaries_impl(self, get_model_fn, + expected_num_summary_ops): + summaries.add_gan_model_image_summaries(get_model_fn(), grid_size=2) - self.assertEquals(5, len(ops.get_collection(ops.GraphKeys.SUMMARIES))) + self.assertEquals(expected_num_summary_ops, + len(ops.get_collection(ops.GraphKeys.SUMMARIES))) with self.test_session(use_gpu=True): variables.global_variables_initializer().run() summary.merge_all().eval() - def testAddGanModelSummaries(self): - summaries.add_gan_model_summaries(get_gan_model()) + def test_add_gan_model_image_summaries(self): + self._test_add_gan_model_image_summaries_impl(get_gan_model, 5) + + def test_add_gan_model_image_summaries_for_cyclegan(self): + self._test_add_gan_model_image_summaries_impl(get_cyclegan_model, 10) - self.assertEquals(3, len(ops.get_collection(ops.GraphKeys.SUMMARIES))) + def _test_add_gan_model_summaries_impl(self, get_model_fn, + expected_num_summary_ops): + summaries.add_gan_model_summaries(get_model_fn()) + + self.assertEquals(expected_num_summary_ops, + len(ops.get_collection(ops.GraphKeys.SUMMARIES))) with self.test_session(use_gpu=True): variables.global_variables_initializer().run() summary.merge_all().eval() - def testAddRegularizationLossSummaries(self): - summaries.add_regularization_loss_summaries(get_gan_model()) + def test_add_gan_model_summaries(self): + self._test_add_gan_model_summaries_impl(get_gan_model, 3) + + def test_add_gan_model_summaries_for_cyclegan(self): + self._test_add_gan_model_summaries_impl(get_cyclegan_model, 6) - self.assertEquals(2, len(ops.get_collection(ops.GraphKeys.SUMMARIES))) + def _test_add_regularization_loss_summaries_impl(self, get_model_fn, + expected_num_summary_ops): + summaries.add_regularization_loss_summaries(get_model_fn()) + + self.assertEquals(expected_num_summary_ops, + len(ops.get_collection(ops.GraphKeys.SUMMARIES))) with self.test_session(use_gpu=True): summary.merge_all().eval() + def test_add_regularization_loss_summaries(self): + self._test_add_regularization_loss_summaries_impl(get_gan_model, 2) + + def test_add_regularization_loss_summaries_for_cyclegan(self): + self._test_add_regularization_loss_summaries_impl(get_cyclegan_model, 4) + # TODO(joelshor): Add correctness test. - def testAddImageComparisonSummaries(self): - summaries.add_image_comparison_summaries( - get_gan_model(), display_diffs=True) + def _test_add_image_comparison_summaries_impl(self, get_model_fn, + expected_num_summary_ops): + summaries.add_image_comparison_summaries(get_model_fn(), display_diffs=True) - self.assertEquals(1, len(ops.get_collection(ops.GraphKeys.SUMMARIES))) + self.assertEquals(expected_num_summary_ops, + len(ops.get_collection(ops.GraphKeys.SUMMARIES))) with self.test_session(use_gpu=True): summary.merge_all().eval() + def test_add_image_comparison_summaries(self): + self._test_add_image_comparison_summaries_impl(get_gan_model, 1) + + def test_add_image_comparison_summaries_for_cyclegan(self): + self._test_add_image_comparison_summaries_impl(get_cyclegan_model, 2) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/gan/python/losses/python/losses_impl.py b/tensorflow/contrib/gan/python/losses/python/losses_impl.py index 940762cf2aa0f473cd41d9d543e2773b565a5248..23a3b60cc0055917bfc5243b0ebdbaea7b61edb9 100644 --- a/tensorflow/contrib/gan/python/losses/python/losses_impl.py +++ b/tensorflow/contrib/gan/python/losses/python/losses_impl.py @@ -67,6 +67,7 @@ __all__ = [ 'wasserstein_gradient_penalty', 'mutual_information_penalty', 'combine_adversarial_loss', + 'cycle_consistency_loss', ] @@ -915,3 +916,63 @@ def combine_adversarial_loss(main_loss, array_ops.stop_gradient(adv_coeff) * adversarial_loss) return final_loss + + +def cycle_consistency_loss(data_x, + reconstructed_data_x, + data_y, + reconstructed_data_y, + scope=None, + add_summaries=False): + """Defines the cycle consistency loss. + + The cyclegan model has two partial models where `model_x2y` generator F maps + data set X to Y, `model_y2x` generator G maps data set Y to X. For a `data_x` + in data set X, we could reconstruct it by + * reconstructed_data_x = G(F(data_x)) + Similarly + * reconstructed_data_y = F(G(data_y)) + + The cycle consistency loss is about the difference between data and + reconstructed data, namely + * loss_x2x = |data_x - G(F(data_x))| (L1-norm) + * loss_y2y = |data_y - F(G(data_y))| (L1-norm) + * loss = (loss_x2x + loss_y2y) / 2 + where `loss` is the final result. + + See https://arxiv.org/abs/1703.10593 for more details. + + Args: + data_x: A `Tensor` of data X. + reconstructed_data_x: A `Tensor` of reconstructed data X. + data_y: A `Tensor` of data Y. + reconstructed_data_y: A `Tensor` of reconstructed data Y. + scope: The scope for the operations performed in computing the loss. + Defaults to None. + add_summaries: Whether or not to add detailed summaries for the loss. + Defaults to False. + + Returns: + A scalar `Tensor` of cycle consistency loss. + """ + + def _partial_cycle_consistency_loss(data, reconstructed_data): + # Following the original implementation + # https://github.com/junyanz/CycleGAN/blob/master/models/cycle_gan_model.lua + # use L1-norm of pixel-wise error normalized by data size so that + # `cycle_loss_weight` can be specified independent of image size. + return math_ops.reduce_mean(math_ops.abs(data - reconstructed_data)) + + with ops.name_scope( + scope, + 'cycle_consistency_loss', + values=[data_x, reconstructed_data_x, data_y, reconstructed_data_y]): + loss_x2x = _partial_cycle_consistency_loss(data_x, reconstructed_data_x) + loss_y2y = _partial_cycle_consistency_loss(data_y, reconstructed_data_y) + loss = (loss_x2x + loss_y2y) / 2.0 + if add_summaries: + summary.scalar('cycle_consistency_loss_x2x', loss_x2x) + summary.scalar('cycle_consistency_loss_y2y', loss_y2y) + summary.scalar('cycle_consistency_loss', loss) + + return loss diff --git a/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py b/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py index b5cd8c92ba180e981e0faf877021cb6d69dc34b4..7d2a7a254f6656198e47325dbb351618d85d147c 100644 --- a/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py +++ b/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py @@ -623,5 +623,32 @@ class CombineAdversarialLossTest(test.TestCase): self.assertNear(gnorm_np, precond_gnorm_np, 1e-5) +class CycleConsistencyLossTest(test.TestCase): + """Tests for cycle_consistency_loss.""" + + def setUp(self): + super(CycleConsistencyLossTest, self).setUp() + + self._data_x_np = [[1.0, 2, 3], [4, 5, 6]] + self._reconstructed_data_x_np = [[7.0, 8, 9], [10, 11, 12]] + self._data_y_np = [1.0, 9] + self._reconstructed_data_y_np = [-2.0, 3] + + self._data_x = constant_op.constant(self._data_x_np, dtype=dtypes.float32) + self._reconstructed_data_x = constant_op.constant( + self._reconstructed_data_x_np, dtype=dtypes.float32) + self._data_y = constant_op.constant(self._data_y_np, dtype=dtypes.float32) + self._reconstructed_data_y = constant_op.constant( + self._reconstructed_data_y_np, dtype=dtypes.float32) + + def test_correct_loss(self): + loss = tfgan_losses.cycle_consistency_loss( + self._data_x, self._reconstructed_data_x, self._data_y, + self._reconstructed_data_y) + with self.test_session(use_gpu=True): + variables.global_variables_initializer().run() + self.assertNear(5.25, loss.eval(), 1e-5) + + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/gan/python/losses/python/tuple_losses_impl.py b/tensorflow/contrib/gan/python/losses/python/tuple_losses_impl.py index b341f03a0ddaacca8b036189516c71908bee50eb..dcc3f94c2d6b9e5e44036e7cc1a9d1bb39104fb5 100644 --- a/tensorflow/contrib/gan/python/losses/python/tuple_losses_impl.py +++ b/tensorflow/contrib/gan/python/losses/python/tuple_losses_impl.py @@ -60,6 +60,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.contrib.gan.python import namedtuples from tensorflow.contrib.gan.python.losses.python import losses_impl from tensorflow.python.util import tf_inspect @@ -78,6 +79,7 @@ __all__ = [ 'wasserstein_gradient_penalty', 'mutual_information_penalty', 'combine_adversarial_loss', + 'cycle_consistency_loss', ] @@ -246,3 +248,32 @@ def combine_adversarial_loss(gan_loss, scalar_summaries, gradient_summaries) return gan_loss._replace(generator_loss=combined_loss) + + +def cycle_consistency_loss(cyclegan_model, scope=None, add_summaries=False): + """Defines the cycle consistency loss. + + Uses `cycle_consistency_loss` to compute the cycle consistency loss for a + `cyclegan_model`. + + Args: + cyclegan_model: A `CycleGANModel` namedtuple. + scope: The scope for the operations performed in computing the loss. + Defaults to None. + add_summaries: Whether or not to add detailed summaries for the loss. + Defaults to False. + + Returns: + A scalar `Tensor` of cycle consistency loss. + + Raises: + ValueError: If `cyclegan_model` is not a `CycleGANModel` namedtuple. + """ + if not isinstance(cyclegan_model, namedtuples.CycleGANModel): + raise ValueError( + '`cyclegan_model` must be a `CycleGANModel`. Instead, was %s.' % + type(cyclegan_model)) + return losses_impl.cycle_consistency_loss( + cyclegan_model.model_x2y.generator_inputs, cyclegan_model.reconstructed_x, + cyclegan_model.model_y2x.generator_inputs, cyclegan_model.reconstructed_y, + scope, add_summaries) diff --git a/tensorflow/contrib/gan/python/losses/python/tuple_losses_test.py b/tensorflow/contrib/gan/python/losses/python/tuple_losses_test.py index 215b15ef6915d0b8113def35987ed6ab85617bcc..aa1ef11172dee6799994b87f70a3883cd67fd15b 100644 --- a/tensorflow/contrib/gan/python/losses/python/tuple_losses_test.py +++ b/tensorflow/contrib/gan/python/losses/python/tuple_losses_test.py @@ -22,8 +22,11 @@ import collections import numpy as np +from tensorflow.contrib.gan.python import namedtuples from tensorflow.contrib.gan.python.losses.python import tuple_losses_impl as tfgan_losses - +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import variables from tensorflow.python.platform import test @@ -125,6 +128,7 @@ manual_tests = [ 'combine_adversarial_loss', 'mutual_information_penalty', 'wasserstein_gradient_penalty', + 'cycle_consistency_loss', ] discriminator_keyword_args = { @@ -139,6 +143,38 @@ generator_keyword_args = { } +class CycleConsistencyLossTest(test.TestCase): + + def setUp(self): + super(CycleConsistencyLossTest, self).setUp() + + def _partial_model(generator_inputs_np): + model = namedtuples.GANModel(*[None] * 11) + return model._replace( + generator_inputs=constant_op.constant( + generator_inputs_np, dtype=dtypes.float32)) + + self._model_x2y = _partial_model([1, 2]) + self._model_y2x = _partial_model([5, 6]) + + def test_model_type(self): + """Test the input model type for `cycle_consistency_loss`.""" + with self.assertRaises(ValueError): + tfgan_losses.cycle_consistency_loss(self._model_x2y) + + def test_correct_loss(self): + """Test the output of `cycle_consistency_loss`.""" + loss = tfgan_losses.cycle_consistency_loss( + namedtuples.CycleGANModel( + model_x2y=self._model_x2y, + model_y2x=self._model_y2x, + reconstructed_x=constant_op.constant([9, 8], dtype=dtypes.float32), + reconstructed_y=constant_op.constant([7, 2], dtype=dtypes.float32))) + with self.test_session(use_gpu=True): + variables.global_variables_initializer().run() + self.assertNear(5.0, loss.eval(), 1e-5) + + if __name__ == '__main__': for loss_name in tfgan_losses.__all__: if loss_name in manual_tests: continue diff --git a/tensorflow/contrib/gan/python/namedtuples.py b/tensorflow/contrib/gan/python/namedtuples.py index 3d4e315ebd0bd52b3b5e3e4a8655df8bfe9cebe8..25cfeafeec9000b0dc3849ebe646e59c1b4d1cc3 100644 --- a/tensorflow/contrib/gan/python/namedtuples.py +++ b/tensorflow/contrib/gan/python/namedtuples.py @@ -30,7 +30,9 @@ __all__ = [ 'GANModel', 'InfoGANModel', 'ACGANModel', + 'CycleGANModel', 'GANLoss', + 'CycleGANLoss', 'GANTrainOps', 'GANTrainSteps', ] @@ -115,6 +117,25 @@ class ACGANModel( """ +class CycleGANModel( + collections.namedtuple( + 'CycleGANModel', + ('model_x2y', 'model_y2x', 'reconstructed_x', 'reconstructed_y'))): + """An CycleGANModel contains all the pieces needed for CycleGAN training. + + The model `model_x2y` generator F maps data set X to Y, while the model + `model_y2x` generator G maps data set Y to X. + + See https://arxiv.org/abs/1703.10593 for more details. + + Args: + model_x2y: A `GANModel` namedtuple whose generator maps data set X to Y. + model_y2x: A `GANModel` namedtuple whose generator maps data set Y to X. + reconstructed_x: A `Tensor` of reconstructed data X which is G(F(X)). + reconstructed_y: A `Tensor` of reconstructed data Y which is F(G(Y)). + """ + + class GANLoss( collections.namedtuple('GANLoss', ( 'generator_loss', @@ -128,6 +149,18 @@ class GANLoss( """ +class CycleGANLoss( + collections.namedtuple('CycleGANLoss', ('loss_x2y', 'loss_y2x'))): + """CycleGANLoss contains the losses for `CycleGANModel`. + + See https://arxiv.org/abs/1703.10593 for more details. + + Args: + loss_x2y: A `GANLoss` namedtuple representing the loss of `model_x2y`. + loss_y2x: A `GANLoss` namedtuple representing the loss of `model_y2x`. + """ + + class GANTrainOps( collections.namedtuple('GANTrainOps', ( 'generator_train_op', diff --git a/tensorflow/contrib/gan/python/train.py b/tensorflow/contrib/gan/python/train.py index c429ec48314b1f036beceb564bcf6d1e2a6d3b2e..5d0ac93aec7869bb1d9b8a174ba50d4bec2c2826 100644 --- a/tensorflow/contrib/gan/python/train.py +++ b/tensorflow/contrib/gan/python/train.py @@ -52,7 +52,9 @@ __all__ = [ 'gan_model', 'infogan_model', 'acgan_model', + 'cyclegan_model', 'gan_loss', + 'cyclegan_loss', 'gan_train_ops', 'gan_train', 'get_sequential_train_hooks', @@ -277,14 +279,16 @@ def acgan_model( generator_inputs = _convert_tensor_or_l_or_d(generator_inputs) generated_data = generator_fn(generator_inputs) with variable_scope.variable_scope(discriminator_scope) as dis_scope: - (discriminator_gen_outputs, discriminator_gen_classification_logits - ) = _validate_acgan_discriminator_outputs( - discriminator_fn(generated_data, generator_inputs)) + with ops.name_scope(dis_scope.name+'/generated/'): + (discriminator_gen_outputs, discriminator_gen_classification_logits + ) = _validate_acgan_discriminator_outputs( + discriminator_fn(generated_data, generator_inputs)) with variable_scope.variable_scope(dis_scope, reuse=True): - real_data = ops.convert_to_tensor(real_data) - (discriminator_real_outputs, discriminator_real_classification_logits - ) = _validate_acgan_discriminator_outputs( - discriminator_fn(real_data, generator_inputs)) + with ops.name_scope(dis_scope.name+'/real/'): + real_data = ops.convert_to_tensor(real_data) + (discriminator_real_outputs, discriminator_real_classification_logits + ) = _validate_acgan_discriminator_outputs( + discriminator_fn(real_data, generator_inputs)) if check_shapes: if not generated_data.shape.is_compatible_with(real_data.shape): raise ValueError( @@ -305,6 +309,76 @@ def acgan_model( discriminator_gen_classification_logits) +def cyclegan_model( + # Lambdas defining models. + generator_fn, + discriminator_fn, + # data X and Y. + data_x, + data_y, + # Optional scopes. + generator_scope='Generator', + discriminator_scope='Discriminator', + model_x2y_scope='ModelX2Y', + model_y2x_scope='ModelY2X', + # Options. + check_shapes=True): + """Returns a CycleGAN model outputs and variables. + + See https://arxiv.org/abs/1703.10593 for more details. + + Args: + generator_fn: A python lambda that takes `data_x` or `data_y` as inputs and + returns the outputs of the GAN generator. + discriminator_fn: A python lambda that takes `real_data`/`generated data` + and `generator_inputs`. Outputs a Tensor in the range [-inf, inf]. + data_x: A `Tensor` of dataset X. Must be the same shape as `data_y`. + data_y: A `Tensor` of dataset Y. Must be the same shape as `data_x`. + generator_scope: Optional generator variable scope. Useful if you want to + reuse a subgraph that has already been created. Defaults to 'Generator'. + discriminator_scope: Optional discriminator variable scope. Useful if you + want to reuse a subgraph that has already been created. Defaults to + 'Discriminator'. + model_x2y_scope: Optional variable scope for model x2y variables. Defaults + to 'ModelX2Y'. + model_y2x_scope: Optional variable scope for model y2x variables. Defaults + to 'ModelY2X'. + check_shapes: If `True`, check that generator produces Tensors that are the + same shape as `data_x` (`data_y`). Otherwise, skip this check. + + Returns: + A `CycleGANModel` namedtuple. + + Raises: + ValueError: If `check_shapes` is True and `data_x` or the generator output + does not have the same shape as `data_y`. + """ + + # Create models. + def _define_partial_model(input_data, output_data): + return gan_model( + generator_fn=generator_fn, + discriminator_fn=discriminator_fn, + real_data=output_data, + generator_inputs=input_data, + generator_scope=generator_scope, + discriminator_scope=discriminator_scope, + check_shapes=check_shapes) + + with variable_scope.variable_scope(model_x2y_scope): + model_x2y = _define_partial_model(data_x, data_y) + with variable_scope.variable_scope(model_y2x_scope): + model_y2x = _define_partial_model(data_y, data_x) + + with variable_scope.variable_scope(model_y2x.generator_scope, reuse=True): + reconstructed_x = model_y2x.generator_fn(model_x2y.generated_data) + with variable_scope.variable_scope(model_x2y.generator_scope, reuse=True): + reconstructed_y = model_x2y.generator_fn(model_y2x.generated_data) + + return namedtuples.CycleGANModel(model_x2y, model_y2x, reconstructed_x, + reconstructed_y) + + def _validate_aux_loss_weight(aux_loss_weight, name='aux_loss_weight'): if isinstance(aux_loss_weight, ops.Tensor): aux_loss_weight.shape.assert_is_compatible_with([]) @@ -494,6 +568,69 @@ def gan_loss( return namedtuples.GANLoss(gen_loss + gen_reg_loss, dis_loss + dis_reg_loss) +def cyclegan_loss( + model, + # Loss functions. + generator_loss_fn=tfgan_losses.least_squares_generator_loss, + discriminator_loss_fn=tfgan_losses.least_squares_discriminator_loss, + # Auxiliary losses. + cycle_consistency_loss_fn=tfgan_losses.cycle_consistency_loss, + cycle_consistency_loss_weight=10.0, + # Options + **kwargs): + """Returns the losses for a `CycleGANModel`. + + See https://arxiv.org/abs/1703.10593 for more details. + + Args: + model: A `CycleGANModel` namedtuple. + generator_loss_fn: The loss function on the generator. Takes a `GANModel` + named tuple. + discriminator_loss_fn: The loss function on the discriminator. Takes a + `GANModel` namedtuple. + cycle_consistency_loss_fn: The cycle consistency loss function. Takes a + `CycleGANModel` namedtuple. + cycle_consistency_loss_weight: A non-negative Python number or a scalar + `Tensor` indicating how much to weigh the cycle consistency loss. + **kwargs: Keyword args to pass directly to `gan_loss` to construct the loss + for each partial model of `model`. + + Returns: + A `CycleGANLoss` namedtuple. + + Raises: + ValueError: If `model` is not a `CycleGANModel` namedtuple. + """ + # Sanity checks. + if not isinstance(model, namedtuples.CycleGANModel): + raise ValueError( + '`model` must be a `CycleGANModel`. Instead, was %s.' % type(model)) + + # Defines cycle consistency loss. + cycle_consistency_loss = cycle_consistency_loss_fn( + model, add_summaries=kwargs.get('add_summaries', True)) + cycle_consistency_loss_weight = _validate_aux_loss_weight( + cycle_consistency_loss_weight, 'cycle_consistency_loss_weight') + aux_loss = cycle_consistency_loss_weight * cycle_consistency_loss + + # Defines losses for each partial model. + def _partial_loss(partial_model): + partial_loss = gan_loss( + partial_model, + generator_loss_fn=generator_loss_fn, + discriminator_loss_fn=discriminator_loss_fn, + **kwargs) + return partial_loss._replace( + generator_loss=partial_loss.generator_loss + aux_loss) + + with ops.name_scope('cyclegan_loss_x2y'): + loss_x2y = _partial_loss(model.model_x2y) + with ops.name_scope('cyclegan_loss_y2x'): + loss_y2x = _partial_loss(model.model_y2x) + + return namedtuples.CycleGANLoss(loss_x2y, loss_y2x) + + def _get_update_ops(kwargs, gen_scope, dis_scope, check_for_unused_ops=True): """Gets generator and discriminator update ops. @@ -561,6 +698,24 @@ def gan_train_ops( A GANTrainOps tuple of (generator_train_op, discriminator_train_op) that can be used to train a generator/discriminator pair. """ + if isinstance(model, namedtuples.CycleGANModel): + saved_params = locals() + saved_params.pop('model', None) + saved_params.pop('loss', None) + kwargs = saved_params.pop('kwargs', {}) + saved_params.update(kwargs) + with ops.name_scope('cyclegan_x2y_train'): + train_ops_x2y = gan_train_ops(model.model_x2y, loss.loss_x2y, + **saved_params) + with ops.name_scope('cyclegan_y2x_train'): + train_ops_y2x = gan_train_ops(model.model_y2x, loss.loss_y2x, + **saved_params) + return namedtuples.GANTrainOps( + (train_ops_x2y.generator_train_op, train_ops_y2x.generator_train_op), + (train_ops_x2y.discriminator_train_op, + train_ops_y2x.discriminator_train_op), + training_util.get_or_create_global_step().assign_add(1)) + # Create global step increment op. global_step = training_util.get_or_create_global_step() global_step_inc = global_step.assign_add(1) diff --git a/tensorflow/contrib/gan/python/train_test.py b/tensorflow/contrib/gan/python/train_test.py index 58704e68594e947041697ec6cb1d240e1f505aae..f9bdaa74c948ecee11d5cfd89f06087924f8dace 100644 --- a/tensorflow/contrib/gan/python/train_test.py +++ b/tensorflow/contrib/gan/python/train_test.py @@ -210,6 +210,38 @@ def create_callable_acgan_model(): one_hot_labels=array_ops.one_hot([0, 1, 2], 10)) +def get_cyclegan_model(): + return namedtuples.CycleGANModel( + model_x2y=get_gan_model(), + model_y2x=get_gan_model(), + reconstructed_x=array_ops.ones([1, 2, 3]), + reconstructed_y=array_ops.zeros([1, 2, 3])) + + +def get_callable_cyclegan_model(): + return namedtuples.CycleGANModel( + model_x2y=get_callable_gan_model(), + model_y2x=get_callable_gan_model(), + reconstructed_x=array_ops.ones([1, 2, 3]), + reconstructed_y=array_ops.zeros([1, 2, 3])) + + +def create_cyclegan_model(): + return train.cyclegan_model( + generator_model, + discriminator_model, + data_x=array_ops.zeros([1, 2]), + data_y=array_ops.ones([1, 2])) + + +def create_callable_cyclegan_model(): + return train.cyclegan_model( + Generator(), + Discriminator(), + data_x=array_ops.zeros([1, 2]), + data_y=array_ops.ones([1, 2])) + + def get_sync_optimizer(): return sync_replicas_optimizer.SyncReplicasOptimizer( gradient_descent.GradientDescentOptimizer(learning_rate=1.0), @@ -261,6 +293,13 @@ class GANModelTest(test.TestCase): self._test_output_type_helper( get_callable_acgan_model, namedtuples.ACGANModel) + def test_output_type_cyclegan(self): + self._test_output_type_helper(get_cyclegan_model, namedtuples.CycleGANModel) + + def test_output_type_callable_cyclegan(self): + self._test_output_type_helper(get_callable_cyclegan_model, + namedtuples.CycleGANModel) + def test_no_shape_check(self): def dummy_generator_model(_): return (None, None) @@ -308,6 +347,17 @@ class GANLossTest(test.TestCase): def test_output_type_callable_acgan(self): self._test_output_type_helper(get_callable_acgan_model) + def test_output_type_cyclegan(self): + loss = train.cyclegan_loss(create_cyclegan_model(), add_summaries=True) + self.assertIsInstance(loss, namedtuples.CycleGANLoss) + self.assertGreater(len(ops.get_collection(ops.GraphKeys.SUMMARIES)), 0) + + def test_output_type_callable_cyclegan(self): + loss = train.cyclegan_loss( + create_callable_cyclegan_model(), add_summaries=True) + self.assertIsInstance(loss, namedtuples.CycleGANLoss) + self.assertGreater(len(ops.get_collection(ops.GraphKeys.SUMMARIES)), 0) + # Test gradient penalty option. def _test_grad_penalty_helper(self, create_gan_model_fn): model = create_gan_model_fn() @@ -431,6 +481,34 @@ class GANLossTest(test.TestCase): def test_callable_acgan(self): self._test_acgan_helper(create_callable_acgan_model) + # Test that CycleGan models work. + def _test_cyclegan_helper(self, create_gan_model_fn): + model = create_gan_model_fn() + loss = train.cyclegan_loss(model) + self.assertIsInstance(loss, namedtuples.CycleGANLoss) + + # Check values. + with self.test_session(use_gpu=True) as sess: + variables.global_variables_initializer().run() + (loss_x2y_gen_np, loss_x2y_dis_np, loss_y2x_gen_np, + loss_y2x_dis_np) = sess.run([ + loss.loss_x2y.generator_loss, loss.loss_x2y.discriminator_loss, + loss.loss_y2x.generator_loss, loss.loss_y2x.discriminator_loss + ]) + + self.assertGreater(loss_x2y_gen_np, loss_x2y_dis_np) + self.assertGreater(loss_y2x_gen_np, loss_y2x_dis_np) + self.assertTrue(np.isscalar(loss_x2y_gen_np)) + self.assertTrue(np.isscalar(loss_x2y_dis_np)) + self.assertTrue(np.isscalar(loss_y2x_gen_np)) + self.assertTrue(np.isscalar(loss_y2x_dis_np)) + + def test_cyclegan(self): + self._test_cyclegan_helper(create_cyclegan_model) + + def test_callable_cyclegan(self): + self._test_cyclegan_helper(create_callable_cyclegan_model) + def _check_tensor_pool_adjusted_model_outputs(self, tensor1, tensor2, pool_size): history_values = [] diff --git a/tensorflow/contrib/gdr/BUILD b/tensorflow/contrib/gdr/BUILD index bdbe6f0a72621e59562fe113da101ff5a2b8c06d..707ae25d485c64f15694ee0e357f32b619d3cd33 100644 --- a/tensorflow/contrib/gdr/BUILD +++ b/tensorflow/contrib/gdr/BUILD @@ -82,6 +82,7 @@ tf_cuda_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core/distributed_runtime:graph_mgr", + "//tensorflow/core/distributed_runtime:recent_request_ids", "//tensorflow/core/distributed_runtime:rendezvous_mgr_interface", "//tensorflow/core/distributed_runtime:worker", "//tensorflow/core/distributed_runtime:worker_cache", @@ -103,6 +104,7 @@ cc_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core/distributed_runtime:base_rendezvous_mgr", + "//tensorflow/core/distributed_runtime:request_id", "//tensorflow/core/distributed_runtime:tensor_coding", "//tensorflow/core/distributed_runtime:worker_cache", "//tensorflow/core/distributed_runtime:worker_env", diff --git a/tensorflow/contrib/gdr/gdr_rendezvous_mgr.cc b/tensorflow/contrib/gdr/gdr_rendezvous_mgr.cc index adef2aac33e3e0839a268eabe2496e58861535c5..28f68cec8cce126f1b177a73e197ccd7ab749f4a 100644 --- a/tensorflow/contrib/gdr/gdr_rendezvous_mgr.cc +++ b/tensorflow/contrib/gdr/gdr_rendezvous_mgr.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/process_util.h" +#include "tensorflow/core/distributed_runtime/request_id.h" #include "tensorflow/core/distributed_runtime/tensor_coding.h" #include "tensorflow/core/distributed_runtime/worker_cache.h" #include "tensorflow/core/distributed_runtime/worker_interface.h" @@ -47,6 +48,7 @@ class GdrRecvTensorCall : public BaseRecvTensorCall { recv_args_(recv_args) { req_.set_step_id(step_id); req_.set_rendezvous_key(key.data(), key.size()); + req_.set_request_id(GetUniqueRequestId()); } ~GdrRecvTensorCall() override {} diff --git a/tensorflow/contrib/gdr/gdr_worker.cc b/tensorflow/contrib/gdr/gdr_worker.cc index 568641234731a458a05886d12066ee9f55fa58aa..ce1d8d2d73000559f03046aceacb169890ecc1b6 100644 --- a/tensorflow/contrib/gdr/gdr_worker.cc +++ b/tensorflow/contrib/gdr/gdr_worker.cc @@ -41,17 +41,26 @@ namespace tensorflow { GdrWorker::GdrWorker(WorkerEnv* worker_env, RemoteMemoryManager* remote_memory_manager) - : GrpcWorker(worker_env), remote_memory_manager_(remote_memory_manager) {} + : GrpcWorker(worker_env), + remote_memory_manager_(remote_memory_manager), + recv_tensor_recent_request_ids_(100000) {} void GdrWorker::GrpcRecvTensorAsync(CallOptions* opts, const RecvTensorRequest* request, ::grpc::ByteBuffer* response, StatusCallback done) { + Status s = recv_tensor_recent_request_ids_.TrackUnique( + request->request_id(), "RecvTensor (GdrWorker)", *request); + if (!s.ok()) { + done(s); + return; + } + const int64 step_id = request->step_id(); const string& key = request->rendezvous_key(); TRACEPRINTF("RecvTensor: %lld %s", step_id, key.c_str()); Rendezvous::ParsedKey parsed; - Status s = Rendezvous::ParseKey(key, &parsed); + s = Rendezvous::ParseKey(key, &parsed); Device* src_dev = nullptr; if (s.ok()) { s = PrepareRecvTensor(parsed, &src_dev); diff --git a/tensorflow/contrib/gdr/gdr_worker.h b/tensorflow/contrib/gdr/gdr_worker.h index a30b7baaedcbc80d93d7f37756732c37d2435935..54081f655ec087d78ac07974656257dcf478bcef 100644 --- a/tensorflow/contrib/gdr/gdr_worker.h +++ b/tensorflow/contrib/gdr/gdr_worker.h @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/contrib/gdr/gdr_memory_manager.h" +#include "tensorflow/core/distributed_runtime/recent_request_ids.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h" namespace tensorflow { @@ -38,6 +39,7 @@ class GdrWorker : public GrpcWorker { private: RemoteMemoryManager* remote_memory_manager_; // Not owned + RecentRequestIds recv_tensor_recent_request_ids_; }; } // namespace tensorflow diff --git a/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op.h b/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op.h index 194ae2ba47456cac66c01989a78ab4ce607d1295..8968da6d8241ca7cd548910a024a618913c3ed70 100644 --- a/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op.h +++ b/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op.h @@ -11,8 +11,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_IMAGE_KERNELS_ADJUST_HSV_IN_YIQ_OP_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_IMAGE_KERNELS_ADJUST_HSV_IN_YIQ_OP_H_ +#ifndef TENSORFLOW_CONTRIB_IMAGE_KERNELS_ADJUST_HSV_IN_YIQ_OP_H_ +#define TENSORFLOW_CONTRIB_IMAGE_KERNELS_ADJUST_HSV_IN_YIQ_OP_H_ #if GOOGLE_CUDA #define EIGEN_USE_GPU @@ -84,4 +84,4 @@ struct AdjustHsvInYiqGPU { } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_IMAGE_KERNELS_ADJUST_HSV_IN_YIQ_OP_H_ +#endif // TENSORFLOW_CONTRIB_IMAGE_KERNELS_ADJUST_HSV_IN_YIQ_OP_H_ diff --git a/tensorflow/contrib/kfac/python/kernel_tests/loss_functions_test.py b/tensorflow/contrib/kfac/python/kernel_tests/loss_functions_test.py index 63f45ea55b3d1f65a113e8c81a822a08613672df..ae787b6f1ac90218f2ac73d37fb270df0b822de2 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/loss_functions_test.py +++ b/tensorflow/contrib/kfac/python/kernel_tests/loss_functions_test.py @@ -113,6 +113,42 @@ class CategoricalLogitsNegativeLogProbLossTest(test.TestCase): self.assertListEqual(loss.input_minibatches, tower_logits) self.assertEqual(loss.num_registered_minibatches, num_towers) + def testMultiplyFisherSingleVector(self): + with ops.Graph().as_default(), self.test_session() as sess: + logits = np.array([1., 2., 3.]) + loss = loss_functions.CategoricalLogitsNegativeLogProbLoss(logits) + + # the LossFunction.multiply_fisher docstring only says it supports the + # case where the vector is the same shape as the input natural parameters + # (i.e. the logits here), but here we also test leading dimensions + vector = np.array([1., 2., 3.]) + vectors = [vector, vector.reshape(1, -1), np.stack([vector] * 4)] + + probs = np.exp(logits - np.logaddexp.reduce(logits)) + fisher = np.diag(probs) - np.outer(probs, probs) + + for vector in vectors: + result = loss.multiply_fisher(vector) + expected_result = np.dot(vector, fisher) + self.assertAllClose(expected_result, sess.run(result)) + + def testMultiplyFisherBatch(self): + with ops.Graph().as_default(), self.test_session() as sess: + logits = np.array([[1., 2., 3.], [4., 6., 8.]]) + loss = loss_functions.CategoricalLogitsNegativeLogProbLoss(logits) + + vector = np.array([[1., 2., 3.], [5., 3., 1.]]) + + na = np.newaxis + probs = np.exp(logits - np.logaddexp.reduce(logits, axis=-1, + keepdims=True)) + fishers = probs[..., na] * np.eye(3) - probs[..., na] * probs[..., na, :] + + result = loss.multiply_fisher(vector) + expected_result = np.matmul(vector[..., na, :], fishers)[..., 0, :] + self.assertEqual(sess.run(result).shape, logits.shape) + self.assertAllClose(expected_result, sess.run(result)) + class OnehotCategoricalLogitsNegativeLogProbLossTest(test.TestCase): diff --git a/tensorflow/contrib/kfac/python/ops/loss_functions.py b/tensorflow/contrib/kfac/python/ops/loss_functions.py index 2daead2a7180fe57b715bd896303cd4c3fbdaca8..cb3e698b9ceab920785adf735f88bd8e535a628f 100644 --- a/tensorflow/contrib/kfac/python/ops/loss_functions.py +++ b/tensorflow/contrib/kfac/python/ops/loss_functions.py @@ -660,19 +660,20 @@ class CategoricalLogitsNegativeLogProbLoss(DistributionNegativeLogProbLoss, def multiply_fisher(self, vector): probs = self._probs - return vector * probs - math_ops.reduce_sum(vector * probs, axis=1) * probs + return vector * probs - probs * math_ops.reduce_sum( + vector * probs, axis=-1, keep_dims=True) def multiply_fisher_factor(self, vector): probs = self._probs sqrt_probs = self._sqrt_probs return sqrt_probs * vector - probs * math_ops.reduce_sum( - sqrt_probs * vector, axis=1, keep_dims=True) + sqrt_probs * vector, axis=-1, keep_dims=True) def multiply_fisher_factor_transpose(self, vector): probs = self._probs sqrt_probs = self._sqrt_probs return sqrt_probs * vector - sqrt_probs * math_ops.reduce_sum( - probs * vector, axis=1, keep_dims=True) + probs * vector, axis=-1, keep_dims=True) def multiply_fisher_factor_replicated_one_hot(self, index): assert len(index) == 1, "Length of index was {}".format(len(index)) diff --git a/tensorflow/contrib/labeled_tensor/python/ops/core_test.py b/tensorflow/contrib/labeled_tensor/python/ops/core_test.py index 1f4a3ef568efc459d4a36fcb0d5de7e0bce8335c..e70b4923749d89aba1bd0187857d762305daeb07 100644 --- a/tensorflow/contrib/labeled_tensor/python/ops/core_test.py +++ b/tensorflow/contrib/labeled_tensor/python/ops/core_test.py @@ -225,7 +225,7 @@ class LabeledTensorTest(test_util.Base): tensor = array_ops.placeholder(dtypes.string, [None]) actual = core.LabeledTensor(tensor, ['x']) self.assertIsNone(actual.axes['x'].size) - self.assertIs(actual.axes['x'].value, tensor.get_shape()[0]) + self.assertIsNone(actual.axes['x'].value.value) def test_eq(self): self.assertEqual(self.lt, self.lt) diff --git a/tensorflow/contrib/layers/__init__.py b/tensorflow/contrib/layers/__init__.py index 6c624929f20503054e0258aad8a843f4a201be64..ef419862b49f4d03d9b711c49155d4ae1252d5bc 100644 --- a/tensorflow/contrib/layers/__init__.py +++ b/tensorflow/contrib/layers/__init__.py @@ -27,6 +27,7 @@ See the @{$python/contrib.layers} guide. @@convolution2d_transpose @@conv3d_transpose @@convolution3d_transpose +@@dense_to_sparse @@dropout @@elu @@embedding_lookup_unique diff --git a/tensorflow/contrib/layers/python/layers/layers.py b/tensorflow/contrib/layers/python/layers/layers.py index f3229a1605c72c61d0d1cc638a9a21048ac60cbe..c8e3307ee8b5ded30dc864c4e69452f58685b8f0 100644 --- a/tensorflow/contrib/layers/python/layers/layers.py +++ b/tensorflow/contrib/layers/python/layers/layers.py @@ -29,6 +29,7 @@ from tensorflow.contrib.framework.python.ops import variables from tensorflow.contrib.layers.python.layers import initializers from tensorflow.contrib.layers.python.layers import utils from tensorflow.python.eager import context +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import function from tensorflow.python.framework import ops @@ -54,47 +55,18 @@ from tensorflow.python.layers.maxout import maxout # TODO(b/28426988): Replace legacy_* fns migrated from slim. # TODO(b/28426988): Remove legacy_* when all uses have migrated to new API. -__all__ = ['avg_pool2d', - 'avg_pool3d', - 'batch_norm', - 'bias_add', - 'conv2d', - 'conv3d', - 'conv2d_in_plane', - 'conv2d_transpose', - 'conv3d_transpose', - 'convolution', - 'convolution2d', - 'convolution2d_in_plane', - 'convolution2d_transpose', - 'convolution3d', - 'convolution3d_transpose', - 'dropout', - 'elu', - 'flatten', - 'fully_connected', - 'GDN', - 'gdn', - 'layer_norm', - 'linear', - 'pool', - 'max_pool2d', - 'max_pool3d', - 'one_hot_encoding', - 'relu', - 'relu6', - 'repeat', - 'scale_gradient', - 'separable_conv2d', - 'separable_convolution2d', - 'softmax', - 'spatial_softmax', - 'stack', - 'unit_norm', - 'legacy_fully_connected', - 'legacy_linear', - 'legacy_relu', - 'maxout'] +__all__ = [ + 'avg_pool2d', 'avg_pool3d', 'batch_norm', 'bias_add', 'conv2d', 'conv3d', + 'conv2d_in_plane', 'conv2d_transpose', 'conv3d_transpose', 'convolution', + 'convolution2d', 'convolution2d_in_plane', 'convolution2d_transpose', + 'convolution3d', 'convolution3d_transpose', 'dense_to_sparse', + 'dropout', 'elu', 'flatten', + 'fully_connected', 'GDN', 'gdn', 'layer_norm', 'linear', 'pool', + 'max_pool2d', 'max_pool3d', 'one_hot_encoding', 'relu', 'relu6', 'repeat', + 'scale_gradient', 'separable_conv2d', 'separable_convolution2d', 'softmax', + 'spatial_softmax', 'stack', 'unit_norm', 'legacy_fully_connected', + 'legacy_linear', 'legacy_relu', 'maxout' +] DATA_FORMAT_NCHW = 'NCHW' DATA_FORMAT_NHWC = 'NHWC' @@ -139,13 +111,14 @@ def avg_pool2d(inputs, raise ValueError('data_format has to be either NCHW or NHWC.') with ops.name_scope(scope, 'AvgPool2D', [inputs]) as sc: inputs = ops.convert_to_tensor(inputs) - df = ('channels_first' if data_format and data_format.startswith('NC') - else 'channels_last') - layer = pooling_layers.AveragePooling2D(pool_size=kernel_size, - strides=stride, - padding=padding, - data_format=df, - _scope=sc) + df = ('channels_first' + if data_format and data_format.startswith('NC') else 'channels_last') + layer = pooling_layers.AveragePooling2D( + pool_size=kernel_size, + strides=stride, + padding=padding, + data_format=df, + _scope=sc) outputs = layer.apply(inputs) return utils.collect_named_outputs(outputs_collections, sc, outputs) @@ -187,13 +160,14 @@ def avg_pool3d(inputs, raise ValueError('data_format has to be either NCDHW or NDHWC.') with ops.name_scope(scope, 'AvgPool3D', [inputs]) as sc: inputs = ops.convert_to_tensor(inputs) - df = ('channels_first' if data_format and data_format.startswith('NC') - else 'channels_last') - layer = pooling_layers.AveragePooling3D(pool_size=kernel_size, - strides=stride, - padding=padding, - data_format=df, - _scope=sc) + df = ('channels_first' + if data_format and data_format.startswith('NC') else 'channels_last') + layer = pooling_layers.AveragePooling3D( + pool_size=kernel_size, + strides=stride, + padding=padding, + data_format=df, + _scope=sc) outputs = layer.apply(inputs) return utils.collect_named_outputs(outputs_collections, sc, outputs) @@ -298,8 +272,8 @@ def _fused_batch_norm(inputs, raise ValueError('Inputs %s has undefined rank' % inputs.name) elif original_rank not in [2, 4]: raise ValueError('Inputs %s has unsupported rank.' - ' Expected 2 or 4 but got %d' % ( - inputs.name, original_rank)) + ' Expected 2 or 4 but got %d' % (inputs.name, + original_rank)) if original_rank == 2: channels = inputs.get_shape()[-1].value if channels is None: @@ -393,6 +367,7 @@ def _fused_batch_norm(inputs, def _fused_batch_norm_training(): return nn.fused_batch_norm( inputs, gamma, beta, epsilon=epsilon, data_format=data_format) + def _fused_batch_norm_inference(): return nn.fused_batch_norm( inputs, @@ -403,9 +378,9 @@ def _fused_batch_norm(inputs, epsilon=epsilon, is_training=False, data_format=data_format) - outputs, mean, variance = utils.smart_cond(is_training, - _fused_batch_norm_training, - _fused_batch_norm_inference) + + outputs, mean, variance = utils.smart_cond( + is_training, _fused_batch_norm_training, _fused_batch_norm_inference) # If `is_training` doesn't have a constant value, because it is a `Tensor`, # a `Variable` or `Placeholder` then is_training_value will be None and @@ -415,6 +390,7 @@ def _fused_batch_norm(inputs, if need_updates: if updates_collections is None: no_updates = lambda: outputs + def _force_updates(): """Internal function forces updates moving_vars if is_training.""" update_moving_mean = moving_averages.assign_moving_average( @@ -424,9 +400,11 @@ def _fused_batch_norm(inputs, with ops.control_dependencies( [update_moving_mean, update_moving_variance]): return array_ops.identity(outputs) + outputs = utils.smart_cond(is_training, _force_updates, no_updates) else: moving_vars_fn = lambda: (moving_mean, moving_variance) + def _delay_updates(): """Internal function that delay updates moving_vars if is_training.""" update_moving_mean = moving_averages.assign_moving_average( @@ -434,9 +412,9 @@ def _fused_batch_norm(inputs, update_moving_variance = moving_averages.assign_moving_average( moving_variance, variance, decay, zero_debias=False) return update_moving_mean, update_moving_variance - update_mean, update_variance = utils.smart_cond(is_training, - _delay_updates, - moving_vars_fn) + + update_mean, update_variance = utils.smart_cond( + is_training, _delay_updates, moving_vars_fn) ops.add_to_collections(updates_collections, update_mean) ops.add_to_collections(updates_collections, update_variance) @@ -479,7 +457,12 @@ def batch_norm(inputs, Sergey Ioffe, Christian Szegedy - Can be used as a normalizer function for conv2d and fully_connected. + Can be used as a normalizer function for conv2d and fully_connected. The + normalization is over all but the last dimension if `data_format` is `NHWC` + and all but the second dimension if `data_format` is `NCHW`. In case of a 2D + tensor this corresponds to the batch dimension, while in case of a 4D tensor + this + corresponds to the batch and space dimensions. Note: when training, the moving_mean and moving_variance need to be updated. By default the update ops are placed in `tf.GraphKeys.UPDATE_OPS`, so they @@ -588,10 +571,9 @@ def batch_norm(inputs, # implementation in normalization_layers.BatchNormalization. inputs = ops.convert_to_tensor(inputs) rank = inputs.get_shape().ndims - possible_to_fuse = (batch_weights is None and - not renorm and - rank in [2, 4] and - adjustment is None) + possible_to_fuse = ( + batch_weights is None and not renorm and rank in [2, 4] and + adjustment is None) if fused and possible_to_fuse and ( zero_debias_moving_mean or rank == 2 or updates_collections is not ops.GraphKeys.UPDATE_OPS): @@ -619,7 +601,9 @@ def batch_norm(inputs, layer_variable_getter = _build_variable_getter() with variable_scope.variable_scope( - scope, 'BatchNorm', [inputs], reuse=reuse, + scope, + 'BatchNorm', [inputs], + reuse=reuse, custom_getter=layer_variable_getter) as sc: inputs = ops.convert_to_tensor(inputs) @@ -667,15 +651,15 @@ def batch_norm(inputs, outputs = layer.apply(inputs, training=is_training) # Add variables to collections. - _add_variable_to_collections( - layer.moving_mean, variables_collections, 'moving_mean') - _add_variable_to_collections( - layer.moving_variance, variables_collections, 'moving_variance') + _add_variable_to_collections(layer.moving_mean, variables_collections, + 'moving_mean') + _add_variable_to_collections(layer.moving_variance, variables_collections, + 'moving_variance') if layer.beta is not None: _add_variable_to_collections(layer.beta, variables_collections, 'beta') if layer.gamma is not None: - _add_variable_to_collections( - layer.gamma, variables_collections, 'gamma') + _add_variable_to_collections(layer.gamma, variables_collections, + 'gamma') if activation_fn is not None: outputs = activation_fn(outputs) @@ -715,8 +699,8 @@ def batch_norm(inputs, params_shape = inputs_shape[-1:] params_shape_broadcast = None if not params_shape.is_fully_defined(): - raise ValueError('Inputs %s has undefined channels dimension %s.' % ( - inputs.name, params_shape)) + raise ValueError('Inputs %s has undefined channels dimension %s.' % + (inputs.name, params_shape)) # Allocate parameters for the beta and gamma of the normalization. beta, gamma = None, None @@ -727,23 +711,25 @@ def batch_norm(inputs, 'beta') beta_initializer = param_initializers.get('beta', init_ops.zeros_initializer()) - beta = variables.model_variable('beta', - shape=params_shape, - dtype=dtype, - initializer=beta_initializer, - collections=beta_collections, - trainable=trainable) + beta = variables.model_variable( + 'beta', + shape=params_shape, + dtype=dtype, + initializer=beta_initializer, + collections=beta_collections, + trainable=trainable) if scale: - gamma_collections = utils.get_variable_collections(variables_collections, - 'gamma') + gamma_collections = utils.get_variable_collections( + variables_collections, 'gamma') gamma_initializer = param_initializers.get('gamma', init_ops.ones_initializer()) - gamma = variables.model_variable('gamma', - shape=params_shape, - dtype=dtype, - initializer=gamma_initializer, - collections=gamma_collections, - trainable=trainable) + gamma = variables.model_variable( + 'gamma', + shape=params_shape, + dtype=dtype, + initializer=gamma_initializer, + collections=gamma_collections, + trainable=trainable) # Create moving_mean and moving_variance variables and add them to the # appropriate collections. We disable variable partitioning while creating @@ -792,8 +778,8 @@ def batch_norm(inputs, mean, variance = nn.moments(inputs, moments_axes) else: if data_format == DATA_FORMAT_NCHW: - mean, variance = nn.weighted_moments(inputs, moments_axes, - batch_weights, keep_dims=True) + mean, variance = nn.weighted_moments( + inputs, moments_axes, batch_weights, keep_dims=True) mean = array_ops.reshape(mean, [-1]) variance = array_ops.reshape(variance, [-1]) else: @@ -802,19 +788,21 @@ def batch_norm(inputs, moving_vars_fn = lambda: (moving_mean, moving_variance) if updates_collections is None: + def _force_updates(): """Internal function forces updates moving_vars if is_training.""" update_moving_mean = moving_averages.assign_moving_average( moving_mean, mean, decay, zero_debias=zero_debias_moving_mean) update_moving_variance = moving_averages.assign_moving_average( moving_variance, variance, decay, zero_debias=False) - with ops.control_dependencies([update_moving_mean, - update_moving_variance]): + with ops.control_dependencies( + [update_moving_mean, update_moving_variance]): return array_ops.identity(mean), array_ops.identity(variance) - mean, variance = utils.smart_cond(is_training, - _force_updates, + + mean, variance = utils.smart_cond(is_training, _force_updates, moving_vars_fn) else: + def _delay_updates(): """Internal function that delay updates moving_vars if is_training.""" update_moving_mean = moving_averages.assign_moving_average( @@ -823,9 +811,8 @@ def batch_norm(inputs, moving_variance, variance, decay, zero_debias=False) return update_moving_mean, update_moving_variance - update_mean, update_variance = utils.smart_cond(is_training, - _delay_updates, - moving_vars_fn) + update_mean, update_variance = utils.smart_cond( + is_training, _delay_updates, moving_vars_fn) ops.add_to_collections(updates_collections, update_mean) ops.add_to_collections(updates_collections, update_variance) # Use computed moments during training and moving_vars otherwise. @@ -893,8 +880,8 @@ def bias_add(inputs, """ if data_format not in (DATA_FORMAT_NCHW, DATA_FORMAT_NHWC): raise ValueError('data_format has to be either NCHW or NHWC.') - with variable_scope.variable_scope(scope, 'BiasAdd', [inputs], - reuse=reuse) as sc: + with variable_scope.variable_scope( + scope, 'BiasAdd', [inputs], reuse=reuse) as sc: inputs = ops.convert_to_tensor(inputs) dtype = inputs.dtype.base_dtype inputs_shape = inputs.get_shape() @@ -909,13 +896,16 @@ def bias_add(inputs, raise ValueError('`C` dimension must be known but is None') biases_collections = utils.get_variable_collections(variables_collections, 'biases') - biases = variables.model_variable('biases', - shape=[num_features,], - dtype=dtype, - initializer=initializer, - regularizer=regularizer, - collections=biases_collections, - trainable=trainable) + biases = variables.model_variable( + 'biases', + shape=[ + num_features, + ], + dtype=dtype, + initializer=initializer, + regularizer=regularizer, + collections=biases_collections, + trainable=trainable) outputs = nn.bias_add(inputs, biases, data_format=data_format) if activation_fn is not None: outputs = activation_fn(outputs) @@ -1015,8 +1005,10 @@ def convolution(inputs, if data_format not in [None, 'NWC', 'NCW', 'NHWC', 'NCHW', 'NDHWC', 'NCDHW']: raise ValueError('Invalid data_format: %r' % (data_format,)) - layer_variable_getter = _build_variable_getter( - {'bias': 'biases', 'kernel': 'weights'}) + layer_variable_getter = _build_variable_getter({ + 'bias': 'biases', + 'kernel': 'weights' + }) with variable_scope.variable_scope( scope, 'Conv', [inputs], reuse=reuse, @@ -1034,26 +1026,27 @@ def convolution(inputs, raise ValueError('Convolution not supported for input with rank', input_rank) - df = ('channels_first' if data_format and data_format.startswith('NC') - else 'channels_last') - layer = layer_class(filters=num_outputs, - kernel_size=kernel_size, - strides=stride, - padding=padding, - data_format=df, - dilation_rate=rate, - activation=None, - use_bias=not normalizer_fn and biases_initializer, - kernel_initializer=weights_initializer, - bias_initializer=biases_initializer, - kernel_regularizer=weights_regularizer, - bias_regularizer=biases_regularizer, - activity_regularizer=None, - trainable=trainable, - name=sc.name, - dtype=inputs.dtype.base_dtype, - _scope=sc, - _reuse=reuse) + df = ('channels_first' + if data_format and data_format.startswith('NC') else 'channels_last') + layer = layer_class( + filters=num_outputs, + kernel_size=kernel_size, + strides=stride, + padding=padding, + data_format=df, + dilation_rate=rate, + activation=None, + use_bias=not normalizer_fn and biases_initializer, + kernel_initializer=weights_initializer, + bias_initializer=biases_initializer, + kernel_regularizer=weights_regularizer, + bias_regularizer=biases_regularizer, + activity_regularizer=None, + trainable=trainable, + name=sc.name, + dtype=inputs.dtype.base_dtype, + _scope=sc, + _reuse=reuse) outputs = layer.apply(inputs) # Add variables to collections. @@ -1069,6 +1062,7 @@ def convolution(inputs, outputs = activation_fn(outputs) return utils.collect_named_outputs(outputs_collections, sc.name, outputs) + convolution2d = convolution convolution3d = convolution @@ -1144,13 +1138,14 @@ def convolution2d_in_plane( weights_shape = [kernel_h, kernel_w, 1, 1] weights_collections = utils.get_variable_collections( variables_collections, 'weights') - weights = variables.model_variable('weights', - shape=weights_shape, - dtype=dtype, - initializer=weights_initializer, - regularizer=weights_regularizer, - collections=weights_collections, - trainable=trainable) + weights = variables.model_variable( + 'weights', + shape=weights_shape, + dtype=dtype, + initializer=weights_initializer, + regularizer=weights_regularizer, + collections=weights_collections, + trainable=trainable) depthwise_weights = array_ops.tile(weights, [1, 1, num_filters_in, 1]) outputs = nn.depthwise_conv2d(inputs, depthwise_weights, [1, stride_h, stride_w, 1], padding) @@ -1161,13 +1156,16 @@ def convolution2d_in_plane( if biases_initializer is not None: biases_collections = utils.get_variable_collections( variables_collections, 'biases') - biases = variables.model_variable('biases', - shape=[num_filters_in,], - dtype=dtype, - initializer=biases_initializer, - regularizer=biases_regularizer, - collections=biases_collections, - trainable=trainable) + biases = variables.model_variable( + 'biases', + shape=[ + num_filters_in, + ], + dtype=dtype, + initializer=biases_initializer, + regularizer=biases_regularizer, + collections=biases_collections, + trainable=trainable) outputs = nn.bias_add(outputs, biases) if activation_fn is not None: @@ -1240,19 +1238,23 @@ def convolution2d_transpose( ValueError: If `data_format` is neither `NHWC` nor `NCHW`. ValueError: If `C` dimension of `inputs` is None. """ - layer_variable_getter = _build_variable_getter( - {'bias': 'biases', 'kernel': 'weights'}) + layer_variable_getter = _build_variable_getter({ + 'bias': 'biases', + 'kernel': 'weights' + }) with variable_scope.variable_scope( - scope, 'Conv2d_transpose', [inputs], reuse=reuse, + scope, + 'Conv2d_transpose', [inputs], + reuse=reuse, custom_getter=layer_variable_getter) as sc: if data_format not in (DATA_FORMAT_NCHW, DATA_FORMAT_NHWC): raise ValueError('data_format has to be either NCHW or NHWC.') inputs = ops.convert_to_tensor(inputs) - df = ('channels_first' if data_format and data_format.startswith('NC') - else 'channels_last') + df = ('channels_first' + if data_format and data_format.startswith('NC') else 'channels_last') layer = convolutional_layers.Convolution2DTranspose( filters=num_outputs, kernel_size=kernel_size, @@ -1349,19 +1351,23 @@ def convolution3d_transpose( ValueError: If `data_format` is neither `NDHWC` nor `NCDHW`. ValueError: If `C` dimension of `inputs` is None. """ - layer_variable_getter = _build_variable_getter( - {'bias': 'biases', 'kernel': 'weights'}) + layer_variable_getter = _build_variable_getter({ + 'bias': 'biases', + 'kernel': 'weights' + }) with variable_scope.variable_scope( - scope, 'Conv3d_transpose', [inputs], reuse=reuse, + scope, + 'Conv3d_transpose', [inputs], + reuse=reuse, custom_getter=layer_variable_getter) as sc: if data_format not in (DATA_FORMAT_NCDHW, DATA_FORMAT_NDHWC): raise ValueError('data_format has to be either NCDHW or NDHWC.') inputs = ops.convert_to_tensor(inputs) - df = ('channels_first' if data_format and data_format.startswith('NC') - else 'channels_last') + df = ('channels_first' + if data_format and data_format.startswith('NC') else 'channels_last') layer = convolutional_layers.Convolution3DTranspose( filters=num_outputs, kernel_size=kernel_size, @@ -1396,6 +1402,29 @@ def convolution3d_transpose( return utils.collect_named_outputs(outputs_collections, sc.name, outputs) +@add_arg_scope +def dense_to_sparse(tensor, eos_token=0, outputs_collections=None, scope=None): + """Converts a dense tensor into a sparse tensor. + An example use would be to convert dense labels to sparse ones + so that they can be fed to the ctc_loss. + + Args: + tensor: An `int` `Tensor` to be converted to a `Sparse`. + eos_token: An integer. + It is part of the target label that signfies the end of a sentence. + outputs_collections: Collection to add the outputs. + scope: Optional scope for name_scope. + """ + with variable_scope.variable_scope( + scope, 'dense_to_sparse', [tensor]) as sc: + tensor = ops.convert_to_tensor(tensor) + indices = array_ops.where(math_ops.not_equal(tensor, constant_op.constant(eos_token, tensor.dtype))) + values = array_ops.gather_nd(tensor, indices) + shape = array_ops.shape(tensor, out_type=dtypes.int64) + outputs = sparse_tensor.SparseTensor(indices, values, shape) + return utils.collect_named_outputs(outputs_collections, sc.name, outputs) + + @add_arg_scope def dropout(inputs, keep_prob=0.5, @@ -1430,19 +1459,18 @@ def dropout(inputs, with variable_scope.variable_scope( scope, 'Dropout', [inputs], custom_getter=_model_variable_getter) as sc: inputs = ops.convert_to_tensor(inputs) - layer = core_layers.Dropout(rate=1 - keep_prob, - noise_shape=noise_shape, - seed=seed, - name=sc.name, - _scope=sc) + layer = core_layers.Dropout( + rate=1 - keep_prob, + noise_shape=noise_shape, + seed=seed, + name=sc.name, + _scope=sc) outputs = layer.apply(inputs, training=is_training) return utils.collect_named_outputs(outputs_collections, sc.name, outputs) @add_arg_scope -def flatten(inputs, - outputs_collections=None, - scope=None): +def flatten(inputs, outputs_collections=None, scope=None): """Flattens the input while maintaining the batch_size. Assumes that the first dimension represents the batch. @@ -1474,8 +1502,8 @@ def _sparse_inner_flatten(inputs, new_rank): outer_dimensions = inputs.dense_shape[:new_rank - 1] inner_dimensions = inputs.dense_shape[new_rank - 1:] - new_shape = array_ops.concat((outer_dimensions, - [math_ops.reduce_prod(inner_dimensions)]), 0) + new_shape = array_ops.concat( + (outer_dimensions, [math_ops.reduce_prod(inner_dimensions)]), 0) flattened = sparse_ops.sparse_reshape(inputs, new_shape) return flattened @@ -1541,10 +1569,18 @@ def _inner_flatten(inputs, new_rank, output_collections=None, scope=None): return utils.collect_named_outputs(output_collections, sc, flattened) -def _model_variable_getter(getter, name, shape=None, dtype=None, - initializer=None, regularizer=None, trainable=True, - collections=None, caching_device=None, - partitioner=None, rename=None, use_resource=None, +def _model_variable_getter(getter, + name, + shape=None, + dtype=None, + initializer=None, + regularizer=None, + trainable=True, + collections=None, + caching_device=None, + partitioner=None, + rename=None, + use_resource=None, **_): """Getter that uses model_variable for compatibility with core layers.""" short_name = name.split('/')[-1] @@ -1553,25 +1589,34 @@ def _model_variable_getter(getter, name, shape=None, dtype=None, name_components[-1] = rename[short_name] name = '/'.join(name_components) return variables.model_variable( - name, shape=shape, dtype=dtype, initializer=initializer, - regularizer=regularizer, collections=collections, trainable=trainable, - caching_device=caching_device, partitioner=partitioner, - custom_getter=getter, use_resource=use_resource) + name, + shape=shape, + dtype=dtype, + initializer=initializer, + regularizer=regularizer, + collections=collections, + trainable=trainable, + caching_device=caching_device, + partitioner=partitioner, + custom_getter=getter, + use_resource=use_resource) def _build_variable_getter(rename=None): """Build a model variable getter that respects scope getter and renames.""" + # VariableScope will nest the getters def layer_variable_getter(getter, *args, **kwargs): kwargs['rename'] = rename return _model_variable_getter(getter, *args, **kwargs) + return layer_variable_getter def _add_variable_to_collections(variable, collections_set, collections_name): """Adds variable (or all its parts) to all collections with that name.""" - collections = utils.get_variable_collections( - collections_set, collections_name) or [] + collections = utils.get_variable_collections(collections_set, + collections_name) or [] variables_list = [variable] if isinstance(variable, tf_variables.PartitionedVariable): variables_list = [v for v in variable] @@ -1640,15 +1685,19 @@ def fully_connected(inputs, ValueError: If x has rank less than 2 or if its last dimension is not set. """ if not isinstance(num_outputs, six.integer_types): - raise ValueError( - 'num_outputs should be int or long, got %s.' % (num_outputs,)) + raise ValueError('num_outputs should be int or long, got %s.' % + (num_outputs,)) - layer_variable_getter = _build_variable_getter({'bias': 'biases', - 'kernel': 'weights'}) + layer_variable_getter = _build_variable_getter({ + 'bias': 'biases', + 'kernel': 'weights' + }) with variable_scope.variable_scope( - scope, 'fully_connected', [inputs], - reuse=reuse, custom_getter=layer_variable_getter) as sc: + scope, + 'fully_connected', [inputs], + reuse=reuse, + custom_getter=layer_variable_getter) as sc: inputs = ops.convert_to_tensor(inputs) layer = core_layers.Dense( units=num_outputs, @@ -1754,15 +1803,17 @@ class GDN(base.Layer): inverse=False, beta_min=1e-6, gamma_init=.1, - reparam_offset=2 ** -18, + reparam_offset=2**-18, data_format='channels_last', activity_regularizer=None, trainable=True, name=None, **kwargs): - super(GDN, self).__init__(trainable=trainable, name=name, - activity_regularizer=activity_regularizer, - **kwargs) + super(GDN, self).__init__( + trainable=trainable, + name=name, + activity_regularizer=activity_regularizer, + **kwargs) self.inverse = inverse self._beta_min = beta_min self._gamma_init = gamma_init @@ -1797,8 +1848,9 @@ class GDN(base.Layer): with ops.name_scope(name, 'GDNLowerBound', [inputs, bound]) as scope: inputs = ops.convert_to_tensor(inputs, name='inputs') bound = ops.convert_to_tensor(bound, name='bound') - with ops.get_default_graph().gradient_override_map( - {'Maximum': 'GDNLowerBound'}): + with ops.get_default_graph().gradient_override_map({ + 'Maximum': 'GDNLowerBound' + }): return math_ops.maximum(inputs, bound, name=scope) @staticmethod @@ -1825,12 +1877,14 @@ class GDN(base.Layer): raise ValueError('The channel dimension of the inputs to `GDN` ' 'must be defined.') self._input_rank = input_shape.ndims - self.input_spec = base.InputSpec(ndim=input_shape.ndims, - axes={channel_axis: num_channels}) + self.input_spec = base.InputSpec( + ndim=input_shape.ndims, axes={ + channel_axis: num_channels + }) - pedestal = array_ops.constant(self._reparam_offset ** 2, dtype=self.dtype) + pedestal = array_ops.constant(self._reparam_offset**2, dtype=self.dtype) beta_bound = array_ops.constant( - (self._beta_min + self._reparam_offset ** 2) ** .5, dtype=self.dtype) + (self._beta_min + self._reparam_offset**2)**.5, dtype=self.dtype) gamma_bound = array_ops.constant(self._reparam_offset, dtype=self.dtype) def beta_initializer(shape, dtype=None, partition_info=None): @@ -1844,19 +1898,21 @@ class GDN(base.Layer): eye = linalg_ops.eye(shape[0], dtype=dtype) return math_ops.sqrt(self._gamma_init * eye + pedestal) - beta = self.add_variable('reparam_beta', - shape=[num_channels], - initializer=beta_initializer, - dtype=self.dtype, - trainable=True) + beta = self.add_variable( + 'reparam_beta', + shape=[num_channels], + initializer=beta_initializer, + dtype=self.dtype, + trainable=True) beta = self._lower_bound(beta, beta_bound) self.beta = math_ops.square(beta) - pedestal - gamma = self.add_variable('reparam_gamma', - shape=[num_channels, num_channels], - initializer=gamma_initializer, - dtype=self.dtype, - trainable=True) + gamma = self.add_variable( + 'reparam_gamma', + shape=[num_channels, num_channels], + initializer=gamma_initializer, + dtype=self.dtype, + trainable=True) gamma = self._lower_bound(gamma, gamma_bound) self.gamma = math_ops.square(gamma) - pedestal @@ -1871,8 +1927,11 @@ class GDN(base.Layer): # Compute normalization pool. if self.data_format == 'channels_first': - norm_pool = nn.convolution(math_ops.square(inputs), gamma, 'VALID', - data_format='NC' + 'DHW'[-(ndim - 2):]) + norm_pool = nn.convolution( + math_ops.square(inputs), + gamma, + 'VALID', + data_format='NC' + 'DHW' [-(ndim - 2):]) if ndim == 3: norm_pool = array_ops.expand_dims(norm_pool, 2) norm_pool = nn.bias_add(norm_pool, self.beta, data_format='NCHW') @@ -1914,7 +1973,7 @@ def gdn(inputs, inverse=False, beta_min=1e-6, gamma_init=.1, - reparam_offset=2 ** -18, + reparam_offset=2**-18, data_format='channels_last', activity_regularizer=None, trainable=True, @@ -1980,17 +2039,18 @@ def gdn(inputs, Returns: Output tensor. """ - layer = GDN(inverse=inverse, - beta_min=beta_min, - gamma_init=gamma_init, - reparam_offset=reparam_offset, - data_format=data_format, - activity_regularizer=activity_regularizer, - trainable=trainable, - name=name, - dtype=inputs.dtype.base_dtype, - _scope=name, - _reuse=reuse) + layer = GDN( + inverse=inverse, + beta_min=beta_min, + gamma_init=gamma_init, + reparam_offset=reparam_offset, + data_format=data_format, + activity_regularizer=activity_regularizer, + trainable=trainable, + name=name, + dtype=inputs.dtype.base_dtype, + _scope=name, + _reuse=reuse) return layer.apply(inputs) @@ -2066,8 +2126,8 @@ def layer_norm(inputs, or if `inputs.shape[begin_params_axis:]` is not fully defined at graph build time. """ - with variable_scope.variable_scope(scope, 'LayerNorm', [inputs], - reuse=reuse) as sc: + with variable_scope.variable_scope( + scope, 'LayerNorm', [inputs], reuse=reuse) as sc: inputs = ops.convert_to_tensor(inputs) inputs_shape = inputs.shape inputs_rank = inputs_shape.ndims @@ -2077,15 +2137,14 @@ def layer_norm(inputs, if begin_norm_axis < 0: begin_norm_axis = inputs_rank + begin_norm_axis if begin_params_axis >= inputs_rank or begin_norm_axis >= inputs_rank: - raise ValueError( - 'begin_params_axis (%d) and begin_norm_axis (%d) ' - 'must be < rank(inputs) (%d)' - % (begin_params_axis, begin_norm_axis, inputs_rank)) + raise ValueError('begin_params_axis (%d) and begin_norm_axis (%d) ' + 'must be < rank(inputs) (%d)' % + (begin_params_axis, begin_norm_axis, inputs_rank)) params_shape = inputs_shape[begin_params_axis:] if not params_shape.is_fully_defined(): raise ValueError( - 'Inputs %s: shape(inputs)[%s:] is not fully defined: %s' % ( - inputs.name, begin_params_axis, inputs_shape)) + 'Inputs %s: shape(inputs)[%s:] is not fully defined: %s' % + (inputs.name, begin_params_axis, inputs_shape)) # Allocate parameters for the beta and gamma of the normalization. beta, gamma = None, None if center: @@ -2099,8 +2158,8 @@ def layer_norm(inputs, collections=beta_collections, trainable=trainable) if scale: - gamma_collections = utils.get_variable_collections(variables_collections, - 'gamma') + gamma_collections = utils.get_variable_collections( + variables_collections, 'gamma') gamma = variables.model_variable( 'gamma', shape=params_shape, @@ -2114,7 +2173,11 @@ def layer_norm(inputs, # Compute layer normalization using the batch_normalization function. variance_epsilon = 1e-12 outputs = nn.batch_normalization( - inputs, mean, variance, offset=beta, scale=gamma, + inputs, + mean, + variance, + offset=beta, + scale=gamma, variance_epsilon=variance_epsilon) outputs.set_shape(inputs_shape) if activation_fn is not None: @@ -2160,13 +2223,14 @@ def max_pool2d(inputs, raise ValueError('data_format has to be either NCHW or NHWC.') with ops.name_scope(scope, 'MaxPool2D', [inputs]) as sc: inputs = ops.convert_to_tensor(inputs) - df = ('channels_first' if data_format and data_format.startswith('NC') - else 'channels_last') - layer = pooling_layers.MaxPooling2D(pool_size=kernel_size, - strides=stride, - padding=padding, - data_format=df, - _scope=sc) + df = ('channels_first' + if data_format and data_format.startswith('NC') else 'channels_last') + layer = pooling_layers.MaxPooling2D( + pool_size=kernel_size, + strides=stride, + padding=padding, + data_format=df, + _scope=sc) outputs = layer.apply(inputs) return utils.collect_named_outputs(outputs_collections, sc, outputs) @@ -2209,13 +2273,14 @@ def max_pool3d(inputs, raise ValueError('data_format has to be either NCDHW or NDHWC.') with ops.name_scope(scope, 'MaxPool3D', [inputs]) as sc: inputs = ops.convert_to_tensor(inputs) - df = ('channels_first' if data_format and data_format.startswith('NC') - else 'channels_last') - layer = pooling_layers.MaxPooling3D(pool_size=kernel_size, - strides=stride, - padding=padding, - data_format=df, - _scope=sc) + df = ('channels_first' + if data_format and data_format.startswith('NC') else 'channels_last') + layer = pooling_layers.MaxPooling3D( + pool_size=kernel_size, + strides=stride, + padding=padding, + data_format=df, + _scope=sc) outputs = layer.apply(inputs) return utils.collect_named_outputs(outputs_collections, sc, outputs) @@ -2268,8 +2333,8 @@ def pool(inputs, """ # pylint: enable=line-too-long - with ops.name_scope(scope, '%s_pool' % - (pooling_type.lower()), [inputs]) as sc: + with ops.name_scope(scope, '%s_pool' % (pooling_type.lower()), + [inputs]) as sc: inputs = ops.convert_to_tensor(inputs) input_rank = inputs.get_shape().ndims if input_rank is None: @@ -2314,18 +2379,16 @@ def one_hot_encoding(labels, labels = ops.convert_to_tensor(labels) if labels.dtype == dtypes.int32: labels = standard_ops.to_int64(labels) - outputs = standard_ops.one_hot(labels, - num_classes, - on_value=on_value, - off_value=off_value) + outputs = standard_ops.one_hot( + labels, num_classes, on_value=on_value, off_value=off_value) return utils.collect_named_outputs(outputs_collections, sc, outputs) def _apply_activation(y, activation_fn, output_collections): if activation_fn is not None: y = activation_fn(y) - ops.add_to_collections(list(output_collections or []) + - [ops.GraphKeys.ACTIVATIONS], y) + ops.add_to_collections( + list(output_collections or []) + [ops.GraphKeys.ACTIVATIONS], y) return y @@ -2370,7 +2433,7 @@ def repeat(inputs, repetitions, layer, *args, **kwargs): scope = 'repeat' outputs = inputs for i in range(repetitions): - kwargs['scope'] = scope + '_' + str(i+1) + kwargs['scope'] = scope + '_' + str(i + 1) outputs = layer(outputs, *args, **kwargs) return outputs @@ -2385,8 +2448,8 @@ def _scale_gradient_grad(op, grad): return [grad * op.inputs[1], None] -@function.Defun(python_grad_func=_scale_gradient_grad, - shape_func=_scale_gradient_shape) +@function.Defun( + python_grad_func=_scale_gradient_grad, shape_func=_scale_gradient_shape) def scale_gradient(inputs, gradient_multiplier): """Identity operation, but with the gradient multiplied by a tensor. @@ -2491,18 +2554,21 @@ def separable_convolution2d( """ if data_format not in (DATA_FORMAT_NCHW, DATA_FORMAT_NHWC): raise ValueError('data_format has to be either NCHW or NHWC.') - layer_variable_getter = _build_variable_getter( - {'bias': 'biases', - 'depthwise_kernel': 'depthwise_weights', - 'pointwise_kernel': 'pointwise_weights'}) + layer_variable_getter = _build_variable_getter({ + 'bias': 'biases', + 'depthwise_kernel': 'depthwise_weights', + 'pointwise_kernel': 'pointwise_weights' + }) with variable_scope.variable_scope( - scope, 'SeparableConv2d', [inputs], reuse=reuse, + scope, + 'SeparableConv2d', [inputs], + reuse=reuse, custom_getter=layer_variable_getter) as sc: inputs = ops.convert_to_tensor(inputs) - df = ('channels_first' if data_format and data_format.startswith('NC') - else 'channels_last') + df = ('channels_first' + if data_format and data_format.startswith('NC') else 'channels_last') if num_outputs is not None: # Apply separable conv using the SeparableConvolution2D layer. layer = convolutional_layers.SeparableConvolution2D( @@ -2535,8 +2601,8 @@ def separable_convolution2d( _add_variable_to_collections(layer.pointwise_kernel, variables_collections, 'weights') if layer.bias is not None: - _add_variable_to_collections(layer.bias, - variables_collections, 'biases') + _add_variable_to_collections(layer.bias, variables_collections, + 'biases') if normalizer_fn is not None: normalizer_params = normalizer_params or {} @@ -2551,8 +2617,7 @@ def separable_convolution2d( weights_collections = utils.get_variable_collections( variables_collections, 'weights') - depthwise_shape = [kernel_h, kernel_w, - num_filters_in, depth_multiplier] + depthwise_shape = [kernel_h, kernel_w, num_filters_in, depth_multiplier] depthwise_weights = variables.model_variable( 'depthwise_weights', shape=depthwise_shape, @@ -2566,9 +2631,13 @@ def separable_convolution2d( 1, stride_h, stride_w, 1 ] - outputs = nn.depthwise_conv2d(inputs, depthwise_weights, strides, padding, - rate=utils.two_element_tuple(rate), - data_format=data_format) + outputs = nn.depthwise_conv2d( + inputs, + depthwise_weights, + strides, + padding, + rate=utils.two_element_tuple(rate), + data_format=data_format) num_outputs = depth_multiplier * num_filters_in if normalizer_fn is not None: @@ -2578,13 +2647,16 @@ def separable_convolution2d( if biases_initializer is not None: biases_collections = utils.get_variable_collections( variables_collections, 'biases') - biases = variables.model_variable('biases', - shape=[num_outputs,], - dtype=dtype, - initializer=biases_initializer, - regularizer=biases_regularizer, - trainable=trainable, - collections=biases_collections) + biases = variables.model_variable( + 'biases', + shape=[ + num_outputs, + ], + dtype=dtype, + initializer=biases_initializer, + regularizer=biases_regularizer, + trainable=trainable, + collections=biases_collections) outputs = nn.bias_add(outputs, biases, data_format=data_format) if activation_fn is not None: @@ -2669,23 +2741,24 @@ def spatial_softmax(features, with ops.name_scope('spatial_softmax_op', 'spatial_softmax_op', [features]): # Create tensors for x and y coordinate values, scaled to range [-1, 1]. - pos_x, pos_y = array_ops.meshgrid(math_ops.lin_space(-1., 1., num=height), - math_ops.lin_space(-1., 1., num=width), - indexing='ij') + pos_x, pos_y = array_ops.meshgrid( + math_ops.lin_space(-1., 1., num=height), + math_ops.lin_space(-1., 1., num=width), + indexing='ij') pos_x = array_ops.reshape(pos_x, [height * width]) pos_y = array_ops.reshape(pos_y, [height * width]) - + if temperature is None: temp_initializer = init_ops.ones_initializer() else: temp_initializer = init_ops.constant_initializer(temperature) - + if not trainable: temp_collections = None else: temp_collections = utils.get_variable_collections( - variables_collections, 'temperature') - + variables_collections, 'temperature') + temperature = variables.model_variable( 'temperature', shape=(), @@ -2699,14 +2772,14 @@ def spatial_softmax(features, features = array_ops.reshape( array_ops.transpose(features, [0, 3, 1, 2]), [-1, height * width]) - softmax_attention = nn.softmax(features/temperature) + softmax_attention = nn.softmax(features / temperature) expected_x = math_ops.reduce_sum( pos_x * softmax_attention, [1], keep_dims=True) expected_y = math_ops.reduce_sum( pos_y * softmax_attention, [1], keep_dims=True) expected_xy = array_ops.concat([expected_x, expected_y], 1) - feature_keypoints = array_ops.reshape( - expected_xy, [-1, num_channels.value * 2]) + feature_keypoints = array_ops.reshape(expected_xy, + [-1, num_channels.value * 2]) feature_keypoints.set_shape([None, num_channels.value * 2]) return feature_keypoints @@ -2758,7 +2831,7 @@ def stack(inputs, layer, stack_args, **kwargs): scope = 'stack' outputs = inputs for i in range(len(stack_args)): - kwargs['scope'] = scope + '_' + str(i+1) + kwargs['scope'] = scope + '_' + str(i + 1) layer_args = stack_args[i] if not isinstance(layer_args, (list, tuple)): layer_args = [layer_args] @@ -2789,11 +2862,10 @@ def unit_norm(inputs, dim, epsilon=1e-7, scope=None): raise ValueError('The input rank must be known.') input_rank = len(inputs.get_shape().as_list()) if dim < 0 or dim >= input_rank: - raise ValueError( - 'dim must be positive but smaller than the input rank.') + raise ValueError('dim must be positive but smaller than the input rank.') - lengths = math_ops.sqrt(epsilon + math_ops.reduce_sum( - math_ops.square(inputs), dim, True)) + lengths = math_ops.sqrt( + epsilon + math_ops.reduce_sum(math_ops.square(inputs), dim, True)) multiples = [] if dim > 0: multiples.append(array_ops.ones([dim], dtypes.int32)) @@ -2934,29 +3006,31 @@ def legacy_fully_connected(x, raise ValueError('last dimension of x must be known but is None') dtype = x.dtype.base_dtype - weight_collections = set(list(weight_collections or []) + - [ops.GraphKeys.GLOBAL_VARIABLES]) - w = variable_scope.get_variable('weights', - shape=[num_input_units, num_output_units], - dtype=dtype, - initializer=weight_init, - collections=weight_collections, - regularizer=weight_regularizer, - trainable=trainable) - x_2_dim = x if len(dims) <= 2 else array_ops.reshape(x, - [-1, num_input_units]) + weight_collections = set( + list(weight_collections or []) + [ops.GraphKeys.GLOBAL_VARIABLES]) + w = variable_scope.get_variable( + 'weights', + shape=[num_input_units, num_output_units], + dtype=dtype, + initializer=weight_init, + collections=weight_collections, + regularizer=weight_regularizer, + trainable=trainable) + x_2_dim = x if len(dims) <= 2 else array_ops.reshape( + x, [-1, num_input_units]) y = standard_ops.matmul(x_2_dim, w) if bias_init is not None: - bias_collections = set(list(bias_collections or []) + - [ops.GraphKeys.GLOBAL_VARIABLES]) - b = variable_scope.get_variable('bias', - shape=[num_output_units], - dtype=dtype, - initializer=bias_init, - collections=bias_collections, - regularizer=bias_regularizer, - trainable=trainable) + bias_collections = set( + list(bias_collections or []) + [ops.GraphKeys.GLOBAL_VARIABLES]) + b = variable_scope.get_variable( + 'bias', + shape=[num_output_units], + dtype=dtype, + initializer=bias_init, + collections=bias_collections, + regularizer=bias_regularizer, + trainable=trainable) y = nn.bias_add(y, b) diff --git a/tensorflow/contrib/layers/python/layers/layers_test.py b/tensorflow/contrib/layers/python/layers/layers_test.py index a9bdbe01387653bada1f1e5e9948db7a737eb600..c5790c76221848524a106f1a218922f4e7a0b7e6 100644 --- a/tensorflow/contrib/layers/python/layers/layers_test.py +++ b/tensorflow/contrib/layers/python/layers/layers_test.py @@ -44,6 +44,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops import partitioned_variables from tensorflow.python.ops import random_ops +from tensorflow.python.ops import sparse_ops from tensorflow.python.ops import state_ops from tensorflow.python.ops import template from tensorflow.python.ops import variable_scope @@ -1292,6 +1293,17 @@ class ConvolutionInPlaneTest(test.TestCase): self.assertAllClose(result, expected, rtol=1e-5, atol=1e-5) +class DenseToSparseTest(test.TestCase): + + def testDenseFromConstantToSparse(self): + expected_constant = np.reshape(np.arange(24, dtype=np.int64), (3, 4, 2)) + tensor = constant_op.constant(expected_constant) + sparse = _layers.dense_to_sparse(tensor) + dense = sparse_ops.sparse_to_dense(sparse.indices, sparse.dense_shape, sparse.values) + with self.test_session() as sess: + constant = sess.run(dense) + self.assertAllEqual(expected_constant, constant) + class DropoutTest(test.TestCase): def testCreateDropout(self): diff --git a/tensorflow/contrib/learn/BUILD b/tensorflow/contrib/learn/BUILD index ee3611ca9385e80d30e42f8405c8ac318e66771b..3c782b54a8559a6aac19d12ea11a9c76bffdb9c3 100644 --- a/tensorflow/contrib/learn/BUILD +++ b/tensorflow/contrib/learn/BUILD @@ -494,7 +494,7 @@ py_test( name = "linear_test", size = "medium", srcs = ["python/learn/estimators/linear_test.py"], - shard_count = 4, + shard_count = 20, srcs_version = "PY2AND3", tags = ["no_pip"], deps = [ diff --git a/tensorflow/contrib/lite/Android.bp b/tensorflow/contrib/lite/Android.bp index 2b91f1e8c900ab8ab1d99cb803944821aa038d84..8a5d9d5df29cc98f3aeab0632295288c4c196c9d 100644 --- a/tensorflow/contrib/lite/Android.bp +++ b/tensorflow/contrib/lite/Android.bp @@ -52,6 +52,7 @@ cc_library_static { "gemmlowp_headers", ], cflags: [ + "-Wno-extern-c-compat", "-Wno-mismatched-tags", "-Wno-sign-compare", "-Wno-unused-lambda-capture", diff --git a/tensorflow/contrib/lite/allocation.h b/tensorflow/contrib/lite/allocation.h index ee8a7ccd0b232f9e48095567fd4aefe94f595bc3..68aee2e64473320c461ec8b3f194904e7b8da43c 100644 --- a/tensorflow/contrib/lite/allocation.h +++ b/tensorflow/contrib/lite/allocation.h @@ -14,8 +14,8 @@ limitations under the License. ==============================================================================*/ // Main abstraction controlling the tflite interpreter. // See context.h for the API for defining operations (TfLiteRegistration). -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_ALLOCATION_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_ALLOCATION_H_ +#ifndef TENSORFLOW_CONTRIB_LITE_ALLOCATION_H_ +#define TENSORFLOW_CONTRIB_LITE_ALLOCATION_H_ #include #include @@ -91,4 +91,4 @@ class MemoryAllocation : public Allocation { } // namespace tflite -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_ALLOCATION_H_ +#endif // TENSORFLOW_CONTRIB_LITE_ALLOCATION_H_ diff --git a/tensorflow/contrib/lite/arena_planner.h b/tensorflow/contrib/lite/arena_planner.h index bd87414ec3c8ac75b99e730fcac977a7afa08806..58bc164619c2c053b9492e9a0e5de2da30e199af 100644 --- a/tensorflow/contrib/lite/arena_planner.h +++ b/tensorflow/contrib/lite/arena_planner.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_ARENA_PLANNER_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_ARENA_PLANNER_H_ +#ifndef TENSORFLOW_CONTRIB_LITE_ARENA_PLANNER_H_ +#define TENSORFLOW_CONTRIB_LITE_ARENA_PLANNER_H_ #include #include @@ -104,4 +104,4 @@ class ArenaPlanner : public MemoryPlanner { } // namespace tflite -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_ARENA_PLANNER_H_ +#endif // TENSORFLOW_CONTRIB_LITE_ARENA_PLANNER_H_ diff --git a/tensorflow/contrib/lite/build_def.bzl b/tensorflow/contrib/lite/build_def.bzl index 0a097d5a69a8bc15aa03502f7a2131fc36e36091..19829e4991651111e13fc1805f97daef8bc016a7 100644 --- a/tensorflow/contrib/lite/build_def.bzl +++ b/tensorflow/contrib/lite/build_def.bzl @@ -5,25 +5,25 @@ def tflite_copts(): copts = [ "-DFARMHASH_NO_CXX_STRING", ] + select({ - "//tensorflow:android_arm64": [ + str(Label("//tensorflow:android_arm64")): [ "-std=c++11", "-O3", ], - "//tensorflow:android_arm": [ + str(Label("//tensorflow:android_arm")): [ "-mfpu=neon", "-mfloat-abi=softfp", "-std=c++11", "-O3", ], - "//tensorflow:android_x86": [ + str(Label("//tensorflow:android_x86")): [ "-DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK", ], - "//tensorflow:ios_x86_64": [ + str(Label("//tensorflow:ios_x86_64")): [ "-msse4.1", ], "//conditions:default": [], }) + select({ - "//tensorflow:with_default_optimizations": [], + str(Label("//tensorflow:with_default_optimizations")): [], "//conditions:default": ["-DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK"], }) diff --git a/tensorflow/contrib/lite/builtin_op_data.h b/tensorflow/contrib/lite/builtin_op_data.h index 3b43a1fd5d383b8b9eee1704b7a1b80b8d4059d4..0b48ef4741ac921e34dd56930783499c5040d581 100644 --- a/tensorflow/contrib/lite/builtin_op_data.h +++ b/tensorflow/contrib/lite/builtin_op_data.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_BUILTIN_OP_DATA_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_BUILTIN_OP_DATA_H_ +#ifndef TENSORFLOW_CONTRIB_LITE_BUILTIN_OP_DATA_H_ +#define TENSORFLOW_CONTRIB_LITE_BUILTIN_OP_DATA_H_ #include @@ -88,7 +88,9 @@ typedef struct { TfLiteFusedActivation activation; } TfLiteSequenceRNNParams; -typedef struct { TfLiteFusedActivation activation; } TfLiteFullyConnectedParams; +typedef struct { + TfLiteFusedActivation activation; +} TfLiteFullyConnectedParams; typedef enum { kTfLiteLshProjectionUnknown = 0, @@ -96,9 +98,13 @@ typedef enum { kTfLiteLshProjectionDense = 2, } TfLiteLSHProjectionType; -typedef struct { TfLiteLSHProjectionType type; } TfLiteLSHProjectionParams; +typedef struct { + TfLiteLSHProjectionType type; +} TfLiteLSHProjectionParams; -typedef struct { float beta; } TfLiteSoftmaxParams; +typedef struct { + float beta; +} TfLiteSoftmaxParams; typedef struct { int axis; @@ -166,11 +172,6 @@ typedef struct { } TfLiteResizeBilinearParams; typedef struct { - // TODO(ahentz): We can't have dynamic data in this struct, at least not yet. - // For now we will fix the maximum possible number of dimensions. - int before_padding[8]; - int after_padding[8]; - int num_dimensions; } TfLitePadParams; typedef struct { @@ -226,8 +227,16 @@ typedef struct { int num_squeeze_dims; } TfLiteSqueezeParams; +typedef struct { + int begin_mask; + int end_mask; + int ellipsis_mask; + int new_axis_mask; + int shrink_axis_mask; +} TfLiteStridedSliceParams; + #ifdef __cplusplus } // extern "C" #endif // __cplusplus -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_BUILTIN_OP_DATA_H_ +#endif // TENSORFLOW_CONTRIB_LITE_BUILTIN_OP_DATA_H_ diff --git a/tensorflow/contrib/lite/context.h b/tensorflow/contrib/lite/context.h index fca71165034a46b39803f4500af8dc5c6f4e8829..d6dfc20ae829b13e9cb45efcf9e14af5d4b69b48 100644 --- a/tensorflow/contrib/lite/context.h +++ b/tensorflow/contrib/lite/context.h @@ -26,8 +26,8 @@ limitations under the License. // TfLiteRegistration - the implementation of a conceptual operation. // // Some abstractions in this file are created and managed by Interpreter. -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_CONTEXT_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_CONTEXT_H_ +#ifndef TENSORFLOW_CONTRIB_LITE_CONTEXT_H_ +#define TENSORFLOW_CONTRIB_LITE_CONTEXT_H_ #include #include @@ -296,4 +296,4 @@ typedef struct { #ifdef __cplusplus } // extern "C" #endif // __cplusplus -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_CONTEXT_H_ +#endif // TENSORFLOW_CONTRIB_LITE_CONTEXT_H_ diff --git a/tensorflow/contrib/lite/download_dependencies.sh b/tensorflow/contrib/lite/download_dependencies.sh index 362e5bee25e95e87fa22bb77904056e732c4e140..e1b7b3613a041287ff3cc4eeff8afd7cfcede174 100755 --- a/tensorflow/contrib/lite/download_dependencies.sh +++ b/tensorflow/contrib/lite/download_dependencies.sh @@ -22,7 +22,14 @@ cd "$SCRIPT_DIR/../../.." DOWNLOADS_DIR=tensorflow/contrib/lite/downloads BZL_FILE_PATH=tensorflow/workspace.bzl -EIGEN_URL="$(grep -o 'http.*bitbucket.org/eigen/eigen/get/.*tar\.gz' "${BZL_FILE_PATH}" | grep -v bazel-mirror | head -n1)" +# Ensure it is being run from repo root +if [ ! -f $BZL_FILE_PATH ]; then + echo "Could not find ${BZL_FILE_PATH}": + echo "Likely you are not running this from the root directory of the repository."; + exit 1; +fi + +EIGEN_URL="$(grep -o 'http.*bitbucket.org/eigen/eigen/get/.*tar\.gz' "${BZL_FILE_PATH}" | grep -v mirror.bazel | head -n1)" GEMMLOWP_URL="$(grep -o 'https://mirror.bazel.build/github.com/google/gemmlowp/.*zip' "${BZL_FILE_PATH}" | head -n1)" GOOGLETEST_URL="https://github.com/google/googletest/archive/release-1.8.0.tar.gz" ABSL_URL="$(grep -o 'https://github.com/abseil/abseil-cpp/.*tar.gz' "${BZL_FILE_PATH}" | head -n1)" diff --git a/tensorflow/contrib/lite/error_reporter.h b/tensorflow/contrib/lite/error_reporter.h index d5715e4f90aead79a617fe4576bfe5100d5e121a..da193d2586e9123341b9a41be049ee2a4382017a 100644 --- a/tensorflow/contrib/lite/error_reporter.h +++ b/tensorflow/contrib/lite/error_reporter.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_ERROR_REPORTER_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_ERROR_REPORTER_H_ +#ifndef TENSORFLOW_CONTRIB_LITE_ERROR_REPORTER_H_ +#define TENSORFLOW_CONTRIB_LITE_ERROR_REPORTER_H_ #include #include "tensorflow/contrib/lite/context.h" @@ -51,4 +51,4 @@ ErrorReporter* DefaultErrorReporter(); } // namespace tflite -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_ERROR_REPORTER_H_ +#endif // TENSORFLOW_CONTRIB_LITE_ERROR_REPORTER_H_ diff --git a/tensorflow/contrib/lite/examples/ios/simple/RunModelViewController.mm b/tensorflow/contrib/lite/examples/ios/simple/RunModelViewController.mm index a885a57b65c5c40ec13cc1c8893e02f4f75ed106..0ab7aa25d0b4e6d2c02e61ec1d82b85258b3dfbc 100644 --- a/tensorflow/contrib/lite/examples/ios/simple/RunModelViewController.mm +++ b/tensorflow/contrib/lite/examples/ios/simple/RunModelViewController.mm @@ -29,13 +29,6 @@ #include "ios_image_load.h" -#define LOG(x) std::cerr -#define CHECK(x) \ - if (!(x)) { \ - LOG(ERROR) << #x << "failed"; \ - exit(1); \ - } - NSString* RunInferenceOnImage(); @interface RunModelViewController () @@ -89,8 +82,8 @@ static void GetTopN(const float* prediction, const int prediction_size, const in NSString* FilePathForResourceName(NSString* name, NSString* extension) { NSString* file_path = [[NSBundle mainBundle] pathForResource:name ofType:extension]; if (file_path == NULL) { - LOG(FATAL) << "Couldn't find '" << [name UTF8String] << "." << [extension UTF8String] - << "' in bundle."; + NSLog(@"Couldn't find '%@.%@' in bundle.", name, extension); + exit(-1); } return file_path; } @@ -106,11 +99,12 @@ NSString* RunInferenceOnImage() { std::unique_ptr model( tflite::FlatBufferModel::BuildFromFile([graph_path UTF8String])); if (!model) { - LOG(FATAL) << "Failed to mmap model " << [graph UTF8String]; + NSLog(@"Failed to mmap model %@.", graph); + exit(-1); } - LOG(INFO) << "Loaded model " << [graph UTF8String]; + NSLog(@"Loaded model %@.", graph); model->error_reporter(); - LOG(INFO) << "resolved reporter"; + NSLog(@"Resolved reporter."); #ifdef TFLITE_CUSTOM_OPS_HEADER tflite::MutableOpResolver resolver; @@ -122,7 +116,8 @@ NSString* RunInferenceOnImage() { std::unique_ptr interpreter; tflite::InterpreterBuilder(*model, resolver)(&interpreter); if (!interpreter) { - LOG(FATAL) << "Failed to construct interpreter"; + NSLog(@"Failed to construct interpreter."); + exit(-1); } if (num_threads != -1) { @@ -136,7 +131,8 @@ NSString* RunInferenceOnImage() { } if (interpreter->AllocateTensors() != kTfLiteOk) { - LOG(FATAL) << "Failed to allocate tensors!"; + NSLog(@"Failed to allocate tensors."); + exit(-1); } // Read the label list @@ -181,7 +177,8 @@ NSString* RunInferenceOnImage() { } if (interpreter->Invoke() != kTfLiteOk) { - LOG(FATAL) << "Failed to invoke!"; + NSLog(@"Failed to invoke!"); + exit(-1); } float* output = interpreter->typed_output_tensor(0); @@ -211,11 +208,9 @@ NSString* RunInferenceOnImage() { ss << "\n"; } - LOG(INFO) << "Predictions: " << ss.str(); - std::string predictions = ss.str(); NSString* result = @""; result = [NSString stringWithFormat:@"%@ - %s", result, predictions.c_str()]; - + NSLog(@"Predictions: %@", result); return result; } diff --git a/tensorflow/contrib/lite/examples/label_image/label_image.cc b/tensorflow/contrib/lite/examples/label_image/label_image.cc index 4d2e1ce0bc751667393c4b38acc0517980c9f02a..d7f49ad8757e8899fe9c23b985edff6ba7f68750 100644 --- a/tensorflow/contrib/lite/examples/label_image/label_image.cc +++ b/tensorflow/contrib/lite/examples/label_image/label_image.cc @@ -148,14 +148,22 @@ void RunInference(Settings* s) { int wanted_width = dims->data[2]; int wanted_channels = dims->data[3]; - if (s->input_floating) { - downsize(interpreter->typed_tensor(input), in, image_height, - image_width, image_channels, wanted_height, wanted_width, - wanted_channels, s); - } else { - downsize(interpreter->typed_tensor(input), in, - image_height, image_width, image_channels, wanted_height, - wanted_width, wanted_channels, s); + switch (interpreter->tensor(input)->type) { + case kTfLiteFloat32: + s->input_floating = true; + downsize(interpreter->typed_tensor(input), in, + image_height, image_width, image_channels, + wanted_height, wanted_width, wanted_channels, s); + break; + case kTfLiteUInt8: + downsize(interpreter->typed_tensor(input), in, + image_height, image_width, image_channels, + wanted_height, wanted_width, wanted_channels, s); + break; + default: + LOG(FATAL) << "cannot handle input type " + << interpreter->tensor(input)->type << " yet"; + exit(-1); } struct timeval start_time, stop_time; @@ -177,13 +185,22 @@ void RunInference(Settings* s) { std::vector> top_results; - if (s->input_floating) { - get_top_n(interpreter->typed_output_tensor(0), output_size, - num_results, threshold, &top_results, s->input_floating); - } else { - get_top_n(interpreter->typed_output_tensor(0), + int output = interpreter->outputs()[0]; + switch (interpreter->tensor(output)->type) { + case kTfLiteFloat32: + get_top_n(interpreter->typed_output_tensor(0), output_size, num_results, threshold, &top_results, - s->input_floating); + true); + break; + case kTfLiteUInt8: + get_top_n(interpreter->typed_output_tensor(0), + output_size, num_results, threshold, &top_results, + false); + break; + default: + LOG(FATAL) << "cannot handle output type " + << interpreter->tensor(input)->type << " yet"; + exit(-1); } std::vector labels; @@ -203,13 +220,11 @@ void display_usage() { LOG(INFO) << "label_image\n" << "--accelerated, -a: [0|1], use Android NNAPI or note\n" << "--count, -c: loop interpreter->Invoke() for certain times\n" - << "--input_floating, -f: [0|1] type of input layer is floating " - "point numbers\n" << "--input_mean, -b: input mean\n" << "--input_std, -s: input standard deviation\n" << "--image, -i: image_name.bmp\n" << "--labels, -l: labels for the model\n" - << "--tflite_mode, -m: model_name.tflite\n" + << "--tflite_model, -m: model_name.tflite\n" << "--threads, -t: number of threads\n" << "--verbose, -v: [0|1] print more information\n" << "\n"; @@ -223,7 +238,6 @@ int Main(int argc, char** argv) { static struct option long_options[] = { {"accelerated", required_argument, 0, 'a'}, {"count", required_argument, 0, 'c'}, - {"input_floating", required_argument, 0, 'f'}, {"verbose", required_argument, 0, 'v'}, {"image", required_argument, 0, 'i'}, {"labels", required_argument, 0, 'l'}, @@ -254,11 +268,6 @@ int Main(int argc, char** argv) { s.loop_count = strtol( // NOLINT(runtime/deprecated_fn) optarg, (char**)NULL, 10); break; - case 'f': - s.input_floating = strtol( // NOLINT(runtime/deprecated_fn) - optarg, (char**)NULL, 10); - s.input_layer_type = "float"; - break; case 'i': s.input_bmp_name = optarg; break; diff --git a/tensorflow/contrib/lite/examples/label_image/label_image.h b/tensorflow/contrib/lite/examples/label_image/label_image.h index ce98e06fc162a9588707eae701e2fcb8d648a4e4..4de32e33fb4ef2ab5d0e111886cdc737398147e9 100644 --- a/tensorflow/contrib/lite/examples/label_image/label_image.h +++ b/tensorflow/contrib/lite/examples/label_image/label_image.h @@ -16,9 +16,11 @@ limitations under the License. #ifndef TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_LABEL_IMAGE_H #define TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_LABEL_IMAGE_H -#include #include "tensorflow/contrib/lite/string.h" +namespace tflite { +namespace label_image { + struct Settings { bool verbose = false; bool accel = false; @@ -33,4 +35,7 @@ struct Settings { int number_of_threads = 4; }; +} // namespace label_image +} // namespace tflite + #endif // TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_LABEL_IMAGE_H diff --git a/tensorflow/contrib/lite/examples/label_image/label_image.md b/tensorflow/contrib/lite/examples/label_image/label_image.md index d6019d673f1b15429e69b57e8dc9eeaad2825bc3..9ce32cf101897f2d41cd14a485aeb432344928a0 100644 --- a/tensorflow/contrib/lite/examples/label_image/label_image.md +++ b/tensorflow/contrib/lite/examples/label_image/label_image.md @@ -1,8 +1,12 @@ label_image for TensorFlow Lite inspired by TensorFlow's label_image. + +To build label_image for Android, run $TENSORFLOW_ROOT/configure +and set Android NDK or configure NDK setting in +$TENSORFLOW_ROOT/WORKSPACE first. To build it for android ARMv8: ``` -> bazel build --cxxopt=-std=c++11 \ +> bazel build --config monolithic --cxxopt=-std=c++11 \ --crosstool_top=//external:android/crosstool \ --host_crosstool_top=@bazel_tools//tools/cpp:toolchain \ --cpu=arm64-v8a \ @@ -10,13 +14,13 @@ To build it for android ARMv8: ``` or ``` -> bazel build --config android_arm64 --cxxopt=-std=c++11 \ +> bazel build --config android_arm64 --config monolithic --cxxopt=-std=c++11 \ //tensorflow/contrib/lite/examples/label_image:label_image ``` To build it for android arm-v7a: ``` -> bazel build --cxxopt=-std=c++11 \ +> bazel build --config monolithic --cxxopt=-std=c++11 \ --crosstool_top=//external:android/crosstool \ --host_crosstool_top=@bazel_tools//tools/cpp:toolchain \ --cpu=armeabi-v7a \ @@ -24,7 +28,7 @@ To build it for android arm-v7a: ``` or ``` -> bazel build --config android_arm --cxxopt=-std=c++11 \ +> bazel build --config android_arm --config monolithic --cxxopt=-std=c++11 \ //tensorflow/contrib/lite/examples/label_image:label_image ``` diff --git a/tensorflow/contrib/lite/graph_info.h b/tensorflow/contrib/lite/graph_info.h index 5481aede605453958adb2c2e661c73130046d9f9..57690058c4630f75f8b23073f4ab44f27090c51b 100644 --- a/tensorflow/contrib/lite/graph_info.h +++ b/tensorflow/contrib/lite/graph_info.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_GRAPH_INFO_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_GRAPH_INFO_H_ +#ifndef TENSORFLOW_CONTRIB_LITE_GRAPH_INFO_H_ +#define TENSORFLOW_CONTRIB_LITE_GRAPH_INFO_H_ #include @@ -50,4 +50,4 @@ class GraphInfo { } // namespace tflite -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_GRAPH_INFO_H_ +#endif // TENSORFLOW_CONTRIB_LITE_GRAPH_INFO_H_ diff --git a/tensorflow/contrib/lite/interpreter.cc b/tensorflow/contrib/lite/interpreter.cc index 5f5981e45a20a2c79ea1a2ba08345e831ce194da..69a597dc5a219b55eced6ec8da5b388caf372b8e 100644 --- a/tensorflow/contrib/lite/interpreter.cc +++ b/tensorflow/contrib/lite/interpreter.cc @@ -291,7 +291,6 @@ TfLiteStatus Interpreter::Invoke() { TfLiteStatus status = kTfLiteOk; if (nnapi_delegate_) { - TF_LITE_ENSURE_STATUS(PrepareOpsAndTensors()); if (next_node_to_prepare_ == nodes_and_registration_.size()) { TF_LITE_ENSURE_OK(&context_, nnapi_delegate_->Invoke(this)); return kTfLiteOk; diff --git a/tensorflow/contrib/lite/interpreter.h b/tensorflow/contrib/lite/interpreter.h index 38dd402e8a971fd0aab51e98610ad12131441862..4f732769f9f921a9debd5213547d2baccfa69426 100644 --- a/tensorflow/contrib/lite/interpreter.h +++ b/tensorflow/contrib/lite/interpreter.h @@ -14,8 +14,8 @@ limitations under the License. ==============================================================================*/ // Main abstraction controlling the tflite interpreter. // See context.h for the API for defining operations (TfLiteRegistration). -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_INTERPRETER_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_INTERPRETER_H_ +#ifndef TENSORFLOW_CONTRIB_LITE_INTERPRETER_H_ +#define TENSORFLOW_CONTRIB_LITE_INTERPRETER_H_ #include #include @@ -363,4 +363,4 @@ class Interpreter { }; } // namespace tflite -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_INTERPRETER_H_ +#endif // TENSORFLOW_CONTRIB_LITE_INTERPRETER_H_ diff --git a/tensorflow/contrib/lite/kernels/Android.bp b/tensorflow/contrib/lite/kernels/Android.bp index de53078c8af2783cc876636ad350d0adb48fb6a9..a171eba48d2a5314b3c9b3c553f89725e79d83a8 100644 --- a/tensorflow/contrib/lite/kernels/Android.bp +++ b/tensorflow/contrib/lite/kernels/Android.bp @@ -23,6 +23,9 @@ cc_library_static { "internal/reference/portable_tensor_utils.cc", "internal/optimized/neon_tensor_utils.cc", ], + cflags: [ + "-Wno-extern-c-compat", + ] } cc_library_static { @@ -58,9 +61,11 @@ cc_library_static { "space_to_batch_nd.cc", "space_to_depth.cc", "squeeze.cc", + "strided_slice.cc", "sub.cc", "svdf.cc", "transpose.cc", + "unidirectional_sequence_lstm.cc", "unidirectional_sequence_rnn.cc", "internal/tensor_utils.cc", "internal/quantization_util.cc", @@ -76,6 +81,7 @@ cc_library_static { cflags: [ "-DNAMESPACE_FOR_HASH_FUNCTIONS=farmhash", "-Wno-array-bounds", + "-Wno-extern-c-compat", "-Wno-invalid-partial-specialization", "-Wno-missing-field-initializers", "-Wno-sign-compare", diff --git a/tensorflow/contrib/lite/kernels/BUILD b/tensorflow/contrib/lite/kernels/BUILD index 7e9644f36c71ff7e03a04dd01743be811632f077..4195e7553c48028d56e80db0d204ef5656be874d 100644 --- a/tensorflow/contrib/lite/kernels/BUILD +++ b/tensorflow/contrib/lite/kernels/BUILD @@ -50,7 +50,7 @@ cc_library( deps = [ ":op_macros", "//tensorflow/contrib/lite:context", - "@gemmlowp//:gemmlowp", + "@gemmlowp", ], ) @@ -103,9 +103,11 @@ cc_library( "space_to_batch_nd.cc", "space_to_depth.cc", "squeeze.cc", + "strided_slice.cc", "sub.cc", "svdf.cc", "transpose.cc", + "unidirectional_sequence_lstm.cc", "unidirectional_sequence_rnn.cc", ], hdrs = [ @@ -249,6 +251,18 @@ tf_cc_test( ], ) +tf_cc_test( + name = "unidirectional_sequence_lstm_test", + size = "small", + srcs = ["unidirectional_sequence_lstm_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + tf_cc_test( name = "unidirectional_sequence_rnn_test", size = "small", @@ -505,6 +519,18 @@ tf_cc_test( ], ) +tf_cc_test( + name = "strided_slice_test", + size = "small", + srcs = ["strided_slice_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/contrib/lite/kernels/activation_functor.h b/tensorflow/contrib/lite/kernels/activation_functor.h index cfb3369e991a474315424423fe655ba214edabbc..41ec3cca33ae1c6bb3f7c43dd1923f104c2ab6a2 100644 --- a/tensorflow/contrib/lite/kernels/activation_functor.h +++ b/tensorflow/contrib/lite/kernels/activation_functor.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_ACTIVATION_FUNCTOR_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_ACTIVATION_FUNCTOR_H_ +#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_ACTIVATION_FUNCTOR_H_ +#define TENSORFLOW_CONTRIB_LITE_KERNELS_ACTIVATION_FUNCTOR_H_ #include #include @@ -55,4 +55,4 @@ class ActivationFunctor { } // namespace tflite -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_ACTIVATION_FUNCTOR_H_ +#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_ACTIVATION_FUNCTOR_H_ diff --git a/tensorflow/contrib/lite/kernels/conv.cc b/tensorflow/contrib/lite/kernels/conv.cc index c75c04baeac2ce53c6261d677dca8d72fafa0da5..37f499a4d09a38765aa4b8db8aa91b708edd7823 100644 --- a/tensorflow/contrib/lite/kernels/conv.cc +++ b/tensorflow/contrib/lite/kernels/conv.cc @@ -322,22 +322,23 @@ void EvalFloat(TfLiteContext* context, TfLiteNode* node, CalculateActivationRangeFloat(params->activation, &output_activation_min, &output_activation_max); - const float* filter_data; - if (data->need_hwcn_weights) { - filter_data = GetTensorData(hwcn_weights); - } else { - filter_data = GetTensorData(filter); - } - if (kernel_type == kReference) { - reference_ops::Conv( - GetTensorData(input), GetTensorDims(input), filter_data, - GetTensorDims(filter), GetTensorData(bias), GetTensorDims(bias), - params->stride_width, params->stride_height, data->padding.width, - data->padding.height, output_activation_min, output_activation_max, - GetTensorData(output), GetTensorDims(output), - GetTensorData(im2col), GetTensorDims(im2col)); + reference_ops::Conv(GetTensorData(input), GetTensorDims(input), + GetTensorData(filter), GetTensorDims(filter), + GetTensorData(bias), GetTensorDims(bias), + params->stride_width, params->stride_height, + data->padding.width, data->padding.height, + output_activation_min, output_activation_max, + GetTensorData(output), GetTensorDims(output), + GetTensorData(im2col), GetTensorDims(im2col)); } else { + const float* filter_data; + if (data->need_hwcn_weights) { + filter_data = GetTensorData(hwcn_weights); + } else { + filter_data = GetTensorData(filter); + } + multithreaded_ops::Conv( GetTensorData(input), GetTensorDims(input), filter_data, GetTensorDims(filter), GetTensorData(bias), GetTensorDims(bias), diff --git a/tensorflow/contrib/lite/kernels/gather.cc b/tensorflow/contrib/lite/kernels/gather.cc index f8df797daf7338e33b16508c21fc61cd9836db1e..0e4187d1eac64636a2e2b25e9a1cc45c3a4da557 100644 --- a/tensorflow/contrib/lite/kernels/gather.cc +++ b/tensorflow/contrib/lite/kernels/gather.cc @@ -42,9 +42,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, positions->type, kTfLiteInt32); // Check that input and output types match. TF_LITE_ENSURE_EQ(context, input->type, output->type); - // TODO(mgubin): only 1D positions are currently supported. - TF_LITE_ENSURE_EQ(context, NumDimensions(positions), 1); + // TODO(mgubin): only 0D or 1D positions are currently supported. + TF_LITE_ENSURE(context, NumDimensions(positions) <= 1); // TODO(mgubin): Only default axis == 0 is supported. + TF_LITE_ENSURE_EQ(context, params->axis, 0); // Check conditions for different types. switch (input->type) { case kTfLiteFloat32: @@ -64,7 +65,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { } const int num_dimensions = NumDimensions(input) + NumDimensions(positions) - 1; - TF_LITE_ENSURE(context, params->axis < num_dimensions); + TF_LITE_ENSURE(context, params->axis <= num_dimensions); TfLiteIntArray* output_shape = TfLiteIntArrayCreate(num_dimensions); int output_index = 0; for (int i = 0; i < params->axis; ++i) { diff --git a/tensorflow/contrib/lite/kernels/gather_test.cc b/tensorflow/contrib/lite/kernels/gather_test.cc index 6343d3b4ef20ae3e030396ec1b6adbcf83a3e45f..658d977b8dc7fffcdde69d74ba2564dfa1b5709e 100644 --- a/tensorflow/contrib/lite/kernels/gather_test.cc +++ b/tensorflow/contrib/lite/kernels/gather_test.cc @@ -48,8 +48,8 @@ class GatherOpModel : public SingleOpModel { PopulateStringTensor(input_, data); } - void SetPositions(std::initializer_list data) { - PopulateTensor(positions_, data); + void SetPositions(std::initializer_list data) { + PopulateTensor(positions_, data); } std::vector GetOutputFloat() { return ExtractVector(output_); } @@ -76,6 +76,29 @@ TEST(GatherOpTest, Shuffle) { ElementsAreArray(ArrayFloatNear({0.7, 0.8, -2, 0.2}))); } +TEST(GatherOpTest, Test0DIndex) { + GatherOpModel m({2, 2}, TensorType_FLOAT32, {}); + m.SetInputFloat({-2.0, 0.2, 0.7, 0.8}); + m.SetPositions({1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputFloat(), + ElementsAreArray(ArrayFloatNear({0.7, 0.8}))); + EXPECT_THAT(m.GetOutputShape(), + ElementsAreArray({2})); +} + +TEST(GatherOpTest, Test0DIndexWith0DResult) { + // 0D tensor is special case in current TFLite. Test it once to make sure + // existing workarounds are fine with it. + GatherOpModel m({3}, TensorType_FLOAT32, {}); + m.SetInputFloat({1.0, 2.0, 3.0}); + m.SetPositions({1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputFloat(), + ElementsAreArray(ArrayFloatNear({2.0}))); + EXPECT_TRUE(m.GetOutputShape().empty()); +} + TEST(FloatGatherOpTest, Duplicate) { GatherOpModel m({1, 2, 2}, TensorType_FLOAT32, {2}); m.SetInputFloat({-2.0, 0.2, 0.7, 0.8}); diff --git a/tensorflow/contrib/lite/kernels/gemm_support.h b/tensorflow/contrib/lite/kernels/gemm_support.h index b531959ffb143c774ee715743480b03ebfbdc114..466781cbcecc7fb851d9078c450cc6c12364d2bb 100644 --- a/tensorflow/contrib/lite/kernels/gemm_support.h +++ b/tensorflow/contrib/lite/kernels/gemm_support.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_GEMM_SUPPORT_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_GEMM_SUPPORT_H_ +#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_GEMM_SUPPORT_H_ +#define TENSORFLOW_CONTRIB_LITE_KERNELS_GEMM_SUPPORT_H_ #include "public/gemmlowp.h" #include "tensorflow/contrib/lite/context.h" @@ -51,4 +51,4 @@ void SetMaxNumThreads(TfLiteContext* context, int num_threads); } // namespace gemm_support } // namespace tflite -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_GEMM_SUPPORT_H_ +#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_GEMM_SUPPORT_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/BUILD b/tensorflow/contrib/lite/kernels/internal/BUILD index a3ecb2ebf6a889729954d1e447997c510e8ff6d4..21118fc96d804654a33d5c693d496b05e2dc59d2 100644 --- a/tensorflow/contrib/lite/kernels/internal/BUILD +++ b/tensorflow/contrib/lite/kernels/internal/BUILD @@ -145,7 +145,7 @@ cc_library( ":types", ":round", "//third_party/eigen3", - "@gemmlowp//:gemmlowp", + "@gemmlowp", "//tensorflow/contrib/lite:builtin_op_data", ] + select({ ":haswell": tflite_deps_intel, @@ -223,7 +223,7 @@ cc_library( ":round", ":types", "//third_party/eigen3", - "@gemmlowp//:gemmlowp", + "@gemmlowp", "//tensorflow/contrib/lite:builtin_op_data", ] + select({ ":haswell": tflite_deps_intel, diff --git a/tensorflow/contrib/lite/kernels/internal/common.h b/tensorflow/contrib/lite/kernels/internal/common.h index 1e88b4774ce2d22d7319751f80ffee1970996847..4d7a3a4e98497653071f9bbee464bc05a8e821e5 100644 --- a/tensorflow/contrib/lite/kernels/internal/common.h +++ b/tensorflow/contrib/lite/kernels/internal/common.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_COMMON_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_COMMON_H_ +#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_COMMON_H_ +#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_COMMON_H_ #ifndef ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK #ifdef GEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK @@ -106,4 +106,4 @@ inline int32 MultiplyByQuantizedMultiplierGreaterThanOne( } // namespace tflite -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_COMMON_H_ +#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_COMMON_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/compatibility.h b/tensorflow/contrib/lite/kernels/internal/compatibility.h index 796a03566a4bf971294dd2375f590dfd20d600f7..1d963afb7e1ce414f251f090208923ca0c68cee1 100644 --- a/tensorflow/contrib/lite/kernels/internal/compatibility.h +++ b/tensorflow/contrib/lite/kernels/internal/compatibility.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_COMPATIBILITY_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_COMPATIBILITY_H_ +#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_COMPATIBILITY_H_ +#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_COMPATIBILITY_H_ #include #include @@ -75,4 +75,4 @@ using uint16 = std::uint16_t; using int32 = std::int32_t; using uint32 = std::uint32_t; -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_COMPATIBILITY_H_ +#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_COMPATIBILITY_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h index da34c8aef94b1c69e661bd33fcb518e73034c4bd..81796e295d9c7ae1f04163467c8b2af851b632c2 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_FLOAT_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_FLOAT_H_ +#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_FLOAT_H_ +#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_FLOAT_H_ #include "public/gemmlowp.h" #include "tensorflow/contrib/lite/kernels/internal/common.h" @@ -1057,4 +1057,4 @@ void DepthwiseConv(const float* input_data, const Dims<4>& input_dims, } // namespace optimized_ops } // namespace tflite -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_FLOAT_H_ +#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_FLOAT_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h index 051ed2a2c44a04f0473dfd26637e53865a5a51ac..f993fd6a00f054c670b247e886a1d9d2a34643e7 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_UINT8_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_UINT8_H_ +#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_UINT8_H_ +#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_UINT8_H_ #include "fixedpoint/fixedpoint.h" #include "public/gemmlowp.h" @@ -1913,4 +1913,4 @@ void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims, } // namespace optimized_ops } // namespace tflite -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_UINT8_H_ +#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_UINT8_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/eigen_spatial_convolutions.h b/tensorflow/contrib/lite/kernels/internal/optimized/eigen_spatial_convolutions.h index a67b35fd65891b1d0d7323d926b5a68bdf0a05b3..86bd8eced5cddd85a7281e99edbdee8cb6ab85cb 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/eigen_spatial_convolutions.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/eigen_spatial_convolutions.h @@ -16,8 +16,8 @@ limitations under the License. // Copied from tensorflow/core/kernels/eigen_spatial_convolutions.h. // TODO(petewarden) - move this to a common location in Eigen itself. -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_SPATIAL_CONVOLUTIONS_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_SPATIAL_CONVOLUTIONS_H_ +#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_SPATIAL_CONVOLUTIONS_H_ +#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_SPATIAL_CONVOLUTIONS_H_ #define EIGEN_USE_CUSTOM_THREAD_POOL #define EIGEN_USE_THREADS @@ -228,4 +228,4 @@ EIGEN_DEVICE_FUNC // clang-format on -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_SPATIAL_CONVOLUTIONS_H_ +#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_SPATIAL_CONVOLUTIONS_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_google.h b/tensorflow/contrib/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_google.h index 1d7e254b93dbdf4cb6b998f26dc198ee9658b9d1..655c8a3f022fe83c946d8028491ad750bb83b73a 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_google.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_google.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_TENSOR_REDUCED_INSTANTIATIONS_GOOGLE_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_TENSOR_REDUCED_INSTANTIATIONS_GOOGLE_H_ +#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_TENSOR_REDUCED_INSTANTIATIONS_GOOGLE_H_ +#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_TENSOR_REDUCED_INSTANTIATIONS_GOOGLE_H_ #define EIGEN_USE_CUSTOM_THREAD_POOL #define EIGEN_USE_THREADS @@ -140,4 +140,4 @@ limitations under the License. #include "unsupported/Eigen/CXX11/src/Tensor/TensorIO.h" #include "Eigen/src/Core/util/ReenableStupidWarnings.h" -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_TENSOR_REDUCED_INSTANTIATIONS_H +#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_TENSOR_REDUCED_INSTANTIATIONS_H diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_oss.h b/tensorflow/contrib/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_oss.h index fc33f2d5df809877cb3d30c1050bc466ed76b9f6..3141cd3002597347108f7a5b310e3ac3674fe2c5 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_oss.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_oss.h @@ -19,8 +19,8 @@ limitations under the License. // clang-format off -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_TENSOR_REDUCED_INSTANTIATIONS_OSS_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_TENSOR_REDUCED_INSTANTIATIONS_OSS_H_ +#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_TENSOR_REDUCED_INSTANTIATIONS_OSS_H_ +#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_TENSOR_REDUCED_INSTANTIATIONS_OSS_H_ #include "Eigen/Core" @@ -166,4 +166,4 @@ typedef unsigned __int64 uint64_t; #include "Eigen/src/Core/util/ReenableStupidWarnings.h" -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_TENSOR_REDUCED_INSTANTIATIONS_OSS_H_ +#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_TENSOR_REDUCED_INSTANTIATIONS_OSS_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h b/tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h index b3615f4658a1a70284cc9d386a868a87aa09819b..0bfb4e9b1f8ee4167cfb629645a38538be1d73d4 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_MULTITHREAD_CONV -#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_MULTITHREAD_CONV +#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_MULTITHREAD_CONV +#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_MULTITHREAD_CONV #include #include @@ -192,4 +192,4 @@ inline void Conv(const float* input_data, const Dims<4>& input_dims, } // namespace multithreaded_ops } // namespace tflite -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_MULTITHREAD_CONV +#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_MULTITHREAD_CONV diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h index 3a4af87304eaf33489b38bd9b15ad9789e091d24..b7e317dc60e2c68e9e993ff45c9090a01bd13b94 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_NEON_TENSOR_UTILS_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_NEON_TENSOR_UTILS_H_ +#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_NEON_TENSOR_UTILS_H_ +#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_NEON_TENSOR_UTILS_H_ // TODO(ghodrat): Remove this header file and the dependency to internal data // structure. @@ -110,4 +110,4 @@ void ReductionSumVector(const float* input_vector, float* output_vector, } // namespace tensor_utils } // namespace tflite -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_NEON_TENSOR_UTILS_H_ +#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_NEON_TENSOR_UTILS_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h index ded5ae8ff50cfc5337a5ea5f6e4880b701246aa6..f65ca6adad71b078a8c71880a1d620545206457d 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_OPS_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_OPS_H_ +#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_OPS_H_ +#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_OPS_H_ #include #include @@ -1538,9 +1538,10 @@ void Add(const int32* input1_data, const Dims<4>& input1_dims, // reference_ops.h. Once an optimized version is implemented and NdArrayDesc // is no longer referenced in this file, move NdArrayDesc from types.h to // reference_ops.h. -template +template void BroadcastAdd(const T* input1_data, const Dims<4>& input1_dims, const T* input2_data, const Dims<4>& input2_dims, + T output_activation_min, T output_activation_max, T* output_data, const Dims<4>& output_dims) { gemmlowp::ScopedProfilingLabel label("BroadcastAdd"); @@ -1563,15 +1564,30 @@ void BroadcastAdd(const T* input1_data, const Dims<4>& input1_dims, for (int y = 0; y < ArraySize(output_dims, 2); ++y) { for (int x = 0; x < ArraySize(output_dims, 1); ++x) { for (int c = 0; c < ArraySize(output_dims, 0); ++c) { - output_data[Offset(output_dims, c, x, y, b)] = ActivationFunction( - input1_data[SubscriptToIndex(desc1, c, x, y, b)] + - input2_data[SubscriptToIndex(desc2, c, x, y, b)]); + output_data[Offset(output_dims, c, x, y, b)] = + ActivationFunctionWithMinMax( + input1_data[SubscriptToIndex(desc1, c, x, y, b)] + + input2_data[SubscriptToIndex(desc2, c, x, y, b)], + output_activation_min, output_activation_max); } } } } } +// legacy, for compatibility with old checked-in code +template +void BroadcastAdd(const T* input1_data, const Dims<4>& input1_dims, + const T* input2_data, const Dims<4>& input2_dims, + T* output_data, const Dims<4>& output_dims) { + T output_activation_min, output_activation_max; + GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); + + BroadcastAdd(input1_data, input1_dims, input2_data, input2_dims, + output_activation_min, output_activation_max, output_data, + output_dims); +} + inline void BroadcastAdd(int left_shift, const uint8* input1_data, const Dims<4>& input1_dims, int32 input1_offset, int32 input1_multiplier, int input1_shift, @@ -1772,9 +1788,10 @@ void Mul(const int32* input1_data, const Dims<4>& input1_dims, // reference_ops.h. Once an optimized version is implemented and NdArrayDesc // is no longer referenced in this file, move NdArrayDesc from types.h to // reference_ops.h. -template +template void BroadcastMul(const T* input1_data, const Dims<4>& input1_dims, const T* input2_data, const Dims<4>& input2_dims, + T output_activation_min, T output_activation_max, T* output_data, const Dims<4>& output_dims) { gemmlowp::ScopedProfilingLabel label("BroadcastMul"); @@ -1797,15 +1814,30 @@ void BroadcastMul(const T* input1_data, const Dims<4>& input1_dims, for (int y = 0; y < ArraySize(output_dims, 2); ++y) { for (int x = 0; x < ArraySize(output_dims, 1); ++x) { for (int c = 0; c < ArraySize(output_dims, 0); ++c) { - output_data[Offset(output_dims, c, x, y, b)] = ActivationFunction( - input1_data[SubscriptToIndex(desc1, c, x, y, b)] * - input2_data[SubscriptToIndex(desc2, c, x, y, b)]); + output_data[Offset(output_dims, c, x, y, b)] = + ActivationFunctionWithMinMax( + input1_data[SubscriptToIndex(desc1, c, x, y, b)] * + input2_data[SubscriptToIndex(desc2, c, x, y, b)], + output_activation_min, output_activation_max); } } } } } +// legacy, for compatibility with old checked-in code +template +void BroadcastMul(const T* input1_data, const Dims<4>& input1_dims, + const T* input2_data, const Dims<4>& input2_dims, + T* output_data, const Dims<4>& output_dims) { + T output_activation_min, output_activation_max; + GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); + + BroadcastMul(input1_data, input1_dims, input2_data, input2_dims, + output_activation_min, output_activation_max, output_data, + output_dims); +} + inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims, int32 input1_offset, const uint8* input2_data, const Dims<4>& input2_dims, int32 input2_offset, @@ -3805,4 +3837,4 @@ void ArgMax(const T3* axis, const T1* input_data, const Dims<4>& input_dims, #pragma GCC diagnostic pop #endif -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_OPS_H_ +#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_OPS_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_float.h b/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_float.h index 8e0f234545e43dd8b2412e065aaecad8325a1182..9aabee5000c29ed97fcf7e874d661e72fd768f84 100644 --- a/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_float.h +++ b/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_float.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_DEPTHWISECONV_FLOAT_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_DEPTHWISECONV_FLOAT_H_ +#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_DEPTHWISECONV_FLOAT_H_ +#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_DEPTHWISECONV_FLOAT_H_ #include "tensorflow/contrib/lite/kernels/internal/common.h" #include "tensorflow/contrib/lite/kernels/internal/compatibility.h" @@ -112,4 +112,4 @@ void DepthwiseConv(const float* input_data, const Dims<4>& input_dims, } // end namespace reference_ops } // end namespace tflite -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_DEPTHWISECONV_FLOAT_H_ +#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_DEPTHWISECONV_FLOAT_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h b/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h index 8a80558b32f2858778460956cd9f57617674e21e..e9b6baeaee87d22aef238410bc9f447509a81c47 100644 --- a/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h +++ b/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_DEPTHWISECONV_UINT8_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_DEPTHWISECONV_UINT8_H_ +#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_DEPTHWISECONV_UINT8_H_ +#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_DEPTHWISECONV_UINT8_H_ #include @@ -135,4 +135,4 @@ void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims, } // end namespace reference_ops } // end namespace tflite -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_DEPTHWISECONV_UINT8_H_ +#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_DEPTHWISECONV_UINT8_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h index 7f90d731b8454a020ab273e6b5591ed90aab14c7..afc3e26e7988a369fb777ae99c08c4e98f26ebb8 100644 --- a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h +++ b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_PORTABLE_TENSOR_UTILS_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_PORTABLE_TENSOR_UTILS_H_ +#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_PORTABLE_TENSOR_UTILS_H_ +#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_PORTABLE_TENSOR_UTILS_H_ // TDOD(ghodrat): Remove this header file and the dependency to internal data // structure. @@ -186,4 +186,4 @@ void ReductionSumVector(const float* input_vector, float* output_vector, } // namespace tensor_utils } // namespace tflite -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_PORTABLE_TENSOR_UTILS_H_ +#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_PORTABLE_TENSOR_UTILS_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h index 7f1f3143e8e2fa1e4a7c2a1902920e9e86ad7f68..5ad1178f8c473261a8024da1b91533095e82a2d4 100644 --- a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h +++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_REFERENCE_OPS_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_REFERENCE_OPS_H_ +#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_REFERENCE_OPS_H_ +#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_REFERENCE_OPS_H_ #include #include @@ -889,10 +889,11 @@ inline void Add(int left_shift, const uint8* input1_data, // dimensionality if the runtime code does a single loop over one dimension // that handles broadcasting as the base case. The code generator would then // generate max(D1, D2) nested for loops. -template -void BroadcastAdd(const float* input1_data, const Dims<4>& input1_dims, - const float* input2_data, const Dims<4>& input2_dims, - float* output_data, const Dims<4>& output_dims) { +template +void BroadcastAdd(const T* input1_data, const Dims<4>& input1_dims, + const T* input2_data, const Dims<4>& input2_dims, + T output_activation_min, T output_activation_max, + T* output_data, const Dims<4>& output_dims) { gemmlowp::ScopedProfilingLabel label("BroadcastAdd"); NdArrayDesc<4> desc1; @@ -914,15 +915,30 @@ void BroadcastAdd(const float* input1_data, const Dims<4>& input1_dims, for (int y = 0; y < ArraySize(output_dims, 2); ++y) { for (int x = 0; x < ArraySize(output_dims, 1); ++x) { for (int c = 0; c < ArraySize(output_dims, 0); ++c) { - output_data[Offset(output_dims, c, x, y, b)] = ActivationFunction( - input1_data[SubscriptToIndex(desc1, c, x, y, b)] + - input2_data[SubscriptToIndex(desc2, c, x, y, b)]); + output_data[Offset(output_dims, c, x, y, b)] = + ActivationFunctionWithMinMax( + input1_data[SubscriptToIndex(desc1, c, x, y, b)] + + input2_data[SubscriptToIndex(desc2, c, x, y, b)], + output_activation_min, output_activation_max); } } } } } +// legacy, for compatibility with old checked-in code +template +void BroadcastAdd(const T* input1_data, const Dims<4>& input1_dims, + const T* input2_data, const Dims<4>& input2_dims, + T* output_data, const Dims<4>& output_dims) { + T output_activation_min, output_activation_max; + GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); + + BroadcastAdd(input1_data, input1_dims, input2_data, input2_dims, + output_activation_min, output_activation_max, output_data, + output_dims); +} + inline void BroadcastAdd(int left_shift, const uint8* input1_data, const Dims<4>& input1_dims, int32 input1_offset, int32 input1_multiplier, int input1_shift, @@ -1053,10 +1069,11 @@ void Mul(const float* input1_data, const Dims<4>& input1_dims, // dimensionality if the runtime code does a single loop over one dimension // that handles broadcasting as the base case. The code generator would then // generate max(D1, D2) nested for loops. -template -void BroadcastMul(const float* input1_data, const Dims<4>& input1_dims, - const float* input2_data, const Dims<4>& input2_dims, - float* output_data, const Dims<4>& output_dims) { +template +void BroadcastMul(const T* input1_data, const Dims<4>& input1_dims, + const T* input2_data, const Dims<4>& input2_dims, + T output_activation_min, T output_activation_max, + T* output_data, const Dims<4>& output_dims) { gemmlowp::ScopedProfilingLabel label("BroadcastMul"); NdArrayDesc<4> desc1; @@ -1078,15 +1095,30 @@ void BroadcastMul(const float* input1_data, const Dims<4>& input1_dims, for (int y = 0; y < ArraySize(output_dims, 2); ++y) { for (int x = 0; x < ArraySize(output_dims, 1); ++x) { for (int c = 0; c < ArraySize(output_dims, 0); ++c) { - output_data[Offset(output_dims, c, x, y, b)] = ActivationFunction( - input1_data[SubscriptToIndex(desc1, c, x, y, b)] * - input2_data[SubscriptToIndex(desc2, c, x, y, b)]); + output_data[Offset(output_dims, c, x, y, b)] = + ActivationFunctionWithMinMax( + input1_data[SubscriptToIndex(desc1, c, x, y, b)] * + input2_data[SubscriptToIndex(desc2, c, x, y, b)], + output_activation_min, output_activation_max); } } } } } +// legacy, for compatibility with old checked-in code +template +void BroadcastMul(const T* input1_data, const Dims<4>& input1_dims, + const T* input2_data, const Dims<4>& input2_dims, + T* output_data, const Dims<4>& output_dims) { + T output_activation_min, output_activation_max; + GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); + + BroadcastMul(input1_data, input1_dims, input2_data, input2_dims, + output_activation_min, output_activation_max, output_data, + output_dims); +} + inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims, int32 input1_offset, const uint8* input2_data, const Dims<4>& input2_dims, int32 input2_offset, @@ -2330,6 +2362,18 @@ inline void Pad(const T* input_data, const Dims<4>& input_dims, } } +inline bool LoopCondition(int index, int stop, int stride) { + return stride > 0 ? index < stop : index > stop; +} + +inline int StartIndex(int start, int stride, int dim, bool masked) { + return masked ? (stride > 0 ? 0 : dim - 1) : start; +} + +inline int StopIndex(int stop, int stride, int dim, bool masked) { + return masked ? (stride > 0 ? dim : -1) : stop; +} + template inline void StridedSlice(const T* input_data, const Dims<4>& input_dims, int begin_mask, int end_mask, @@ -2337,20 +2381,35 @@ inline void StridedSlice(const T* input_data, const Dims<4>& input_dims, const std::vector& stops, const std::vector& strides, T* output_data, const Dims<4>& output_dims) { - const int start_b = (begin_mask & 8) ? 0 : starts[3]; - const int stop_b = (end_mask & 8) ? input_dims.sizes[3] : stops[3]; - const int start_h = (begin_mask & 4) ? 0 : starts[2]; - const int stop_h = (end_mask & 4) ? input_dims.sizes[2] : stops[2]; - const int start_w = (begin_mask & 2) ? 0 : starts[1]; - const int stop_w = (end_mask & 2) ? input_dims.sizes[1] : stops[1]; - const int start_d = (begin_mask & 1) ? 0 : starts[0]; - const int stop_d = (end_mask & 1) ? input_dims.sizes[0] : stops[0]; + TFLITE_DCHECK_EQ(starts.size(), 4); + TFLITE_DCHECK_EQ(stops.size(), 4); + TFLITE_DCHECK_EQ(strides.size(), 4); + const int start_b = + StartIndex(starts[3], strides[3], input_dims.sizes[3], begin_mask & 8); + const int stop_b = + StopIndex(stops[3], strides[3], input_dims.sizes[3], end_mask & 8); + const int start_h = + StartIndex(starts[2], strides[2], input_dims.sizes[2], begin_mask & 4); + const int stop_h = + StopIndex(stops[2], strides[2], input_dims.sizes[2], end_mask & 4); + const int start_w = + StartIndex(starts[1], strides[1], input_dims.sizes[1], begin_mask & 2); + const int stop_w = + StopIndex(stops[1], strides[1], input_dims.sizes[1], end_mask & 2); + const int start_d = + StartIndex(starts[0], strides[0], input_dims.sizes[0], begin_mask & 1); + const int stop_d = + StopIndex(stops[0], strides[0], input_dims.sizes[0], end_mask & 1); T* out_ptr = output_data; - for (int in_b = start_b; in_b < stop_b; in_b += strides[3]) { - for (int in_h = start_h; in_h < stop_h; in_h += strides[2]) { - for (int in_w = start_w; in_w < stop_w; in_w += strides[1]) { - for (int in_d = start_d; in_d < stop_d; in_d += strides[0]) { + for (int in_b = start_b; LoopCondition(in_b, stop_b, strides[3]); + in_b += strides[3]) { + for (int in_h = start_h; LoopCondition(in_h, stop_h, strides[2]); + in_h += strides[2]) { + for (int in_w = start_w; LoopCondition(in_w, stop_w, strides[1]); + in_w += strides[1]) { + for (int in_d = start_d; LoopCondition(in_d, stop_d, strides[0]); + in_d += strides[0]) { *out_ptr++ = input_data[Offset(input_dims, in_d, in_w, in_h, in_b)]; } } @@ -2628,4 +2687,4 @@ void Transpose(const T* input, const Dims<4>& input_dims, T* output, } // namespace reference_ops } // namespace tflite -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_REFERENCE_OPS_H_ +#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_REFERENCE_OPS_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/round.h b/tensorflow/contrib/lite/kernels/internal/round.h index 38525b0e208b852343849096ac68cbfc9ef3e389..f299d0bd8733dc603c4950091c8ac3d7890548a7 100644 --- a/tensorflow/contrib/lite/kernels/internal/round.h +++ b/tensorflow/contrib/lite/kernels/internal/round.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_ROUND_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_ROUND_H_ +#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_ROUND_H_ +#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_ROUND_H_ #include @@ -36,4 +36,4 @@ inline T TfLiteRound(const T x) { } // namespace tflite -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_ROUND_H_ +#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_ROUND_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/tensor.h b/tensorflow/contrib/lite/kernels/internal/tensor.h index 1961e1a2d5ecd4fd20c6f442b79dc88ed28062fe..dfe76c2afd40c692063710a4d98464b55e40feb9 100644 --- a/tensorflow/contrib/lite/kernels/internal/tensor.h +++ b/tensorflow/contrib/lite/kernels/internal/tensor.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_H_ +#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_H_ +#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_H_ #include #include "tensorflow/contrib/lite/context.h" @@ -83,4 +83,4 @@ inline Dims<4> GetTensorDims(const TfLiteTensor* tensor) { } // namespace tflite -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_H_ +#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/tensor_utils.h b/tensorflow/contrib/lite/kernels/internal/tensor_utils.h index e7e2994397650004c7ba442fa1803290e6b12302..40d144979b2f965725db86ff311e90f39438802f 100644 --- a/tensorflow/contrib/lite/kernels/internal/tensor_utils.h +++ b/tensorflow/contrib/lite/kernels/internal/tensor_utils.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_UTILS_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_UTILS_H_ +#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_UTILS_H_ +#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_UTILS_H_ #include "tensorflow/contrib/lite/builtin_op_data.h" @@ -113,4 +113,4 @@ void ReductionSumVector(const float* input_vector, float* output_vector, } // namespace tensor_utils } // namespace tflite -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_UTILS_H_ +#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_UTILS_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/types.h b/tensorflow/contrib/lite/kernels/internal/types.h index 5989ac8fcdec101c14dd7b04d89fe8c7bfce0a10..afe131b06ec41201395e80aa5415fd7db990f8d4 100644 --- a/tensorflow/contrib/lite/kernels/internal/types.h +++ b/tensorflow/contrib/lite/kernels/internal/types.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TYPES_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TYPES_H_ +#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TYPES_H_ +#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TYPES_H_ #include "tensorflow/contrib/lite/kernels/internal/compatibility.h" @@ -134,4 +134,4 @@ bool IsPackedWithoutStrides(const Dims& dims) { } // namespace tflite -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TYPES_H_ +#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TYPES_H_ diff --git a/tensorflow/contrib/lite/kernels/kernel_util.h b/tensorflow/contrib/lite/kernels/kernel_util.h index 25556ae4567aca45b3bfe4ba02b1cb58331d239d..1cf30ecff9760d218d279cc6c7132589e11cc15c 100644 --- a/tensorflow/contrib/lite/kernels/kernel_util.h +++ b/tensorflow/contrib/lite/kernels/kernel_util.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_KERNEL_UTIL_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_KERNEL_UTIL_H_ +#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_KERNEL_UTIL_H_ +#define TENSORFLOW_CONTRIB_LITE_KERNELS_KERNEL_UTIL_H_ #include "tensorflow/contrib/lite/builtin_op_data.h" #include "tensorflow/contrib/lite/context.h" @@ -62,4 +62,4 @@ void CalculateActivationRangeFloat(TfLiteFusedActivation activation, } // namespace tflite -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_KERNEL_UTIL_H_ +#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_KERNEL_UTIL_H_ diff --git a/tensorflow/contrib/lite/kernels/op_macros.h b/tensorflow/contrib/lite/kernels/op_macros.h index 63670efcb1e6349317aa5c75756707fb7a7fa2aa..7568eaa88edfa3260964e16f03299aecb97da6be 100644 --- a/tensorflow/contrib/lite/kernels/op_macros.h +++ b/tensorflow/contrib/lite/kernels/op_macros.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_OP_UTIL_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_OP_UTIL_H_ +#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_OP_UTIL_H_ +#define TENSORFLOW_CONTRIB_LITE_KERNELS_OP_UTIL_H_ #include @@ -31,4 +31,4 @@ limitations under the License. if ((x) != (y)) TF_LITE_FATAL(#x " didn't equal " #y); \ } while (0) -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_OP_UTIL_H_ +#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_OP_UTIL_H_ diff --git a/tensorflow/contrib/lite/kernels/pad.cc b/tensorflow/contrib/lite/kernels/pad.cc index 1a0d9d1505d41fb7948863f9da9e2a4f1b61e4f9..569bf0fe8fc9964a1299911d248d53862c99cbdf 100644 --- a/tensorflow/contrib/lite/kernels/pad.cc +++ b/tensorflow/contrib/lite/kernels/pad.cc @@ -33,65 +33,93 @@ enum KernelType { kGenericOptimized, }; -// TODO(nupurgarg): Padding represented as a tensor is ignored. Only use the -// `left_padding` and `right_padding` specified in `params`. struct PadContext { PadContext(TfLiteContext* context, TfLiteNode* node) { - params = reinterpret_cast(node->builtin_data); input = GetInput(context, node, 0); + paddings = GetInput(context, node, 1); output = GetOutput(context, node, 0); + dims = NumDimensions(input); } - TfLitePadParams* params; TfLiteTensor* input; + TfLiteTensor* paddings; TfLiteTensor* output; + int dims; }; +// Resizes output array based on the input size and padding size. This function +// is callable from both Prepare() and Eval() as long as the caller ensures the +// paddings data is present. +TfLiteStatus ResizeOutputTensor(TfLiteContext* context, + PadContext* op_context) { + // TODO(nupurgarg): Our current implementations rely on the inputs being 4D. + TF_LITE_ENSURE_EQ(context, op_context->dims, 4); + + // Ensures the paddings array is dims x 2. + TF_LITE_ENSURE_EQ(context, SizeOfDimension(op_context->paddings, 0), + op_context->dims); + TF_LITE_ENSURE_EQ(context, SizeOfDimension(op_context->paddings, 1), 2); + + // Determines the size of the output tensor. + const TfLiteIntArray* input_size = op_context->input->dims; + TfLiteIntArray* output_size = TfLiteIntArrayCreate(op_context->dims); + const int32* paddings_data = GetTensorData(op_context->paddings); + + for (int idx = 0; idx < op_context->dims; ++idx) { + int before_padding = *paddings_data++; + int after_padding = *paddings_data++; + + TF_LITE_ENSURE_MSG(context, (before_padding >= 0 && after_padding >= 0), + "Pad value has to be greater than equal to 0."); + + output_size->data[idx] = + (input_size->data[idx] + before_padding + after_padding); + } + + return context->ResizeTensor(context, op_context->output, output_size); +} + TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { - TF_LITE_ENSURE(context, NumInputs(node) == 1 || NumInputs(node) == 2); + TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); - // Determines size of output tensor. PadContext op_context(context, node); - int dims = NumDimensions(op_context.input); - TF_LITE_ENSURE_EQ(context, dims, op_context.params->num_dimensions); TF_LITE_ENSURE_EQ(context, op_context.input->type, op_context.output->type); - // TODO(nupurgarg): Our current implementations rely on the inputs being 4D. - TF_LITE_ENSURE_EQ(context, dims, 4); - - const TfLiteIntArray* input_size = op_context.input->dims; - TfLiteIntArray* output_size = TfLiteIntArrayCreate(dims); - for (int idx = 0; idx < dims; ++idx) { - TF_LITE_ENSURE_MSG(context, - (op_context.params->before_padding[idx] >= 0 && - op_context.params->after_padding[idx] >= 0), - "Pad value has to be greater than equal to 0."); - output_size->data[idx] = - (input_size->data[idx] + op_context.params->before_padding[idx] + - op_context.params->after_padding[idx]); + // TODO(nupurgarg): Create wrapper functions for dynamic tensor logic. + // Exit early if paddings is a non-const tensor. Set output tensor to + // dynamic so output size can be determined in Eval. + if (op_context.paddings->allocation_type != kTfLiteMmapRo) { + op_context.output->allocation_type = kTfLiteDynamic; + return kTfLiteOk; } - - return context->ResizeTensor(context, op_context.output, output_size); + return ResizeOutputTensor(context, &op_context); } template TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { PadContext op_context(context, node); - std::vector before_padding( - op_context.params->before_padding, - op_context.params->before_padding + op_context.params->num_dimensions); - std::vector after_padding( - op_context.params->after_padding, - op_context.params->after_padding + op_context.params->num_dimensions); - - // TODO(nupurgarg): Change TOCO's implementation to use padding arrays - // in forward order (depth, width, height, batch). - // Converts from int[] = {depth, width, height, batch} to int[] = {batch, - // height, width, depth} to match TOCO's implementation of pad in - // referenced_ops.h and optimized_ops.h. - std::reverse(before_padding.begin(), before_padding.end()); - std::reverse(after_padding.begin(), after_padding.end()); + // Resize the output tensor if the output tensor is dynamic. + if (op_context.output->allocation_type == kTfLiteDynamic) { + TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context)); + TfLiteTensorRealloc(op_context.output->bytes, op_context.output); + } + + // TODO(nupurgarg): Change kernel implementation to take in int* instead of + // vector to remove malloc from Eval(). + // Create before and after padding arrays that are accepted by the kernel. + std::vector before_padding; + std::vector after_padding; + const int32* paddings_data = GetTensorData(op_context.paddings); + + // TODO(nupurgarg): Change kernel implementation to use padding arrays in + // forward order (depth, width, height, batch). + // Build paddings in order of int[] = {batch, height, width, depth} to match + // kernel implementation of Pad in referenced_ops.h and optimized_ops.h. + for (int idx = op_context.dims - 1; idx >= 0; --idx) { + before_padding.push_back(paddings_data[idx * 2]); + after_padding.push_back(paddings_data[idx * 2 + 1]); + } #define TF_LITE_PAD(type, scalar) \ type::Pad(GetTensorData(op_context.input), \ diff --git a/tensorflow/contrib/lite/kernels/pad_test.cc b/tensorflow/contrib/lite/kernels/pad_test.cc index f3ea9417df0e61dcff7a877726ab91c9b22691ba..28834ad0719291b2e868bca2d86a6685e6eb9962 100644 --- a/tensorflow/contrib/lite/kernels/pad_test.cc +++ b/tensorflow/contrib/lite/kernels/pad_test.cc @@ -25,52 +25,87 @@ using ::testing::ElementsAreArray; class PadOpModel : public SingleOpModel { public: - PadOpModel(std::initializer_list input_shape, - std::initializer_list before_padding, - std::initializer_list after_padding) { - input_ = AddInput(TensorType_FLOAT32); - output_ = AddOutput(TensorType_FLOAT32); - SetBuiltinOp( - BuiltinOperator_PAD, BuiltinOptions_PadOptions, - CreatePadOptions(builder_, builder_.CreateVector(before_padding), - builder_.CreateVector(after_padding)) - .Union()); - BuildInterpreter({input_shape}); - } - void SetInput(std::initializer_list data) { PopulateTensor(input_, data); } + void SetPaddings(std::initializer_list paddings) { + PopulateTensor(paddings_, paddings); + } + std::vector GetOutput() { return ExtractVector(output_); } std::vector GetOutputShape() { return GetTensorShape(output_); } - private: + protected: int input_; int output_; + int paddings_; +}; + +// Tests case where paddings is a const tensor. +// +// Example usage is as follows: +// PadOpDynamicModel m(input_shape, paddings_shape, paddings_data); +// m.SetInput(input_data); +// m.Invoke(); +class PadOpConstModel : public PadOpModel { + public: + PadOpConstModel(std::initializer_list input_shape, + std::initializer_list paddings_shape, + std::initializer_list paddings) { + input_ = AddInput(TensorType_FLOAT32); + paddings_ = AddConstInput(TensorType_INT32, paddings, paddings_shape); + output_ = AddOutput(TensorType_FLOAT32); + + SetBuiltinOp(BuiltinOperator_PAD, BuiltinOptions_PadOptions, + CreatePadOptions(builder_).Union()); + BuildInterpreter({input_shape}); + } +}; + +// Test case where paddings is a non-const tensor. +// +// Example usage is as follows: +// PadOpDynamicModel m(input_shape, paddings_shape); +// m.SetInput(input_data); +// m.SetPaddings(paddings_data); +// m.Invoke(); +class PadOpDynamicModel : public PadOpModel { + public: + PadOpDynamicModel(std::initializer_list input_shape, + std::initializer_list paddings_shape) { + input_ = AddInput(TensorType_FLOAT32); + paddings_ = AddInput(TensorType_INT32); + output_ = AddOutput(TensorType_FLOAT32); + + SetBuiltinOp(BuiltinOperator_PAD, BuiltinOptions_PadOptions, + CreatePadOptions(builder_).Union()); + BuildInterpreter({input_shape, paddings_shape}); + } }; TEST(PadOpTest, TooManyDimensions) { EXPECT_DEATH( - PadOpModel({1, 2, 3, 4, 5, 6, 7, 8, 9}, {1, 2, 3, 4, 5, 6, 7, 8, 9}, - {1, 2, 3, 4, 5, 6, 7, 8, 9}), + PadOpConstModel({1, 2, 3, 4, 5, 6, 7, 8, 9}, {9, 2}, + {1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9}), "dims != 4"); } -// TODO(nupurgarg): Test case where before padding and after padding arrays -// don't contain the same number of dimensions. TEST(PadOpTest, UnequalDimensions) { - EXPECT_DEATH(PadOpModel({1, 1, 2, 1}, {1, 2, 3}, {1, 2, 3}), - "dims != op_context.params->num_dimensions"); + EXPECT_DEATH(PadOpConstModel({1, 1, 2, 1}, {3, 2}, {1, 1, 2, 2, 3, 3}), + "3 != 4"); } TEST(PadOpTest, InvalidPadValue) { - EXPECT_DEATH(PadOpModel({1, 1, 2, 1}, {0, 1, 2, 0}, {0, -1, -1, 0}), - "Pad value has to be greater than equal to 0."); + EXPECT_DEATH( + PadOpConstModel({1, 1, 2, 1}, {4, 2}, {0, 0, 1, -1, 2, -1, 0, 0}), + "Pad value has to be greater than equal to 0."); } -TEST(PadOpTest, SimpleTest) { - PadOpModel m({1, 2, 2, 1}, {0, 1, 1, 0}, {0, 1, 1, 0}); +TEST(PadOpTest, SimpleConstTest) { + // Padding is represented as four 2-D lists representing above padding and + // below padding (i.e. {{0, 0}, {1, 1}, {1, 1}, {0, 0}}). + PadOpConstModel m({1, 2, 2, 1}, {4, 2}, {0, 0, 1, 1, 1, 1, 0, 0}); m.SetInput({1, 2, 3, 4}); m.Invoke(); EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 0, 0, 0, 0, 1, 2, 0, 0, 3, 4, @@ -78,10 +113,30 @@ TEST(PadOpTest, SimpleTest) { EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1})); } -TEST(PadOpTest, AdvancedTest) { - // The padding is input in the order of batch, height, width, depth. - PadOpModel m({1, 2, 3, 1}, {0, 0, 1, 0}, {0, 2, 3, 0}); +TEST(PadOpTest, SimpleDynamicTest) { + PadOpDynamicModel m({1, 2, 2, 1}, {4, 2}); + m.SetInput({1, 2, 3, 4}); + m.SetPaddings({0, 0, 1, 1, 1, 1, 0, 0}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 0, 0, 0, 0, 1, 2, 0, 0, 3, 4, + 0, 0, 0, 0, 0})); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1})); +} + +TEST(PadOpTest, AdvancedConstTest) { + PadOpConstModel m({1, 2, 3, 1}, {4, 2}, {0, 0, 0, 2, 1, 3, 0, 0}); + m.SetInput({1, 2, 3, 4, 5, 6}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({0, 1, 2, 3, 0, 0, 0, 0, 4, 5, 6, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0})); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 7, 1})); +} + +TEST(PadOpTest, AdvancedDynamicTest) { + PadOpDynamicModel m({1, 2, 3, 1}, {4, 2}); m.SetInput({1, 2, 3, 4, 5, 6}); + m.SetPaddings({0, 0, 0, 2, 1, 3, 0, 0}); m.Invoke(); EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 1, 2, 3, 0, 0, 0, 0, 4, 5, 6, 0, 0, 0, diff --git a/tensorflow/contrib/lite/kernels/padding.h b/tensorflow/contrib/lite/kernels/padding.h index 3a60274524c468ef29e522de5569e0d8354974c2..40b8476b3779c66e31a04856bce8aebd378f1e5f 100644 --- a/tensorflow/contrib/lite/kernels/padding.h +++ b/tensorflow/contrib/lite/kernels/padding.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_PADDING_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_PADDING_H_ +#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_PADDING_H_ +#define TENSORFLOW_CONTRIB_LITE_KERNELS_PADDING_H_ namespace tflite { @@ -25,4 +25,4 @@ inline int ComputePadding(int stride, int in_size, int filter_size, } // namespace tflite -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_PADDING_H_ +#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_PADDING_H_ diff --git a/tensorflow/contrib/lite/kernels/register.cc b/tensorflow/contrib/lite/kernels/register.cc index 45ad5f18903927ff8f2743e96c167cfcb11bdcca..f605deaa5b4a3a8572c4be16cb1d301dbc49e5ba 100644 --- a/tensorflow/contrib/lite/kernels/register.cc +++ b/tensorflow/contrib/lite/kernels/register.cc @@ -48,6 +48,7 @@ TfLiteRegistration* Register_MUL(); TfLiteRegistration* Register_L2_NORMALIZATION(); TfLiteRegistration* Register_LOCAL_RESPONSE_NORMALIZATION(); TfLiteRegistration* Register_LSTM(); +TfLiteRegistration* Register_UNIDIRECTIONAL_SEQUENCE_LSTM(); TfLiteRegistration* Register_PAD(); TfLiteRegistration* Register_RESHAPE(); TfLiteRegistration* Register_RESIZE_BILINEAR(); @@ -57,6 +58,7 @@ TfLiteRegistration* Register_GATHER(); TfLiteRegistration* Register_TRANSPOSE(); TfLiteRegistration* Register_MEAN(); TfLiteRegistration* Register_SQUEEZE(); +TfLiteRegistration* Register_STRIDED_SLICE(); BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_RELU, Register_RELU()); @@ -89,6 +91,8 @@ BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION, Register_LOCAL_RESPONSE_NORMALIZATION()); AddBuiltin(BuiltinOperator_LSTM, Register_LSTM()); + AddBuiltin(BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM, + Register_UNIDIRECTIONAL_SEQUENCE_LSTM()); AddBuiltin(BuiltinOperator_PAD, Register_PAD()); AddBuiltin(BuiltinOperator_RESHAPE, Register_RESHAPE()); AddBuiltin(BuiltinOperator_RESIZE_BILINEAR, Register_RESIZE_BILINEAR()); @@ -100,6 +104,7 @@ BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_DIV, Register_DIV()); AddBuiltin(BuiltinOperator_SUB, Register_SUB()); AddBuiltin(BuiltinOperator_SQUEEZE, Register_SQUEEZE()); + AddBuiltin(BuiltinOperator_STRIDED_SLICE, Register_STRIDED_SLICE()); } TfLiteRegistration* BuiltinOpResolver::FindOp( diff --git a/tensorflow/contrib/lite/kernels/register.h b/tensorflow/contrib/lite/kernels/register.h index 28f5e0fcc80a14cf9fb6fb19b795d0c0d55e0df9..b9cff0ae21086b44e0c920095d5f6c9668346f38 100644 --- a/tensorflow/contrib/lite/kernels/register.h +++ b/tensorflow/contrib/lite/kernels/register.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_REGISTER_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_REGISTER_H_ +#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_REGISTER_H_ +#define TENSORFLOW_CONTRIB_LITE_KERNELS_REGISTER_H_ #include #include "tensorflow/contrib/lite/context.h" @@ -47,4 +47,4 @@ class BuiltinOpResolver : public OpResolver { } // namespace ops } // namespace tflite -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_BUILTIN_KERNELS_H +#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_BUILTIN_KERNELS_H diff --git a/tensorflow/contrib/lite/kernels/resize_bilinear.cc b/tensorflow/contrib/lite/kernels/resize_bilinear.cc index 1613c9a89faa3579b913408cc09cdad7f942cb99..9a419af0238e1a25e4b9e81f109b54de6b49097b 100644 --- a/tensorflow/contrib/lite/kernels/resize_bilinear.cc +++ b/tensorflow/contrib/lite/kernels/resize_bilinear.cc @@ -33,49 +33,53 @@ enum KernelType { }; constexpr int kInputTensor = 0; +constexpr int kSizeTensor = 1; constexpr int kOutputTensor = 0; TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { - auto* params = - reinterpret_cast(node->builtin_data); - - TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); + TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); TfLiteTensor* input = GetInput(context, node, kInputTensor); + TfLiteTensor* size = GetInput(context, node, kSizeTensor); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); // TODO(ahentz): Our current implementations rely on the inputs being 4D. TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4); + TF_LITE_ENSURE_EQ(context, NumDimensions(size), 1); // TODO(ahentz): Our current implementations only support float32. - TF_LITE_ENSURE_EQ(context, output->type, kTfLiteFloat32); - TF_LITE_ENSURE_EQ(context, input->type, output->type); - - TfLiteIntArray* output_size = TfLiteIntArrayCreate(4); - output_size->data[0] = input->dims->data[0]; - output_size->data[1] = params->new_height; - output_size->data[2] = params->new_width; - output_size->data[3] = input->dims->data[3]; - - return context->ResizeTensor(context, output, output_size); + TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32); + TF_LITE_ENSURE_EQ(context, size->type, kTfLiteInt32); + // ResizeBilinear creates a float tensor even when the input is made of + // integers. + output->type = kTfLiteFloat32; + + // TODO(ahentz): if the input is constant, we can allocate here. + output->allocation_type = kTfLiteDynamic; + return kTfLiteOk; } template TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { - auto* params = - reinterpret_cast(node->builtin_data); - TfLiteTensor* input = GetInput(context, node, kInputTensor); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + TfLiteTensor* size = GetInput(context, node, kSizeTensor); - // We have to fake a tensor here, to satisfy ResizeBilinear(). - int32 output_size_data[2] = {params->new_height, params->new_width}; + // TODO(ahentz): we only need to do this here if it wasn't done in Eval(). + TfLiteIntArray* output_size = TfLiteIntArrayCreate(4); + output_size->data[0] = input->dims->data[0]; + const int32* size_data = GetTensorData(size); + output_size->data[1] = size_data[0]; + output_size->data[2] = size_data[1]; + output_size->data[3] = input->dims->data[3]; + context->ResizeTensor(context, output, output_size); + TfLiteTensorRealloc(output->bytes, output); if (output->type == kTfLiteFloat32) { #define TF_LITE_RESIZE_BILINEAR(type) \ type::ResizeBilinear(GetTensorData(input), GetTensorDims(input), \ - output_size_data, GetTensorDims({1, 1, 1, 2}), \ + GetTensorData(size), GetTensorDims(size), \ GetTensorData(output), GetTensorDims(output)) if (kernel_type == kReference) { diff --git a/tensorflow/contrib/lite/kernels/resize_bilinear_test.cc b/tensorflow/contrib/lite/kernels/resize_bilinear_test.cc index 314a71e210d9b5ea75bb137ef228273ef48f28b5..2b1aaf654f87f435ec464b2cc1a63c77ba86ae5b 100644 --- a/tensorflow/contrib/lite/kernels/resize_bilinear_test.cc +++ b/tensorflow/contrib/lite/kernels/resize_bilinear_test.cc @@ -25,47 +25,52 @@ using ::testing::ElementsAreArray; class ResizeBilinearOpModel : public SingleOpModel { public: - ResizeBilinearOpModel(std::initializer_list input_shape, int new_height, - int new_width) { + ResizeBilinearOpModel(std::initializer_list input_shape) { input_ = AddInput(TensorType_FLOAT32); + size_ = AddInput(TensorType_INT32); output_ = AddOutput(TensorType_FLOAT32); - SetBuiltinOp( - BuiltinOperator_RESIZE_BILINEAR, BuiltinOptions_ResizeBilinearOptions, - CreateResizeBilinearOptions(builder_, new_height, new_width).Union()); - BuildInterpreter({input_shape}); + SetBuiltinOp(BuiltinOperator_RESIZE_BILINEAR, + BuiltinOptions_ResizeBilinearOptions, + CreateResizeBilinearOptions(builder_).Union()); + BuildInterpreter({input_shape, {2}}); } void SetInput(std::initializer_list data) { PopulateTensor(input_, data); } + void SetSize(std::initializer_list data) { PopulateTensor(size_, data); } std::vector GetOutput() { return ExtractVector(output_); } private: int input_; + int size_; int output_; }; TEST(ResizeBilinearOpTest, HorizontalResize) { - ResizeBilinearOpModel m({1, 1, 2, 1}, 1, 3); + ResizeBilinearOpModel m({1, 1, 2, 1}); m.SetInput({3, 6}); + m.SetSize({1, 3}); m.Invoke(); EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({3, 5, 6}))); } TEST(ResizeBilinearOpTest, VerticalResize) { - ResizeBilinearOpModel m({1, 2, 1, 1}, 3, 1); + ResizeBilinearOpModel m({1, 2, 1, 1}); m.SetInput({3, 9}); + m.SetSize({3, 1}); m.Invoke(); EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({3, 7, 9}))); } TEST(ResizeBilinearOpTest, TwoDimensionalResize) { - ResizeBilinearOpModel m({1, 2, 2, 1}, 3, 3); + ResizeBilinearOpModel m({1, 2, 2, 1}); m.SetInput({ 3, 6, // 9, 12 // }); + m.SetSize({3, 3}); m.Invoke(); EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ 3, 5, 6, // @@ -75,13 +80,14 @@ TEST(ResizeBilinearOpTest, TwoDimensionalResize) { } TEST(ResizeBilinearOpTest, TwoDimensionalResizeWithTwoBatches) { - ResizeBilinearOpModel m({2, 2, 2, 1}, 3, 3); + ResizeBilinearOpModel m({2, 2, 2, 1}); m.SetInput({ 3, 6, // 9, 12, // 4, 10, // 10, 16 // }); + m.SetSize({3, 3}); m.Invoke(); EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ 3, 5, 6, // @@ -94,11 +100,12 @@ TEST(ResizeBilinearOpTest, TwoDimensionalResizeWithTwoBatches) { } TEST(ResizeBilinearOpTest, ThreeDimensionalResize) { - ResizeBilinearOpModel m({1, 2, 2, 2}, 3, 3); + ResizeBilinearOpModel m({1, 2, 2, 2}); m.SetInput({ 3, 4, 6, 10, // 9, 10, 12, 16, // }); + m.SetSize({3, 3}); m.Invoke(); EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ 3, 4, 5, 8, 6, 10, // diff --git a/tensorflow/contrib/lite/kernels/squeeze_test.cc b/tensorflow/contrib/lite/kernels/squeeze_test.cc index 409227b626afdc8cbed66a27e300b320b59023f2..a8aab88357cacbb72784a4bc6e860aeb47783eb3 100644 --- a/tensorflow/contrib/lite/kernels/squeeze_test.cc +++ b/tensorflow/contrib/lite/kernels/squeeze_test.cc @@ -22,6 +22,7 @@ namespace tflite { namespace { using ::testing::ElementsAreArray; +using ::testing::IsEmpty; class BaseSqueezeOpModel : public SingleOpModel { public: @@ -103,6 +104,16 @@ TEST(FloatSqueezeOpTest, SqueezeNegativeAxis) { 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0})); } +TEST(FloatSqueezeOpTest, SqueezeAllDims) { + std::initializer_list data = {3.85}; + FloatSqueezeOpModel m({TensorType_FLOAT32, {1, 1, 1, 1, 1, 1, 1}}, + {TensorType_FLOAT32, {1}}, {}); + m.SetInput(data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), IsEmpty()); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({3.85})); +} + } // namespace } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/strided_slice.cc b/tensorflow/contrib/lite/kernels/strided_slice.cc new file mode 100644 index 0000000000000000000000000000000000000000..91ba4a9b7851c35a5138f4ccea307c810a4731a1 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/strided_slice.cc @@ -0,0 +1,256 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace strided_slice { + +enum KernelType { + kReference, + // TODO(soroosh): add kGenericOptimized +}; + +constexpr int kInputTensor = 0; +constexpr int kBeginTensor = 1; +constexpr int kEndTensor = 2; +constexpr int kStridesTensor = 3; +constexpr int kOutputTensor = 0; + +struct StridedSliceContext { + StridedSliceContext(TfLiteContext* context, TfLiteNode* node) { + params = reinterpret_cast(node->builtin_data); + input = GetInput(context, node, kInputTensor); + begin = GetInput(context, node, kBeginTensor); + end = GetInput(context, node, kEndTensor); + strides = GetInput(context, node, kStridesTensor); + output = GetOutput(context, node, kOutputTensor); + dims = NumDimensions(input); + } + TfLiteStridedSliceParams* params; + TfLiteTensor* input; + TfLiteTensor* begin; + TfLiteTensor* end; + TfLiteTensor* strides; + TfLiteTensor* output; + int dims; +}; + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 4); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + StridedSliceContext op_context(context, node); + + // Ensure validity of input tensor and its dimension + TF_LITE_ENSURE_EQ(context, NumDimensions(op_context.begin), 1); + TF_LITE_ENSURE_EQ(context, NumDimensions(op_context.end), 1); + TF_LITE_ENSURE_EQ(context, NumDimensions(op_context.strides), 1); + TF_LITE_ENSURE_EQ(context, op_context.input->type, op_context.output->type); + // Only INT32 begin/end/strides are supported + // TODO(soroosh) add support for INT64 + TF_LITE_ENSURE_EQ(context, op_context.begin->type, kTfLiteInt32); + TF_LITE_ENSURE_EQ(context, op_context.end->type, kTfLiteInt32); + TF_LITE_ENSURE_EQ(context, op_context.strides->type, kTfLiteInt32); + TF_LITE_ENSURE_MSG(context, op_context.dims <= 4, + "StridedSlice op only supports 1D-4D input arrays."); + + // TODO(soroosh): add the following missing functionalities + TF_LITE_ENSURE_MSG(context, op_context.params->ellipsis_mask == 0, + "ellipsis_mask is not implemented yet."); + TF_LITE_ENSURE_MSG(context, op_context.params->new_axis_mask == 0, + "new_axis_mask is not implemented yet."); + TF_LITE_ENSURE_MSG(context, op_context.params->shrink_axis_mask == 0, + "shrink_axis_mask is not implemented yet."); + + // TODO(soroosh): optimize for constant tensors to do allocation in Prepare + op_context.output->allocation_type = kTfLiteDynamic; + return kTfLiteOk; +} // namespace strided_slice + +// TODO(soroosh): consolidate with BytesRequired in interpreter.h +TfLiteStatus BytesRequired(TfLiteContext* context, TfLiteType type, + const int* dims, int dims_size, size_t* bytes) { + // TODO(aselle): Check for overflow here using overflow.h in TensorFlow + // MultiplyWithoutOverflow. + TF_LITE_ENSURE(context, bytes != nullptr); + size_t count = 1; + for (int k = 0; k < dims_size; k++) count *= dims[k]; + switch (type) { + case kTfLiteFloat32: + *bytes = sizeof(float) * count; + break; + case kTfLiteInt32: + *bytes = sizeof(int32_t) * count; + break; + case kTfLiteUInt8: + *bytes = sizeof(uint8_t) * count; + break; + case kTfLiteInt64: + *bytes = sizeof(int64_t) * count; + break; + default: + return kTfLiteError; + } + return kTfLiteOk; +} + +// Reverse order of bits in the mask to match the expected order in kernel +inline int ReverseMaskBits(int mask, int num_dimensions) { + int out = 0; + for (int dim = 0; dim < num_dimensions; dim++) { + out <<= 1; + out += (mask & 1); + mask >>= 1; + } + return out; +} + +// This Op only supports 1-4D cases and since we use the reference 4D +// implementation, the 1-3D tensors are mapped to 4D. +const int kMaxDim = 4; + +inline int32_t PositiveRemainder(int32_t dividend, int32_t divisor) { + return (divisor + (dividend % divisor)) % divisor; +} + +inline int32_t ClampedIndex(int32_t index, int dim, bool pos_stride) { + return pos_stride + ? (index >= dim ? dim + : PositiveRemainder( + std::min(std::max(index, -dim), dim), dim)) + : (index < -dim + ? -1 + : PositiveRemainder( + std::min(std::max(index, -dim), dim - 1), dim)); +} + +template +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + StridedSliceContext op_context(context, node); + + std::vector starts; + std::vector stops; + std::vector strides; + + // Determine size of output tensor and map indices + TfLiteIntArray* output_shape = TfLiteIntArrayCreate(op_context.dims); + for (int idx = op_context.dims - 1; idx >= 0; --idx) { + int dim = op_context.input->dims->data[idx]; + int32_t stride = GetTensorData(op_context.strides)[idx]; + TF_LITE_ENSURE_MSG(context, stride != 0, "stride value has to be non-zero"); + bool pos_stride = stride > 0; + + int32_t begin = + op_context.params->begin_mask & (1 << idx) + ? pos_stride ? 0 : dim - 1 + : ClampedIndex(GetTensorData(op_context.begin)[idx], dim, + pos_stride); + int32_t end = + op_context.params->end_mask & (1 << idx) + ? pos_stride ? dim : -1 + : ClampedIndex(GetTensorData(op_context.end)[idx], dim, + pos_stride); + + // This is valid for both positive and negative strides + output_shape->data[idx] = ceil((end - begin) / static_cast(stride)); + output_shape->data[idx] = + output_shape->data[idx] < 0 ? 0 : output_shape->data[idx]; + starts.emplace_back(begin); + stops.emplace_back(end); + strides.emplace_back(stride); + } + + for (int i = op_context.dims; i < kMaxDim; i++) { + starts.emplace_back(0); + stops.emplace_back(1); + strides.emplace_back(1); + } + + TF_LITE_ENSURE_STATUS( + context->ResizeTensor(context, op_context.output, output_shape)); + + size_t required_bytes; + TF_LITE_ENSURE_OK( + context, + BytesRequired(context, op_context.output->type, output_shape->data, + output_shape->size, &required_bytes)); + TfLiteTensorRealloc(required_bytes, op_context.output); + + op_context.params->begin_mask = + ReverseMaskBits(op_context.params->begin_mask, op_context.dims); + op_context.params->end_mask = + ReverseMaskBits(op_context.params->end_mask, op_context.dims); + +#define TF_LITE_STRIDED_SLICE(kernel_type, data_type) \ + kernel_type::StridedSlice( \ + GetTensorData(op_context.input), \ + GetTensorDims(op_context.input), op_context.params->begin_mask, \ + op_context.params->end_mask, starts, stops, strides, \ + GetTensorData(op_context.output), \ + GetTensorDims(op_context.output)) + + switch (op_context.input->type) { + case kTfLiteFloat32: + if (kernel_type == kReference) { + TF_LITE_STRIDED_SLICE(reference_ops, float); + } + break; + case kTfLiteInt32: + if (kernel_type == kReference) { + TF_LITE_STRIDED_SLICE(reference_ops, int32_t); + } + break; + case kTfLiteInt64: + if (kernel_type == kReference) { + TF_LITE_STRIDED_SLICE(reference_ops, int64_t); + } + break; + default: + context->ReportError(context, + "Type is currently not supported " + "by StridedSlice."); + return kTfLiteError; + } +#undef TF_LITE_STRIDED_SLICE + return kTfLiteOk; +} + +} // namespace strided_slice + +TfLiteRegistration* Register_STRIDED_SLICE_REF() { + static TfLiteRegistration r = { + nullptr, nullptr, strided_slice::Prepare, + strided_slice::Eval}; + return &r; +} + +// TODO(soroosh): add optimized +TfLiteRegistration* Register_STRIDED_SLICE() { + return Register_STRIDED_SLICE_REF(); +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/strided_slice_test.cc b/tensorflow/contrib/lite/kernels/strided_slice_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..cd4a364682c0e66b2ceec92c0b34461945caf779 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/strided_slice_test.cc @@ -0,0 +1,375 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class StridedSliceOpModel : public SingleOpModel { + public: + StridedSliceOpModel(std::initializer_list input_shape, + std::initializer_list begin_shape, + std::initializer_list end_shape, + std::initializer_list strides_shape, int begin_mask, + int end_mask, int ellipsis_mask, int new_axis_mask, + int shrink_axis_mask) { + input_ = AddInput(TensorType_FLOAT32); + begin_ = AddInput(TensorType_INT32); + end_ = AddInput(TensorType_INT32); + strides_ = AddInput(TensorType_INT32); + output_ = AddOutput(TensorType_FLOAT32); + SetBuiltinOp( + BuiltinOperator_STRIDED_SLICE, BuiltinOptions_StridedSliceOptions, + CreateStridedSliceOptions(builder_, begin_mask, end_mask, ellipsis_mask, + new_axis_mask, shrink_axis_mask) + .Union()); + BuildInterpreter({input_shape, begin_shape, end_shape, strides_shape}); + } + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + void SetBegin(std::initializer_list data) { + PopulateTensor(begin_, data); + } + void SetEnd(std::initializer_list data) { + PopulateTensor(end_, data); + } + void SetStrides(std::initializer_list data) { + PopulateTensor(strides_, data); + } + + std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetOutputShape() { return GetTensorShape(output_); } + + private: + int input_; + int begin_; + int end_; + int strides_; + int output_; +}; + +TEST(StridedSliceOpTest, UnsupportedInputSize) { + EXPECT_DEATH( + StridedSliceOpModel({2, 2, 2, 2, 2}, {5}, {5}, {5}, 0, 0, 0, 0, 0), + "StridedSlice op only supports 1D-4D input arrays."); +} + +TEST(StridedSliceOpTest, UnssupportedArgs) { + EXPECT_DEATH(StridedSliceOpModel({3, 2}, {2}, {2}, {2}, 0, 0, 1, 0, 0), + "ellipsis_mask is not implemented yet."); + EXPECT_DEATH(StridedSliceOpModel({3, 2}, {2}, {2}, {2}, 0, 0, 0, 1, 0), + "new_axis_mask is not implemented yet."); + EXPECT_DEATH(StridedSliceOpModel({3, 2}, {2}, {2}, {2}, 0, 0, 0, 0, 1), + "shrink_axis_mask is not implemented yet."); +} + +TEST(StridedSliceOpTest, In1D) { + StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0); + m.SetInput({1, 2, 3, 4}); + m.SetBegin({1}); + m.SetEnd({3}); + m.SetStrides({1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({2, 3})); +} + +TEST(StridedSliceOpTest, In1D_EmptyOutput) { + StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0); + m.SetInput({1, 2, 3, 4}); + m.SetBegin({10}); + m.SetEnd({3}); + m.SetStrides({1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({0})); +} + +TEST(StridedSliceOpTest, In1D_NegativeBegin) { + StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0); + m.SetInput({1, 2, 3, 4}); + m.SetBegin({-3}); + m.SetEnd({3}); + m.SetStrides({1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({2, 3})); +} + +TEST(StridedSliceOpTest, In1D_OutOfRangeBegin) { + StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0); + m.SetInput({1, 2, 3, 4}); + m.SetBegin({-5}); + m.SetEnd({3}); + m.SetStrides({1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3})); +} + +TEST(StridedSliceOpTest, In1D_NegativeEnd) { + StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0); + m.SetInput({1, 2, 3, 4}); + m.SetBegin({1}); + m.SetEnd({-2}); + m.SetStrides({1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({2})); +} + +TEST(StridedSliceOpTest, In1D_OutOfRangeEnd) { + StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0); + m.SetInput({1, 2, 3, 4}); + m.SetBegin({-3}); + m.SetEnd({5}); + m.SetStrides({1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({2, 3, 4})); +} + +TEST(StridedSliceOpTest, In1D_BeginMask) { + StridedSliceOpModel m({4}, {1}, {1}, {1}, 1, 0, 0, 0, 0); + m.SetInput({1, 2, 3, 4}); + m.SetBegin({1}); + m.SetEnd({3}); + m.SetStrides({1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3})); +} + +TEST(StridedSliceOpTest, In1D_NegativeBeginNegativeStride) { + StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0); + m.SetInput({1, 2, 3, 4}); + m.SetBegin({-2}); + m.SetEnd({-3}); + m.SetStrides({-1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({3})); +} + +TEST(StridedSliceOpTest, In1D_OutOfRangeBeginNegativeStride) { + StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0); + m.SetInput({1, 2, 3, 4}); + m.SetBegin({5}); + m.SetEnd({2}); + m.SetStrides({-1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({4})); +} + +TEST(StridedSliceOpTest, In1D_NegativeEndNegativeStride) { + StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0); + m.SetInput({1, 2, 3, 4}); + m.SetBegin({2}); + m.SetEnd({-4}); + m.SetStrides({-1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 2})); +} + +TEST(StridedSliceOpTest, In1D_OutOfRangeEndNegativeStride) { + StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0); + m.SetInput({1, 2, 3, 4}); + m.SetBegin({-3}); + m.SetEnd({-5}); + m.SetStrides({-1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({2, 1})); +} + +TEST(StridedSliceOpTest, In1D_EndMask) { + StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 1, 0, 0, 0); + m.SetInput({1, 2, 3, 4}); + m.SetBegin({1}); + m.SetEnd({3}); + m.SetStrides({1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({2, 3, 4})); +} +TEST(StridedSliceOpTest, In1D_NegStride) { + StridedSliceOpModel m({3}, {1}, {1}, {1}, 0, 0, 0, 0, 0); + m.SetInput({1, 2, 3}); + m.SetBegin({-1}); + m.SetEnd({-4}); + m.SetStrides({-1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 2, 1})); +} + +TEST(StridedSliceOpTest, In1D_EvenLenStride2) { + StridedSliceOpModel m({2}, {1}, {1}, {1}, 0, 0, 0, 0, 0); + m.SetInput({1, 2}); + m.SetBegin({0}); + m.SetEnd({2}); + m.SetStrides({2}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1})); +} +TEST(StridedSliceOpTest, In1D_OddLenStride2) { + StridedSliceOpModel m({3}, {1}, {1}, {1}, 0, 0, 0, 0, 0); + m.SetInput({1, 2, 3}); + m.SetBegin({0}); + m.SetEnd({3}); + m.SetStrides({2}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 3})); +} + +TEST(StridedSliceOpTest, In2D_Identity) { + StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 0); + m.SetInput({1, 2, 3, 4, 5, 6}); + m.SetBegin({0, 0}); + m.SetEnd({2, 3}); + m.SetStrides({1, 1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 3})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6})); +} +TEST(StridedSliceOpTest, In2D) { + StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 0); + m.SetInput({1, 2, 3, 4, 5, 6}); + m.SetBegin({1, 0}); + m.SetEnd({2, 2}); + m.SetStrides({1, 1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({4, 5})); +} + +TEST(StridedSliceOpTest, In2D_Stride2) { + StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 0); + m.SetInput({1, 2, 3, 4, 5, 6}); + m.SetBegin({0, 0}); + m.SetEnd({2, 3}); + m.SetStrides({2, 2}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 3})); +} + +TEST(StridedSliceOpTest, In2D_NegStride) { + StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 0); + m.SetInput({1, 2, 3, 4, 5, 6}); + m.SetBegin({1, -1}); + m.SetEnd({2, -4}); + m.SetStrides({2, -1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({6, 5, 4})); +} + +TEST(StridedSliceOpTest, In2D_BeginMask) { + StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 1, 0, 0, 0, 0); + m.SetInput({1, 2, 3, 4, 5, 6}); + m.SetBegin({1, 0}); + m.SetEnd({2, 2}); + m.SetStrides({1, 1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 4, 5})); +} + +TEST(StridedSliceOpTest, In2D_EndMask) { + StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 0, 2, 0, 0, 0); + m.SetInput({1, 2, 3, 4, 5, 6}); + m.SetBegin({1, 0}); + m.SetEnd({2, 2}); + m.SetStrides({1, 1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({4, 5, 6})); +} + +TEST(StridedSliceOpTest, In2D_NegStrideBeginMask) { + StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 2, 0, 0, 0, 0); + m.SetInput({1, 2, 3, 4, 5, 6}); + m.SetBegin({1, -2}); + m.SetEnd({2, -4}); + m.SetStrides({1, -1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({6, 5, 4})); +} +TEST(StridedSliceOpTest, In2D_NegStrideEndMask) { + StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 0, 2, 0, 0, 0); + m.SetInput({1, 2, 3, 4, 5, 6}); + m.SetBegin({1, -2}); + m.SetEnd({2, -3}); + m.SetStrides({1, -1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({5, 4})); +} + +TEST(StridedSliceOpTest, In3D_Identity) { + StridedSliceOpModel m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 0); + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + m.SetBegin({0, 0, 0}); + m.SetEnd({2, 3, 2}); + m.SetStrides({1, 1, 1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 3, 2})); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12})); +} + +TEST(StridedSliceOpTest, In3D_NegStride) { + StridedSliceOpModel m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 0); + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + m.SetBegin({-1, -1, -1}); + m.SetEnd({-3, -4, -3}); + m.SetStrides({-1, -1, -1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 3, 2})); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1})); +} +TEST(StridedSliceOpTest, In3D_Strided2) { + StridedSliceOpModel m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 0); + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + m.SetBegin({0, 0, 0}); + m.SetEnd({2, 3, 2}); + m.SetStrides({2, 2, 2}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2, 1})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 5})); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/test_util.cc b/tensorflow/contrib/lite/kernels/test_util.cc index b69f2b3e4bc66c94fdfc7ed4c244151be63a1711..3a58e7ec321f649a6cae4cc0969807c2c74c6529 100644 --- a/tensorflow/contrib/lite/kernels/test_util.cc +++ b/tensorflow/contrib/lite/kernels/test_util.cc @@ -49,7 +49,7 @@ std::vector> ArrayFloatNear(const std::vector& values, return matchers; } -int SingleOpModel::AddTensor(TensorData t) { +int SingleOpModel::AddTensor(TensorData t, std::initializer_list data) { int id = tensors_.size(); // This is slightly different depending on whether we are adding a @@ -78,8 +78,23 @@ int SingleOpModel::AddTensor(TensorData t) { builder_.CreateVector({t.zero_point})); } - tensors_.push_back(CreateTensor(builder_, builder_.CreateVector({}), - t.type, /*buffer=*/0, + int buffer_id = 0; + if (data.size()) { + // Initialize buffers list with empty buffer to allow for non-const tensors. + if (buffers_.empty()) { + buffers_.push_back(CreateBuffer(builder_, builder_.CreateVector({}))); + } + + // Add data as a Buffer to buffers list. + buffer_id = buffers_.size(); + auto data_buffer = + builder_.CreateVector(reinterpret_cast(data.begin()), + sizeof(int) * data.size()); + buffers_.push_back(CreateBuffer(builder_, data_buffer)); + } + + tensors_.push_back(CreateTensor(builder_, builder_.CreateVector(t.shape), + t.type, /*buffer=*/buffer_id, /*name=*/0, q_params)); tensor_data_[id] = t; @@ -88,7 +103,15 @@ int SingleOpModel::AddTensor(TensorData t) { } int SingleOpModel::AddInput(const TensorData& t) { - int id = AddTensor(t); + int id = AddTensor(t, {}); + inputs_.push_back(id); + return id; +} + +int SingleOpModel::AddConstInput(TensorType type, + std::initializer_list data, + std::initializer_list shape) { + int id = AddTensor(TensorData{type, shape}, data); inputs_.push_back(id); return id; } @@ -100,7 +123,7 @@ int SingleOpModel::AddNullInput() { } int SingleOpModel::AddOutput(const TensorData& t) { - int id = AddTensor(t); + int id = AddTensor(t, {}); outputs_.push_back(id); return id; } @@ -142,8 +165,7 @@ void SingleOpModel::BuildInterpreter( subgraphs.push_back(subgraph); auto subgraphs_flatbuffer = builder_.CreateVector(subgraphs); - std::vector> buffers_vec; - auto buffers = builder_.CreateVector(buffers_vec); + auto buffers = builder_.CreateVector(buffers_); auto description = builder_.CreateString("programmatic model"); builder_.Finish(CreateModel(builder_, TFLITE_SCHEMA_VERSION, opcodes, subgraphs_flatbuffer, description, buffers)); diff --git a/tensorflow/contrib/lite/kernels/test_util.h b/tensorflow/contrib/lite/kernels/test_util.h index 531c1366a87e20e140e779b767e29b1fd1111f97..cc445299ff9f0b75610c7ff38f28facbbbe5587d 100644 --- a/tensorflow/contrib/lite/kernels/test_util.h +++ b/tensorflow/contrib/lite/kernels/test_util.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_TEST_UTIL_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_TEST_UTIL_H_ +#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_TEST_UTIL_H_ +#define TENSORFLOW_CONTRIB_LITE_KERNELS_TEST_UTIL_H_ #include @@ -98,6 +98,10 @@ class SingleOpModel { int AddInput(TensorType type) { return AddInput(TensorData{type}); } int AddInput(const TensorData& t); + // Add a Tensor containing const data and return the tensor id. + int AddConstInput(TensorType type, std::initializer_list data, + std::initializer_list shape); + // Add a null input tensor (optional input) and return kOptionalTensor. int AddNullInput(); @@ -181,7 +185,7 @@ class SingleOpModel { std::unique_ptr interpreter_; private: - int AddTensor(TensorData t); + int AddTensor(TensorData t, std::initializer_list data); std::map tensor_data_; std::vector inputs_; @@ -189,6 +193,7 @@ class SingleOpModel { std::vector> tensors_; std::vector> opcodes_; std::vector> operators_; + std::vector> buffers_; std::map> custom_registrations_; }; @@ -197,4 +202,4 @@ template <> std::vector SingleOpModel::ExtractVector(int index); } // namespace tflite -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_TEST_UTIL_H_ +#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_TEST_UTIL_H_ diff --git a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc new file mode 100644 index 0000000000000000000000000000000000000000..9cdb58714edb5fee771fc45f3c53a570f8fb28d1 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc @@ -0,0 +1,527 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/activation_functor.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace unidirectional_sequence_lstm { + +// Input Tensors of size {max_time, n_batch, n_input} +constexpr int kInputTensor = 0; + +// Input weight tensors of size: {n_cell, n_input} +constexpr int kInputToInputWeightsTensor = 1; // Optional +constexpr int kInputToForgetWeightsTensor = 2; +constexpr int kInputToCellWeightsTensor = 3; +constexpr int kInputToOutputWeightsTensor = 4; + +// Recurrent weight tensors of size {n_cell, n_output} +constexpr int kRecurrentToInputWeightsTensor = 5; // Optional +constexpr int kRecurrentToForgetWeightsTensor = 6; +constexpr int kRecurrentToCellWeightsTensor = 7; +constexpr int kRecurrentToOutputWeightsTensor = 8; + +// Peephole weights tensors of size {n_cell}, representing a diagonal matrix. +constexpr int kCellToInputWeightsTensor = 9; // Optional +constexpr int kCellToForgetWeightsTensor = 10; // Optional +constexpr int kCellToOutputWeightsTensor = 11; // Optional + +// Gates bias tensors of size {n_cell} +constexpr int kInputGateBiasTensor = 12; // Optional +constexpr int kForgetGateBiasTensor = 13; +constexpr int kCellGateBiasTensor = 14; +constexpr int kOutputGateBiasTensor = 15; + +// Projection weight tensor of size {n_output, n_cell} +constexpr int kProjectionWeightsTensor = 16; // Optional +// Projection bias tensor of size {n_output} +constexpr int kProjectionBiasTensor = 17; // Optional + +// Output tensors. +constexpr int kScratchBufferTensor = 0; +constexpr int kOutputStateTensor = 1; +constexpr int kCellStateTensor = 2; +constexpr int kOutputTensor = 3; + +// Check that input tensor dimensions matches with each other. +TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context, + TfLiteNode* node, int n_input, + int n_output, int n_cell) { + auto* params = reinterpret_cast(node->builtin_data); + + // Making sure clipping parameters have valid values. + // == 0 means no clipping + // > 0 means clipping + TF_LITE_ENSURE(context, params->cell_clip >= 0); + TF_LITE_ENSURE(context, params->proj_clip >= 0); + + TfLiteTensor* input_to_input_weights = + GetOptionalInputTensor(context, node, kInputToInputWeightsTensor); + if (input_to_input_weights) { + TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->size, 2); + TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[0], n_cell); + TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[1], n_input); + } + + TfLiteTensor* input_to_forget_weights = + GetInput(context, node, kInputToForgetWeightsTensor); + TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->size, 2); + TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[0], n_cell); + TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[1], n_input); + + TfLiteTensor* input_to_cell_weights = + GetInput(context, node, kInputToCellWeightsTensor); + TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->size, 2); + TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[0], n_cell); + TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[1], n_input); + + TfLiteTensor* recurrent_to_input_weights = + GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor); + if (recurrent_to_input_weights) { + TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->size, 2); + TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->data[0], + n_cell); + TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->data[1], + n_output); + } + + TfLiteTensor* recurrent_to_forget_weights = + GetInput(context, node, kRecurrentToForgetWeightsTensor); + TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->size, 2); + TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[0], + n_cell); + TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[1], + n_output); + + TfLiteTensor* recurrent_to_cell_weights = + GetInput(context, node, kRecurrentToCellWeightsTensor); + TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->size, 2); + TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[0], n_cell); + TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[1], + n_output); + + // We make sure the input-gate's parameters are either both present (regular + // LSTM) or not at all (CIFG-LSTM). + const bool cifg_weights_all_or_none = + ((input_to_input_weights != nullptr) && + (recurrent_to_input_weights != nullptr)) || + ((input_to_input_weights == nullptr) && + (recurrent_to_input_weights == nullptr)); + TF_LITE_ENSURE(context, cifg_weights_all_or_none == true); + + TfLiteTensor* cell_to_input_weights = + GetOptionalInputTensor(context, node, kCellToInputWeightsTensor); + if (cell_to_input_weights) { + TF_LITE_ENSURE_EQ(context, cell_to_input_weights->dims->size, 1); + TF_LITE_ENSURE_EQ(context, cell_to_input_weights->dims->data[0], n_cell); + } + + TfLiteTensor* cell_to_forget_weights = + GetOptionalInputTensor(context, node, kCellToForgetWeightsTensor); + if (cell_to_forget_weights) { + TF_LITE_ENSURE_EQ(context, cell_to_forget_weights->dims->size, 1); + TF_LITE_ENSURE_EQ(context, cell_to_forget_weights->dims->data[0], n_cell); + } + + TfLiteTensor* cell_to_output_weights = + GetOptionalInputTensor(context, node, kCellToOutputWeightsTensor); + if (cell_to_output_weights) { + TF_LITE_ENSURE_EQ(context, cell_to_output_weights->dims->size, 1); + TF_LITE_ENSURE_EQ(context, cell_to_output_weights->dims->data[0], n_cell); + } + + // Making sure the peephole weights are there all or none. + const bool use_cifg = (input_to_input_weights == nullptr); + const bool peephole_weights_all_or_none = + ((cell_to_input_weights != nullptr || use_cifg) && + (cell_to_forget_weights != nullptr) && + (cell_to_output_weights != nullptr)) || + ((cell_to_input_weights == nullptr) && + (cell_to_forget_weights == nullptr) && + (cell_to_output_weights == nullptr)); + TF_LITE_ENSURE(context, peephole_weights_all_or_none == true); + + // Make sure the input gate bias is present only when not a CIFG-LSTM. + TfLiteTensor* input_gate_bias = + GetOptionalInputTensor(context, node, kInputGateBiasTensor); + if (use_cifg) { + TF_LITE_ENSURE_EQ(context, input_gate_bias, nullptr); + } else { + TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->size, 1); + TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->data[0], n_cell); + } + + TfLiteTensor* forget_gate_bias = + GetInput(context, node, kForgetGateBiasTensor); + TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->size, 1); + TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->data[0], n_cell); + + TfLiteTensor* cell_bias = GetInput(context, node, kCellGateBiasTensor); + TF_LITE_ENSURE_EQ(context, cell_bias->dims->size, 1); + TF_LITE_ENSURE_EQ(context, cell_bias->dims->data[0], n_cell); + + TfLiteTensor* output_gate_bias = + GetInput(context, node, kOutputGateBiasTensor); + TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->size, 1); + TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->data[0], n_cell); + + TfLiteTensor* projection_weights = + GetOptionalInputTensor(context, node, kProjectionWeightsTensor); + if (projection_weights) { + TF_LITE_ENSURE_EQ(context, projection_weights->dims->size, 2); + TF_LITE_ENSURE_EQ(context, projection_weights->dims->data[0], n_output); + TF_LITE_ENSURE_EQ(context, projection_weights->dims->data[1], n_cell); + } + + TfLiteTensor* projection_bias = + GetOptionalInputTensor(context, node, kProjectionBiasTensor); + if (projection_bias) { + TF_LITE_ENSURE_EQ(context, projection_bias->dims->size, 1); + TF_LITE_ENSURE_EQ(context, projection_bias->dims->data[0], n_output); + } + + // Making sure the projection tensors are consistent: + // 1) If projection weight is not present, then projection bias should not be + // present. + // 2) If projection weight is present, then projection bias is optional. + // TODO(ghodrat): make sure this is correct. + const bool projecton_tensors_consistent = + ((projection_weights != nullptr) || (projection_bias == nullptr)); + TF_LITE_ENSURE(context, projecton_tensors_consistent == true); + + return kTfLiteOk; +} + +// Resize the output, state and scratch tensors based on the sizes of the input +// tensors. Also check that the size of the input tensors match each other. +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + // Check we have all the inputs and outputs we need. + TF_LITE_ENSURE_EQ(context, node->inputs->size, 18); + TF_LITE_ENSURE_EQ(context, node->outputs->size, 4); + + // Inferring batch size, number of outputs and sequence length and + // number of cells from the input tensors. + TfLiteTensor* input = GetInput(context, node, kInputTensor); + TF_LITE_ENSURE(context, input->dims->size > 1); + const int max_time = input->dims->data[0]; + const int n_batch = input->dims->data[1]; + const int n_input = input->dims->data[2]; + + TfLiteTensor* input_to_output_weights = + GetInput(context, node, kInputToOutputWeightsTensor); + const int n_cell = input_to_output_weights->dims->data[0]; + TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->size, 2); + TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->data[1], n_input); + + TfLiteTensor* recurrent_to_output_weights = + GetInput(context, node, kRecurrentToOutputWeightsTensor); + TF_LITE_ENSURE_EQ(context, recurrent_to_output_weights->dims->size, 2); + TF_LITE_ENSURE_EQ(context, recurrent_to_output_weights->dims->data[0], + n_cell); + const int n_output = recurrent_to_output_weights->dims->data[1]; + + // Check that input tensor dimensions matches with each other. + CheckInputTensorDimensions(context, node, n_input, n_output, n_cell); + + // Get the pointer to output, state and scratch buffer tensors. + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + TfLiteTensor* output_state = GetOutput(context, node, kOutputStateTensor); + TfLiteTensor* cell_state = GetOutput(context, node, kCellStateTensor); + // TODO(ghodrat): Modify this as soon as we have a finalized method for + // scratch buffers. + TfLiteTensor* scratch_buffer = GetOutput(context, node, kScratchBufferTensor); + + // Resize the output and output_state tensors. + TfLiteIntArray* output_size = TfLiteIntArrayCreate(3); + output_size->data[0] = max_time; + output_size->data[1] = n_batch; + output_size->data[2] = n_output; + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, output, output_size)); + + TfLiteIntArray* output_state_size = TfLiteIntArrayCreate(2); + output_state_size->data[0] = n_batch; + output_state_size->data[1] = n_output; + TF_LITE_ENSURE_OK( + context, context->ResizeTensor(context, output_state, output_state_size)); + + // Resize the scratch buffer tensor. + TfLiteIntArray* cell_size = TfLiteIntArrayCreate(2); + cell_size->data[0] = n_batch; + cell_size->data[1] = n_cell; + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, cell_state, cell_size)); + + // Mark state tensors as persistent tensors. + output_state->allocation_type = kTfLiteArenaRwPersistent; + cell_state->allocation_type = kTfLiteArenaRwPersistent; + + TfLiteTensor* input_to_input_weights = + GetOptionalInputTensor(context, node, kInputToInputWeightsTensor); + const bool use_cifg = (input_to_input_weights == nullptr); + if (use_cifg) { + TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(2); + scratch_buffer_size->data[0] = n_batch; + // Reserving space for Cell, Forget, Output gates + scratch_buffer_size->data[1] = n_cell * 3; + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scratch_buffer, + scratch_buffer_size)); + } else { + TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(2); + scratch_buffer_size->data[0] = n_batch; + // Reserving space for Input, Cell, Forget, Output gates + scratch_buffer_size->data[1] = n_cell * 4; + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scratch_buffer, + scratch_buffer_size)); + } + return kTfLiteOk; +} + +// The LSTM Op engine. +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast(node->builtin_data); + TfLiteTensor* input = GetInput(context, node, kInputTensor); + + TfLiteTensor* input_to_input_weights = + GetOptionalInputTensor(context, node, kInputToInputWeightsTensor); + TfLiteTensor* input_to_forget_weights = + GetInput(context, node, kInputToForgetWeightsTensor); + TfLiteTensor* input_to_cell_weights = + GetInput(context, node, kInputToCellWeightsTensor); + TfLiteTensor* input_to_output_weights = + GetInput(context, node, kInputToOutputWeightsTensor); + + TfLiteTensor* recurrent_to_input_weights = + GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor); + TfLiteTensor* recurrent_to_forget_weights = + GetInput(context, node, kRecurrentToForgetWeightsTensor); + TfLiteTensor* recurrent_to_cell_weights = + GetInput(context, node, kRecurrentToCellWeightsTensor); + TfLiteTensor* recurrent_to_output_weights = + GetInput(context, node, kRecurrentToOutputWeightsTensor); + + TfLiteTensor* cell_to_input_weights = + GetOptionalInputTensor(context, node, kCellToInputWeightsTensor); + TfLiteTensor* cell_to_forget_weights = + GetOptionalInputTensor(context, node, kCellToForgetWeightsTensor); + TfLiteTensor* cell_to_output_weights = + GetOptionalInputTensor(context, node, kCellToOutputWeightsTensor); + + TfLiteTensor* input_gate_bias = + GetOptionalInputTensor(context, node, kInputGateBiasTensor); + TfLiteTensor* forget_gate_bias = + GetInput(context, node, kForgetGateBiasTensor); + TfLiteTensor* cell_bias = GetInput(context, node, kCellGateBiasTensor); + TfLiteTensor* output_gate_bias = + GetInput(context, node, kOutputGateBiasTensor); + + TfLiteTensor* projection_weights = + GetOptionalInputTensor(context, node, kProjectionWeightsTensor); + TfLiteTensor* projection_bias = + GetOptionalInputTensor(context, node, kProjectionBiasTensor); + + TfLiteTensor* output_state = GetOutput(context, node, kOutputStateTensor); + TfLiteTensor* cell_state = GetOutput(context, node, kCellStateTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + const int max_time = input->dims->data[0]; + const int n_batch = input->dims->data[1]; + const int n_input = input->dims->data[2]; + // n_cell and n_output will be the same size when there is no projection. + const int n_cell = input_to_output_weights->dims->data[0]; + const int n_output = recurrent_to_output_weights->dims->data[1]; + + // Since we have already checked that weights are all there or none, we can + // check the existense of only one to the get the condition. + const bool use_cifg = (input_to_input_weights == nullptr); + const bool use_peephole = (cell_to_output_weights != nullptr); + + // Index the scratch buffers pointers to the global scratch buffer. + TfLiteTensor* scratch_buffer = GetOutput(context, node, kScratchBufferTensor); + float* input_gate_scratch = nullptr; + float* cell_scratch = nullptr; + float* forget_gate_scratch = nullptr; + float* output_gate_scratch = nullptr; + if (use_cifg) { + cell_scratch = scratch_buffer->data.f; + forget_gate_scratch = scratch_buffer->data.f + n_cell * n_batch; + output_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch; + } else { + input_gate_scratch = scratch_buffer->data.f; + cell_scratch = scratch_buffer->data.f + n_cell * n_batch; + forget_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch; + output_gate_scratch = scratch_buffer->data.f + 3 * n_cell * n_batch; + } + + for (int t = 0; t < max_time; t++) { + const float* input_ptr_time = input->data.f + t * n_batch * n_input; + // Initialize scratch buffers with bias. + if (!use_cifg) { + tensor_utils::VectorBatchVectorAssign(input_gate_bias->data.f, n_cell, + n_batch, input_gate_scratch); + } + tensor_utils::VectorBatchVectorAssign(forget_gate_bias->data.f, n_cell, + n_batch, forget_gate_scratch); + tensor_utils::VectorBatchVectorAssign(cell_bias->data.f, n_cell, n_batch, + cell_scratch); + tensor_utils::VectorBatchVectorAssign(output_gate_bias->data.f, n_cell, + n_batch, output_gate_scratch); + + // For each batch and cell: compute input_weight * input. + if (!use_cifg) { + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + input_to_input_weights->data.f, n_cell, n_input, input_ptr_time, + n_batch, input_gate_scratch, /*result_stride=*/1); + } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + input_to_forget_weights->data.f, n_cell, n_input, input_ptr_time, + n_batch, forget_gate_scratch, /*result_stride=*/1); + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + input_to_cell_weights->data.f, n_cell, n_input, input_ptr_time, n_batch, + cell_scratch, /*result_stride=*/1); + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + input_to_output_weights->data.f, n_cell, n_input, input_ptr_time, + n_batch, output_gate_scratch, /*result_stride=*/1); + + // For each batch and cell: compute recurrent_weight * output_state. + if (!use_cifg) { + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + recurrent_to_input_weights->data.f, n_cell, n_output, + output_state->data.f, n_batch, input_gate_scratch, + /*result_stride=*/1); + } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + recurrent_to_forget_weights->data.f, n_cell, n_output, + output_state->data.f, n_batch, forget_gate_scratch, + /*result_stride=*/1); + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + recurrent_to_cell_weights->data.f, n_cell, n_output, + output_state->data.f, n_batch, cell_scratch, /*result_stride=*/1); + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + recurrent_to_output_weights->data.f, n_cell, n_output, + output_state->data.f, n_batch, output_gate_scratch, + /*result_stride=*/1); + + // For each batch and cell: update input gate. + if (!use_cifg) { + if (use_peephole) { + tensor_utils::VectorBatchVectorCwiseProductAccumulate( + cell_to_input_weights->data.f, n_cell, cell_state->data.f, n_batch, + input_gate_scratch); + } + tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch, + input_gate_scratch); + } + + // For each batch and cell: update forget gate. + if (use_peephole) { + tensor_utils::VectorBatchVectorCwiseProductAccumulate( + cell_to_forget_weights->data.f, n_cell, cell_state->data.f, n_batch, + forget_gate_scratch); + } + tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch, + forget_gate_scratch); + + // For each batch and cell: update the cell. + tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, + cell_state->data.f, n_batch * n_cell, + cell_state->data.f); + tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell, + params->activation, cell_scratch); + if (use_cifg) { + tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell, + forget_gate_scratch); + tensor_utils::VectorVectorCwiseProductAccumulate( + cell_scratch, forget_gate_scratch, n_batch * n_cell, + cell_state->data.f); + } else { + tensor_utils::VectorVectorCwiseProductAccumulate( + cell_scratch, input_gate_scratch, n_batch * n_cell, + cell_state->data.f); + } + if (params->cell_clip > 0.0) { + tensor_utils::ClipVector(cell_state->data.f, n_batch * n_cell, + params->cell_clip, cell_state->data.f); + } + + // For each batch and cell: update the output gate. + if (use_peephole) { + tensor_utils::VectorBatchVectorCwiseProductAccumulate( + cell_to_output_weights->data.f, n_cell, cell_state->data.f, n_batch, + output_gate_scratch); + } + tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell, + output_gate_scratch); + tensor_utils::ApplyActivationToVector(cell_state->data.f, n_batch * n_cell, + params->activation, cell_scratch); + tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch, + n_batch * n_cell, + output_gate_scratch); + + // For each batch: update the projection and output_state. + const bool use_projection_weight = (projection_weights != nullptr); + const bool use_projection_bias = (projection_bias != nullptr); + float* output_ptr_time = output->data.f + t * n_batch * n_output; + if (use_projection_weight) { + if (use_projection_bias) { + tensor_utils::VectorBatchVectorAssign(projection_bias->data.f, n_output, + n_batch, output_ptr_time); + } else { + tensor_utils::ZeroVector(output_ptr_time, n_batch * n_output); + } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + projection_weights->data.f, n_output, n_cell, output_gate_scratch, + n_batch, output_ptr_time, /*result_stride=*/1); + if (params->proj_clip > 0.0) { + tensor_utils::ClipVector(output_ptr_time, n_batch * n_output, + params->proj_clip, output_ptr_time); + } + } else { + tensor_utils::CopyVector(output_gate_scratch, n_batch * n_output, + output_ptr_time); + } + tensor_utils::CopyVector(output_ptr_time, n_batch * n_output, + output_state->data.f); + } + return kTfLiteOk; +} + +} // namespace unidirectional_sequence_lstm + +TfLiteRegistration* Register_UNIDIRECTIONAL_SEQUENCE_LSTM() { + static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr, + unidirectional_sequence_lstm::Prepare, + unidirectional_sequence_lstm::Eval}; + return &r; +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm_test.cc b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..93b635ae576e99854796d9fa997e5bf355b20534 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm_test.cc @@ -0,0 +1,1089 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Unit test for TFLite Sequential LSTM op. + +#include +#include +#include + +#include +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class UnidirectionalLSTMOpModel : public SingleOpModel { + public: + UnidirectionalLSTMOpModel(int n_batch, int n_input, int n_cell, int n_output, + int sequence_length, bool use_cifg, + bool use_peephole, bool use_projection_weights, + bool use_projection_bias, float cell_clip, + float proj_clip, + const std::vector>& input_shapes) + : n_batch_(n_batch), + n_input_(n_input), + n_cell_(n_cell), + n_output_(n_output), + sequence_length_(sequence_length) { + input_ = AddInput(TensorType_FLOAT32); + + if (use_cifg) { + input_to_input_weights_ = AddNullInput(); + } else { + input_to_input_weights_ = AddInput(TensorType_FLOAT32); + } + + input_to_forget_weights_ = AddInput(TensorType_FLOAT32); + input_to_cell_weights_ = AddInput(TensorType_FLOAT32); + input_to_output_weights_ = AddInput(TensorType_FLOAT32); + + if (use_cifg) { + recurrent_to_input_weights_ = AddNullInput(); + } else { + recurrent_to_input_weights_ = AddInput(TensorType_FLOAT32); + } + + recurrent_to_forget_weights_ = AddInput(TensorType_FLOAT32); + recurrent_to_cell_weights_ = AddInput(TensorType_FLOAT32); + recurrent_to_output_weights_ = AddInput(TensorType_FLOAT32); + + if (use_peephole) { + if (use_cifg) { + cell_to_input_weights_ = AddNullInput(); + } else { + cell_to_input_weights_ = AddInput(TensorType_FLOAT32); + } + cell_to_forget_weights_ = AddInput(TensorType_FLOAT32); + cell_to_output_weights_ = AddInput(TensorType_FLOAT32); + } else { + cell_to_input_weights_ = AddNullInput(); + cell_to_forget_weights_ = AddNullInput(); + cell_to_output_weights_ = AddNullInput(); + } + + if (use_cifg) { + input_gate_bias_ = AddNullInput(); + } else { + input_gate_bias_ = AddInput(TensorType_FLOAT32); + } + forget_gate_bias_ = AddInput(TensorType_FLOAT32); + cell_bias_ = AddInput(TensorType_FLOAT32); + output_gate_bias_ = AddInput(TensorType_FLOAT32); + + if (use_projection_weights) { + projection_weights_ = AddInput(TensorType_FLOAT32); + if (use_projection_bias) { + projection_bias_ = AddInput(TensorType_FLOAT32); + } else { + projection_bias_ = AddNullInput(); + } + } else { + projection_weights_ = AddNullInput(); + projection_bias_ = AddNullInput(); + } + + scratch_buffer_ = AddOutput(TensorType_FLOAT32); + // TODO(ghodrat): Modify these states when we have a permanent solution for + // persistent buffer. + output_state_ = AddOutput(TensorType_FLOAT32); + cell_state_ = AddOutput(TensorType_FLOAT32); + output_ = AddOutput(TensorType_FLOAT32); + + SetBuiltinOp(BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM, + BuiltinOptions_LSTMOptions, + CreateLSTMOptions(builder_, ActivationFunctionType_TANH, + cell_clip, proj_clip) + .Union()); + BuildInterpreter(input_shapes); + } + + void SetInputToInputWeights(std::initializer_list f) { + PopulateTensor(input_to_input_weights_, f); + } + + void SetInputToForgetWeights(std::initializer_list f) { + PopulateTensor(input_to_forget_weights_, f); + } + + void SetInputToCellWeights(std::initializer_list f) { + PopulateTensor(input_to_cell_weights_, f); + } + + void SetInputToOutputWeights(std::initializer_list f) { + PopulateTensor(input_to_output_weights_, f); + } + + void SetRecurrentToInputWeights(std::initializer_list f) { + PopulateTensor(recurrent_to_input_weights_, f); + } + + void SetRecurrentToForgetWeights(std::initializer_list f) { + PopulateTensor(recurrent_to_forget_weights_, f); + } + + void SetRecurrentToCellWeights(std::initializer_list f) { + PopulateTensor(recurrent_to_cell_weights_, f); + } + + void SetRecurrentToOutputWeights(std::initializer_list f) { + PopulateTensor(recurrent_to_output_weights_, f); + } + + void SetCellToInputWeights(std::initializer_list f) { + PopulateTensor(cell_to_input_weights_, f); + } + + void SetCellToForgetWeights(std::initializer_list f) { + PopulateTensor(cell_to_forget_weights_, f); + } + + void SetCellToOutputWeights(std::initializer_list f) { + PopulateTensor(cell_to_output_weights_, f); + } + + void SetInputGateBias(std::initializer_list f) { + PopulateTensor(input_gate_bias_, f); + } + + void SetForgetGateBias(std::initializer_list f) { + PopulateTensor(forget_gate_bias_, f); + } + + void SetCellBias(std::initializer_list f) { + PopulateTensor(cell_bias_, f); + } + + void SetOutputGateBias(std::initializer_list f) { + PopulateTensor(output_gate_bias_, f); + } + + void SetProjectionWeights(std::initializer_list f) { + PopulateTensor(projection_weights_, f); + } + + void SetProjectionBias(std::initializer_list f) { + PopulateTensor(projection_bias_, f); + } + + void ResetOutputState() { + const int zero_buffer_size = n_cell_ * n_batch_; + std::unique_ptr zero_buffer(new float[zero_buffer_size]); + memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float)); + PopulateTensor(output_state_, 0, zero_buffer.get(), + zero_buffer.get() + zero_buffer_size); + } + + void ResetCellState() { + const int zero_buffer_size = n_cell_ * n_batch_; + std::unique_ptr zero_buffer(new float[zero_buffer_size]); + memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float)); + PopulateTensor(cell_state_, 0, zero_buffer.get(), + zero_buffer.get() + zero_buffer_size); + } + + void SetInput(int offset, float* begin, float* end) { + PopulateTensor(input_, offset, begin, end); + } + + std::vector GetOutput() { return ExtractVector(output_); } + + int num_inputs() { return n_input_; } + int num_outputs() { return n_output_; } + int num_cells() { return n_cell_; } + int num_batches() { return n_batch_; } + int sequence_length() { return sequence_length_; } + + private: + int input_; + int input_to_input_weights_; + int input_to_forget_weights_; + int input_to_cell_weights_; + int input_to_output_weights_; + + int recurrent_to_input_weights_; + int recurrent_to_forget_weights_; + int recurrent_to_cell_weights_; + int recurrent_to_output_weights_; + + int cell_to_input_weights_; + int cell_to_forget_weights_; + int cell_to_output_weights_; + + int input_gate_bias_; + int forget_gate_bias_; + int cell_bias_; + int output_gate_bias_; + + int projection_weights_; + int projection_bias_; + + int output_; + int output_state_; + int cell_state_; + int scratch_buffer_; + + int n_batch_; + int n_input_; + int n_cell_; + int n_output_; + int sequence_length_; +}; + +TEST(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClipping) { + const int n_batch = 1; + const int n_input = 2; + // n_cell and n_output have the same size when there is no projection. + const int n_cell = 4; + const int n_output = 4; + const int sequence_length = 3; + + UnidirectionalLSTMOpModel lstm( + n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/false, + /*use_peephole=*/false, /*use_projection_weights=*/false, + /*use_projection_bias=*/false, /*cell_clip=*/0.0, /*proj_clip=*/0.0, + { + {sequence_length, n_batch, n_input}, // input tensor + + {n_cell, n_input}, // input_to_input_weight tensor + {n_cell, n_input}, // input_to_forget_weight tensor + {n_cell, n_input}, // input_to_cell_weight tensor + {n_cell, n_input}, // input_to_output_weight tensor + + {n_cell, n_output}, // recurrent_to_input_weight tensor + {n_cell, n_output}, // recurrent_to_forget_weight tensor + {n_cell, n_output}, // recurrent_to_cell_weight tensor + {n_cell, n_output}, // recurrent_to_output_weight tensor + + {0}, // cell_to_input_weight tensor + {0}, // cell_to_forget_weight tensor + {0}, // cell_to_output_weight tensor + + {n_cell}, // input_gate_bias tensor + {n_cell}, // forget_gate_bias tensor + {n_cell}, // cell_bias tensor + {n_cell}, // output_gate_bias tensor + + {0, 0}, // projection_weight tensor + {0}, // projection_bias tensor + }); + + lstm.SetInputToInputWeights({-0.45018822, -0.02338299, -0.0870589, + -0.34550029, 0.04266912, -0.15680569, + -0.34856534, 0.43890524}); + + lstm.SetInputToCellWeights({-0.50013041, 0.1370284, 0.11810488, 0.2013163, + -0.20583314, 0.44344562, 0.22077113, + -0.29909778}); + + lstm.SetInputToForgetWeights({0.09701663, 0.20334584, -0.50592935, + -0.31343272, -0.40032279, 0.44781327, + 0.01387155, -0.35593212}); + + lstm.SetInputToOutputWeights({-0.25065863, -0.28290087, 0.04613829, + 0.40525138, 0.44272184, 0.03897077, -0.1556896, + 0.19487578}); + + lstm.SetInputGateBias({0., 0., 0., 0.}); + + lstm.SetCellBias({0., 0., 0., 0.}); + + lstm.SetForgetGateBias({1., 1., 1., 1.}); + + lstm.SetOutputGateBias({0., 0., 0., 0.}); + + lstm.SetRecurrentToInputWeights( + {-0.0063535, -0.2042388, 0.31454784, -0.35746509, 0.28902304, 0.08183324, + -0.16555229, 0.02286911, -0.13566875, 0.03034258, 0.48091322, + -0.12528998, 0.24077177, -0.51332325, -0.33502164, 0.10629296}); + + lstm.SetRecurrentToCellWeights( + {-0.3407414, 0.24443203, -0.2078532, 0.26320225, 0.05695659, -0.00123841, + -0.4744786, -0.35869038, -0.06418842, -0.13502428, -0.501764, 0.22830659, + -0.46367589, 0.26016325, -0.03894562, -0.16368064}); + + lstm.SetRecurrentToForgetWeights( + {-0.48684245, -0.06655136, 0.42224967, 0.2112639, 0.27654213, 0.20864892, + -0.07646349, 0.45877004, 0.00141793, -0.14609534, 0.36447752, 0.09196436, + 0.28053468, 0.01560611, -0.20127171, -0.01140004}); + + lstm.SetRecurrentToOutputWeights( + {0.43385774, -0.17194885, 0.2718237, 0.09215671, 0.24107647, -0.39835793, + 0.18212086, 0.01301402, 0.48572797, -0.50656658, 0.20047462, -0.20607421, + -0.51818722, -0.15390486, 0.0468148, 0.39922136}); + + // Input should have n_input * sequence_length many values. + static float lstm_input[] = {2., 3., 3., 4., 1., 1.}; + static float lstm_golden_output[] = {-0.02973187, 0.1229473, 0.20885126, + -0.15358765, -0.03716109, 0.12507336, + 0.41193449, -0.20860538, -0.15053082, + 0.09120187, 0.24278517, -0.12222792}; + + // Resetting cell_state and output_state + lstm.ResetCellState(); + lstm.ResetOutputState(); + + float* batch0_start = lstm_input; + float* batch0_end = batch0_start + lstm.num_inputs() * lstm.sequence_length(); + + lstm.SetInput(0, batch0_start, batch0_end); + + lstm.Invoke(); + + float* golden_start = lstm_golden_output; + float* golden_end = + golden_start + lstm.num_outputs() * lstm.sequence_length(); + std::vector expected; + expected.insert(expected.end(), golden_start, golden_end); + EXPECT_THAT(lstm.GetOutput(), ElementsAreArray(ArrayFloatNear(expected))); +} + +TEST(LSTMOpTest, BlackBoxTestWithCifgWithPeepholeNoProjectionNoClipping) { + const int n_batch = 1; + const int n_input = 2; + // n_cell and n_output have the same size when there is no projection. + const int n_cell = 4; + const int n_output = 4; + const int sequence_length = 3; + + UnidirectionalLSTMOpModel lstm( + n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/true, + /*use_peephole=*/true, /*use_projection_weights=*/false, + /*use_projection_bias=*/false, /*cell_clip=*/0.0, /*proj_clip=*/0.0, + { + {sequence_length, n_batch, n_input}, // input tensor + + {0, 0}, // input_to_input_weight tensor + {n_cell, n_input}, // input_to_forget_weight tensor + {n_cell, n_input}, // input_to_cell_weight tensor + {n_cell, n_input}, // input_to_output_weight tensor + + {0, 0}, // recurrent_to_input_weight tensor + {n_cell, n_output}, // recurrent_to_forget_weight tensor + {n_cell, n_output}, // recurrent_to_cell_weight tensor + {n_cell, n_output}, // recurrent_to_output_weight tensor + + {0}, // cell_to_input_weight tensor + {n_cell}, // cell_to_forget_weight tensor + {n_cell}, // cell_to_output_weight tensor + + {0}, // input_gate_bias tensor + {n_cell}, // forget_gate_bias tensor + {n_cell}, // cell_bias tensor + {n_cell}, // output_gate_bias tensor + + {0, 0}, // projection_weight tensor + {0}, // projection_bias tensor + }); + + lstm.SetInputToCellWeights({-0.49770179, -0.27711356, -0.09624726, 0.05100781, + 0.04717243, 0.48944736, -0.38535351, + -0.17212132}); + + lstm.SetInputToForgetWeights({-0.55291498, -0.42866567, 0.13056988, + -0.3633365, -0.22755712, 0.28253698, 0.24407166, + 0.33826375}); + + lstm.SetInputToOutputWeights({0.10725588, -0.02335852, -0.55932593, + -0.09426838, -0.44257352, 0.54939759, + 0.01533556, 0.42751634}); + + lstm.SetCellBias({0., 0., 0., 0.}); + + lstm.SetForgetGateBias({1., 1., 1., 1.}); + + lstm.SetOutputGateBias({0., 0., 0., 0.}); + + lstm.SetRecurrentToCellWeights( + {0.54066205, -0.32668582, -0.43562764, -0.56094903, 0.42957711, + 0.01841056, -0.32764608, -0.33027974, -0.10826075, 0.20675004, + 0.19069612, -0.03026325, -0.54532051, 0.33003211, 0.44901288, + 0.21193194}); + + lstm.SetRecurrentToForgetWeights( + {-0.13832897, -0.0515101, -0.2359007, -0.16661474, -0.14340827, + 0.36986142, 0.23414481, 0.55899, 0.10798943, -0.41174671, 0.17751795, + -0.34484994, -0.35874045, -0.11352962, 0.27268326, 0.54058349}); + + lstm.SetRecurrentToOutputWeights( + {0.41613156, 0.42610586, -0.16495961, -0.5663873, 0.30579174, -0.05115908, + -0.33941799, 0.23364776, 0.11178309, 0.09481031, -0.26424935, 0.46261835, + 0.50248802, 0.26114327, -0.43736315, 0.33149987}); + + lstm.SetCellToForgetWeights( + {0.47485286, -0.51955009, -0.24458408, 0.31544167}); + lstm.SetCellToOutputWeights( + {-0.17135078, 0.82760304, 0.85573703, -0.77109635}); + + static float lstm_input[] = {2., 3., 3., 4., 1., 1.}; + static float lstm_golden_output[] = {-0.36444446, -0.00352185, 0.12886585, + -0.05163646, -0.42312205, -0.01218222, + 0.24201041, -0.08124574, -0.358325, + -0.04621704, 0.21641694, -0.06471302}; + + // Resetting cell_state and output_state + lstm.ResetCellState(); + lstm.ResetOutputState(); + + float* batch0_start = lstm_input; + float* batch0_end = batch0_start + lstm.num_inputs() * lstm.sequence_length(); + + lstm.SetInput(0, batch0_start, batch0_end); + + lstm.Invoke(); + + float* golden_start = lstm_golden_output; + float* golden_end = + golden_start + lstm.num_outputs() * lstm.sequence_length(); + std::vector expected; + expected.insert(expected.end(), golden_start, golden_end); + EXPECT_THAT(lstm.GetOutput(), ElementsAreArray(ArrayFloatNear(expected))); +} + +TEST(LSTMOpTest, BlackBoxTestWithPeepholeWithProjectionNoClipping) { + const int n_batch = 2; + const int n_input = 5; + const int n_cell = 20; + const int n_output = 16; + const int sequence_length = 4; + + UnidirectionalLSTMOpModel lstm( + n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/false, + /*use_peephole=*/true, /*use_projection_weights=*/true, + /*use_projection_bias=*/false, + /*cell_clip=*/0.0, /*proj_clip=*/0.0, + { + {sequence_length, n_batch, n_input}, // input tensor + + {n_cell, n_input}, // input_to_input_weight tensor + {n_cell, n_input}, // input_to_forget_weight tensor + {n_cell, n_input}, // input_to_cell_weight tensor + {n_cell, n_input}, // input_to_output_weight tensor + + {n_cell, n_output}, // recurrent_to_input_weight tensor + {n_cell, n_output}, // recurrent_to_forget_weight tensor + {n_cell, n_output}, // recurrent_to_cell_weight tensor + {n_cell, n_output}, // recurrent_to_output_weight tensor + + {n_cell}, // cell_to_input_weight tensor + {n_cell}, // cell_to_forget_weight tensor + {n_cell}, // cell_to_output_weight tensor + + {n_cell}, // input_gate_bias tensor + {n_cell}, // forget_gate_bias tensor + {n_cell}, // cell_bias tensor + {n_cell}, // output_gate_bias tensor + + {n_output, n_cell}, // projection_weight tensor + {0}, // projection_bias tensor + }); + + lstm.SetInputToInputWeights( + {0.021393683, 0.06124551, 0.046905167, -0.014657677, -0.03149463, + 0.09171803, 0.14647801, 0.10797193, -0.0057968358, 0.0019193048, + -0.2726754, 0.10154029, -0.018539885, 0.080349885, -0.10262385, + -0.022599787, -0.09121155, -0.008675967, -0.045206103, -0.0821282, + -0.008045952, 0.015478081, 0.055217247, 0.038719587, 0.044153627, + -0.06453243, 0.05031825, -0.046935108, -0.008164439, 0.014574226, + -0.1671009, -0.15519552, -0.16819797, -0.13971269, -0.11953059, + 0.25005487, -0.22790983, 0.009855087, -0.028140958, -0.11200698, + 0.11295408, -0.0035217577, 0.054485075, 0.05184695, 0.064711206, + 0.10989193, 0.11674786, 0.03490607, 0.07727357, 0.11390585, + -0.1863375, -0.1034451, -0.13945189, -0.049401227, -0.18767063, + 0.042483903, 0.14233552, 0.13832581, 0.18350165, 0.14545603, + -0.028545704, 0.024939531, 0.050929718, 0.0076203286, -0.0029723682, + -0.042484224, -0.11827596, -0.09171104, -0.10808628, -0.16327988, + -0.2273378, -0.0993647, -0.017155107, 0.0023917493, 0.049272764, + 0.0038534778, 0.054764505, 0.089753784, 0.06947234, 0.08014476, + -0.04544234, -0.0497073, -0.07135631, -0.048929106, -0.004042012, + -0.009284026, 0.018042054, 0.0036860977, -0.07427302, -0.11434604, + -0.018995456, 0.031487543, 0.012834908, 0.019977754, 0.044256654, + -0.39292613, -0.18519334, -0.11651281, -0.06809892, 0.011373677}); + + lstm.SetInputToForgetWeights( + {-0.0018401089, -0.004852237, 0.03698424, 0.014181704, 0.028273236, + -0.016726194, -0.05249759, -0.10204261, 0.00861066, -0.040979505, + -0.009899187, 0.01923892, -0.028177269, -0.08535103, -0.14585495, + 0.10662567, -0.01909731, -0.017883534, -0.0047269356, -0.045103323, + 0.0030784295, 0.076784775, 0.07463696, 0.094531395, 0.0814421, + -0.12257899, -0.033945758, -0.031303465, 0.045630626, 0.06843887, + -0.13492945, -0.012480007, -0.0811829, -0.07224499, -0.09628791, + 0.045100946, 0.0012300825, 0.013964662, 0.099372394, 0.02543059, + 0.06958324, 0.034257296, 0.0482646, 0.06267997, 0.052625068, + 0.12784666, 0.07077897, 0.025725935, 0.04165009, 0.07241905, + 0.018668644, -0.037377294, -0.06277783, -0.08833636, -0.040120605, + -0.011405586, -0.007808335, -0.010301386, -0.005102167, 0.027717464, + 0.05483423, 0.11449111, 0.11289652, 0.10939839, 0.13396506, + -0.08402166, -0.01901462, -0.044678304, -0.07720565, 0.014350063, + -0.11757958, -0.0652038, -0.08185733, -0.076754324, -0.092614375, + 0.10405491, 0.052960336, 0.035755895, 0.035839386, -0.012540553, + 0.036881298, 0.02913376, 0.03420159, 0.05448447, -0.054523353, + 0.02582715, 0.02327355, -0.011857179, -0.0011980024, -0.034641717, + -0.026125094, -0.17582615, -0.15923657, -0.27486774, -0.0006143371, + 0.0001771948, -8.470171e-05, 0.02651807, 0.045790765, 0.06956496}); + + lstm.SetInputToCellWeights( + {-0.04580283, -0.09549462, -0.032418985, -0.06454633, + -0.043528453, 0.043018587, -0.049152344, -0.12418144, + -0.078985475, -0.07596889, 0.019484362, -0.11434962, + -0.0074034138, -0.06314844, -0.092981495, 0.0062155537, + -0.025034338, -0.0028890965, 0.048929527, 0.06235075, + 0.10665918, -0.032036792, -0.08505916, -0.10843358, + -0.13002433, -0.036816437, -0.02130134, -0.016518239, + 0.0047691227, -0.0025825808, 0.066017866, 0.029991534, + -0.10652836, -0.1037554, -0.13056071, -0.03266643, + -0.033702414, -0.006473424, -0.04611692, 0.014419339, + -0.025174323, 0.0396852, 0.081777506, 0.06157468, + 0.10210095, -0.009658194, 0.046511717, 0.03603906, + 0.0069369148, 0.015960095, -0.06507666, 0.09551598, + 0.053568836, 0.06408714, 0.12835667, -0.008714329, + -0.20211966, -0.12093674, 0.029450472, 0.2849013, + -0.029227901, 0.1164364, -0.08560263, 0.09941786, + -0.036999565, -0.028842626, -0.0033637602, -0.017012902, + -0.09720865, -0.11193351, -0.029155117, -0.017936034, + -0.009768936, -0.04223324, -0.036159635, 0.06505112, + -0.021742892, -0.023377212, -0.07221364, -0.06430552, + 0.05453865, 0.091149814, 0.06387331, 0.007518393, + 0.055960953, 0.069779344, 0.046411168, 0.10509911, + 0.07463894, 0.0075130584, 0.012850982, 0.04555431, + 0.056955688, 0.06555285, 0.050801456, -0.009862683, + 0.00826772, -0.026555609, -0.0073611983, -0.0014897042}); + + lstm.SetInputToOutputWeights( + {-0.0998932, -0.07201956, -0.052803773, -0.15629593, -0.15001918, + -0.07650751, 0.02359855, -0.075155355, -0.08037709, -0.15093534, + 0.029517552, -0.04751393, 0.010350531, -0.02664851, -0.016839722, + -0.023121163, 0.0077019283, 0.012851257, -0.05040649, -0.0129761, + -0.021737747, -0.038305793, -0.06870586, -0.01481247, -0.001285394, + 0.10124236, 0.083122835, 0.053313006, -0.062235646, -0.075637154, + -0.027833903, 0.029774971, 0.1130802, 0.09218906, 0.09506135, + -0.086665764, -0.037162706, -0.038880914, -0.035832845, -0.014481564, + -0.09825003, -0.12048569, -0.097665586, -0.05287633, -0.0964047, + -0.11366429, 0.035777505, 0.13568819, 0.052451383, 0.050649304, + 0.05798951, -0.021852335, -0.099848844, 0.014740475, -0.078897946, + 0.04974699, 0.014160473, 0.06973932, 0.04964942, 0.033364646, + 0.08190124, 0.025535367, 0.050893165, 0.048514254, 0.06945813, + -0.078907564, -0.06707616, -0.11844508, -0.09986688, -0.07509403, + 0.06263226, 0.14925587, 0.20188436, 0.12098451, 0.14639415, + 0.0015017595, -0.014267382, -0.03417257, 0.012711468, 0.0028300495, + -0.024758482, -0.05098548, -0.0821182, 0.014225672, 0.021544158, + 0.08949725, 0.07505268, -0.0020780868, 0.04908258, 0.06476295, + -0.022907063, 0.027562456, 0.040185735, 0.019567577, -0.015598739, + -0.049097303, -0.017121866, -0.083368234, -0.02332002, -0.0840956}); + + lstm.SetInputGateBias( + {0.02234832, 0.14757581, 0.18176508, 0.10380666, 0.053110216, + -0.06928846, -0.13942584, -0.11816189, 0.19483899, 0.03652339, + -0.10250295, 0.036714908, -0.18426876, 0.036065217, 0.21810818, + 0.02383196, -0.043370757, 0.08690144, -0.04444982, 0.00030581196}); + + lstm.SetForgetGateBias({0.035185695, -0.042891346, -0.03032477, 0.23027696, + 0.11098921, 0.15378423, 0.09263801, 0.09790885, + 0.09508917, 0.061199076, 0.07665568, -0.015443159, + -0.03499149, 0.046190713, 0.08895977, 0.10899629, + 0.40694186, 0.06030037, 0.012413437, -0.06108739}); + + lstm.SetCellBias({-0.024379363, 0.0055531194, 0.23377132, 0.033463873, + -0.1483596, -0.10639995, -0.091433935, 0.058573797, + -0.06809782, -0.07889636, -0.043246906, -0.09829136, + -0.4279842, 0.034901652, 0.18797937, 0.0075234566, + 0.016178843, 0.1749513, 0.13975595, 0.92058027}); + + lstm.SetOutputGateBias( + {0.046159424, -0.0012809046, 0.03563469, 0.12648113, 0.027195795, + 0.35373217, -0.018957434, 0.008907322, -0.0762701, 0.12018895, + 0.04216877, 0.0022856654, 0.040952638, 0.3147856, 0.08225149, + -0.057416286, -0.14995944, -0.008040261, 0.13208859, 0.029760877}); + + lstm.SetRecurrentToInputWeights( + {-0.001374326, -0.078856036, 0.10672688, 0.029162422, + -0.11585556, 0.02557986, -0.13446963, -0.035785314, + -0.01244275, 0.025961924, -0.02337298, -0.044228926, + -0.055839065, -0.046598054, -0.010546039, -0.06900766, + 0.027239809, 0.022582639, -0.013296484, -0.05459212, + 0.08981, -0.045407712, 0.08682226, -0.06867011, + -0.14390695, -0.02916037, 0.000996957, 0.091420636, + 0.14283475, -0.07390571, -0.06402044, 0.062524505, + -0.093129106, 0.04860203, -0.08364217, -0.08119002, + 0.009352075, 0.22920375, 0.0016303885, 0.11583097, + -0.13732095, 0.012405723, -0.07551853, 0.06343048, + 0.12162708, -0.031923793, -0.014335606, 0.01790974, + -0.10650317, -0.0724401, 0.08554849, -0.05727212, + 0.06556731, -0.042729504, -0.043227166, 0.011683251, + -0.013082158, -0.029302018, -0.010899579, -0.062036745, + -0.022509435, -0.00964907, -0.01567329, 0.04260106, + -0.07787477, -0.11576462, 0.017356863, 0.048673786, + -0.017577527, -0.05527947, -0.082487635, -0.040137455, + -0.10820036, -0.04666372, 0.022746278, -0.07851417, + 0.01068115, 0.032956902, 0.022433773, 0.0026891115, + 0.08944216, -0.0685835, 0.010513544, 0.07228705, + 0.02032331, -0.059686817, -0.0005566496, -0.086984694, + 0.040414046, -0.1380399, 0.094208956, -0.05722982, + 0.012092817, -0.04989123, -0.086576, -0.003399834, + -0.04696032, -0.045747425, 0.10091314, 0.048676282, + -0.029037097, 0.031399418, -0.0040285117, 0.047237843, + 0.09504992, 0.041799378, -0.049185462, -0.031518843, + -0.10516937, 0.026374253, 0.10058866, -0.0033195973, + -0.041975245, 0.0073591834, 0.0033782164, -0.004325073, + -0.10167381, 0.042500053, -0.01447153, 0.06464186, + -0.017142897, 0.03312627, 0.009205989, 0.024138335, + -0.011337001, 0.035530265, -0.010912711, 0.0706555, + -0.005894094, 0.051841937, -0.1401738, -0.02351249, + 0.0365468, 0.07590991, 0.08838724, 0.021681072, + -0.10086113, 0.019608743, -0.06195883, 0.077335775, + 0.023646897, -0.095322326, 0.02233014, 0.09756986, + -0.048691444, -0.009579111, 0.07595467, 0.11480546, + -0.09801813, 0.019894179, 0.08502348, 0.004032281, + 0.037211012, 0.068537936, -0.048005626, -0.091520436, + -0.028379958, -0.01556313, 0.06554592, -0.045599163, + -0.01672207, -0.020169014, -0.011877351, -0.20212261, + 0.010889619, 0.0047078193, 0.038385306, 0.08540671, + -0.017140968, -0.0035865551, 0.016678626, 0.005633034, + 0.015963363, 0.00871737, 0.060130805, 0.028611384, + 0.10109069, -0.015060172, -0.07894427, 0.06401885, + 0.011584063, -0.024466386, 0.0047652307, -0.09041358, + 0.030737216, -0.0046374933, 0.14215417, -0.11823516, + 0.019899689, 0.006106124, -0.027092824, 0.0786356, + 0.05052217, -0.058925, -0.011402121, -0.024987547, + -0.0013661642, -0.06832946, -0.015667673, -0.1083353, + -0.00096863037, -0.06988685, -0.053350925, -0.027275559, + -0.033664223, -0.07978348, -0.025200296, -0.017207067, + -0.058403496, -0.055697463, 0.005798788, 0.12965427, + -0.062582195, 0.0013350133, -0.10482091, 0.0379771, + 0.072521195, -0.0029455067, -0.13797039, -0.03628521, + 0.013806405, -0.017858358, -0.01008298, -0.07700066, + -0.017081132, 0.019358726, 0.0027079724, 0.004635139, + 0.062634714, -0.02338735, -0.039547626, -0.02050681, + 0.03385117, -0.083611414, 0.002862572, -0.09421313, + 0.058618143, -0.08598433, 0.00972939, 0.023867095, + -0.053934585, -0.023203006, 0.07452513, -0.048767887, + -0.07314807, -0.056307215, -0.10433547, -0.06440842, + 0.04328182, 0.04389765, -0.020006588, -0.09076438, + -0.11652589, -0.021705797, 0.03345259, -0.010329105, + -0.025767034, 0.013057034, -0.07316461, -0.10145612, + 0.06358255, 0.18531723, 0.07759293, 0.12006465, + 0.1305557, 0.058638252, -0.03393652, 0.09622831, + -0.16253184, -2.4580743e-06, 0.079869635, -0.070196845, + -0.005644518, 0.06857898, -0.12598175, -0.035084512, + 0.03156317, -0.12794146, -0.031963028, 0.04692781, + 0.030070418, 0.0071660685, -0.095516115, -0.004643372, + 0.040170413, -0.062104587, -0.0037324072, 0.0554317, + 0.08184801, -0.019164372, 0.06791302, 0.034257166, + -0.10307039, 0.021943003, 0.046745934, 0.0790918, + -0.0265588, -0.007824208, 0.042546265, -0.00977924, + -0.0002440307, -0.017384544, -0.017990116, 0.12252321, + -0.014512694, -0.08251313, 0.08861942, 0.13589665, + 0.026351685, 0.012641483, 0.07466548, 0.044301085, + -0.045414884, -0.051112458, 0.03444247, -0.08502782, + -0.04106223, -0.028126027, 0.028473156, 0.10467447}); + + lstm.SetRecurrentToForgetWeights( + {-0.057784554, -0.026057621, -0.068447545, -0.022581743, + 0.14811787, 0.10826372, 0.09471067, 0.03987225, + -0.0039523416, 0.00030638507, 0.053185795, 0.10572994, + 0.08414449, -0.022036452, -0.00066928595, -0.09203576, + 0.032950465, -0.10985798, -0.023809856, 0.0021431844, + -0.02196096, -0.00326074, 0.00058621005, -0.074678116, + -0.06193199, 0.055729095, 0.03736828, 0.020123724, + 0.061878487, -0.04729229, 0.034919553, -0.07585433, + -0.04421272, -0.044019096, 0.085488975, 0.04058006, + -0.06890133, -0.030951202, -0.024628663, -0.07672815, + 0.034293607, 0.08556707, -0.05293577, -0.033561368, + -0.04899627, 0.0241671, 0.015736353, -0.095442444, + -0.029564252, 0.016493602, -0.035026584, 0.022337519, + -0.026871363, 0.004780428, 0.0077918363, -0.03601621, + 0.016435321, -0.03263031, -0.09543275, -0.047392778, + 0.013454138, 0.028934088, 0.01685226, -0.086110644, + -0.046250615, -0.01847454, 0.047608484, 0.07339695, + 0.034546845, -0.04881143, 0.009128804, -0.08802852, + 0.03761666, 0.008096139, -0.014454086, 0.014361001, + -0.023502491, -0.0011840804, -0.07607001, 0.001856849, + -0.06509276, -0.006021153, -0.08570962, -0.1451793, + 0.060212336, 0.055259194, 0.06974018, 0.049454916, + -0.027794661, -0.08077226, -0.016179763, 0.1169753, + 0.17213494, -0.0056326236, -0.053934924, -0.0124349, + -0.11520337, 0.05409887, 0.088759385, 0.0019655675, + 0.0042065294, 0.03881498, 0.019844765, 0.041858196, + -0.05695512, 0.047233116, 0.038937137, -0.06542224, + 0.014429736, -0.09719407, 0.13908425, -0.05379757, + 0.012321099, 0.082840554, -0.029899208, 0.044217527, + 0.059855383, 0.07711018, -0.045319796, 0.0948846, + -0.011724666, -0.0033288454, -0.033542685, -0.04764985, + -0.13873616, 0.040668588, 0.034832682, -0.015319203, + -0.018715994, 0.046002675, 0.0599172, -0.043107376, + 0.0294216, -0.002314414, -0.022424703, 0.0030315618, + 0.0014641669, 0.0029166266, -0.11878115, 0.013738511, + 0.12375372, -0.0006038222, 0.029104086, 0.087442465, + 0.052958444, 0.07558703, 0.04817258, 0.044462286, + -0.015213451, -0.08783778, -0.0561384, -0.003008196, + 0.047060397, -0.002058388, 0.03429439, -0.018839769, + 0.024734668, 0.024614193, -0.042046934, 0.09597743, + -0.0043254104, 0.04320769, 0.0064070094, -0.0019131786, + -0.02558259, -0.022822596, -0.023273505, -0.02464396, + -0.10991725, -0.006240552, 0.0074488563, 0.024044557, + 0.04383914, -0.046476185, 0.028658995, 0.060410924, + 0.050786525, 0.009452605, -0.0073054377, -0.024810238, + 0.0052906186, 0.0066939713, -0.0020913032, 0.014515517, + 0.015898481, 0.021362653, -0.030262267, 0.016587038, + -0.011442813, 0.041154444, -0.007631438, -0.03423484, + -0.010977775, 0.036152758, 0.0066366293, 0.11915515, + 0.02318443, -0.041350313, 0.021485701, -0.10906167, + -0.028218046, -0.00954771, 0.020531068, -0.11995105, + -0.03672871, 0.024019798, 0.014255957, -0.05221243, + -0.00661567, -0.04630967, 0.033188973, 0.10107534, + -0.014027541, 0.030796422, -0.10270911, -0.035999842, + 0.15443139, 0.07684145, 0.036571592, -0.035900835, + -0.0034699554, 0.06209149, 0.015920248, -0.031122351, + -0.03858649, 0.01849943, 0.13872518, 0.01503974, + 0.069941424, -0.06948533, -0.0088794185, 0.061282158, + -0.047401894, 0.03100163, -0.041533746, -0.10430945, + 0.044574402, -0.01425562, -0.024290353, 0.034563623, + 0.05866852, 0.023947537, -0.09445152, 0.035450947, + 0.02247216, -0.0042998926, 0.061146557, -0.10250651, + 0.020881841, -0.06747029, 0.10062043, -0.0023941975, + 0.03532124, -0.016341697, 0.09685456, -0.016764693, + 0.051808182, 0.05875331, -0.04536488, 0.001626336, + -0.028892258, -0.01048663, -0.009793449, -0.017093895, + 0.010987891, 0.02357273, -0.00010856845, 0.0099760275, + -0.001845119, -0.03551521, 0.0018358806, 0.05763657, + -0.01769146, 0.040995963, 0.02235177, -0.060430344, + 0.11475477, -0.023854522, 0.10071741, 0.0686208, + -0.014250481, 0.034261297, 0.047418304, 0.08562733, + -0.030519066, 0.0060542435, 0.014653856, -0.038836084, + 0.04096551, 0.032249358, -0.08355519, -0.026823482, + 0.056386515, -0.010401743, -0.028396193, 0.08507674, + 0.014410365, 0.020995233, 0.17040324, 0.11511526, + 0.02459721, 0.0066619175, 0.025853224, -0.023133837, + -0.081302024, 0.017264642, -0.009585969, 0.09491168, + -0.051313367, 0.054532815, -0.014298593, 0.10657464, + 0.007076659, 0.10964551, 0.0409152, 0.008275321, + -0.07283536, 0.07937492, 0.04192024, -0.1075027}); + + lstm.SetRecurrentToCellWeights( + {-0.037322544, 0.018592842, 0.0056175636, -0.06253426, + 0.055647098, -0.05713207, -0.05626563, 0.005559383, + 0.03375411, -0.025757805, -0.088049285, 0.06017052, + -0.06570978, 0.007384076, 0.035123326, -0.07920549, + 0.053676967, 0.044480428, -0.07663568, 0.0071805613, + 0.08089997, 0.05143358, 0.038261272, 0.03339287, + -0.027673481, 0.044746667, 0.028349208, 0.020090483, + -0.019443132, -0.030755889, -0.0040000007, 0.04465846, + -0.021585021, 0.0031670958, 0.0053199246, -0.056117613, + -0.10893326, 0.076739706, -0.08509834, -0.027997585, + 0.037871376, 0.01449768, -0.09002357, -0.06111149, + -0.046195522, 0.0422062, -0.005683705, -0.1253618, + -0.012925729, -0.04890792, 0.06985068, 0.037654128, + 0.03398274, -0.004781977, 0.007032333, -0.031787455, + 0.010868644, -0.031489216, 0.09525667, 0.013939797, + 0.0058680447, 0.0167067, 0.02668468, -0.04797466, + -0.048885044, -0.12722108, 0.035304096, 0.06554885, + 0.00972396, -0.039238118, -0.05159735, -0.11329045, + 0.1613692, -0.03750952, 0.06529313, -0.071974665, + -0.11769596, 0.015524369, -0.0013754242, -0.12446318, + 0.02786344, -0.014179351, 0.005264273, 0.14376344, + 0.015983658, 0.03406988, -0.06939408, 0.040699873, + 0.02111075, 0.09669095, 0.041345075, -0.08316494, + -0.07684199, -0.045768797, 0.032298047, -0.041805092, + 0.0119405, 0.0061010392, 0.12652606, 0.0064572375, + -0.024950314, 0.11574242, 0.04508852, -0.04335324, + 0.06760663, -0.027437469, 0.07216407, 0.06977076, + -0.05438599, 0.034033038, -0.028602652, 0.05346137, + 0.043184172, -0.037189785, 0.10420091, 0.00882477, + -0.054019816, -0.074273005, -0.030617684, -0.0028467078, + 0.024302477, -0.0038869337, 0.005332455, 0.0013399826, + 0.04361412, -0.007001822, 0.09631092, -0.06702025, + -0.042049985, -0.035070654, -0.04103342, -0.10273396, + 0.0544271, 0.037184782, -0.13150354, -0.0058036847, + -0.008264958, 0.042035464, 0.05891794, 0.029673764, + 0.0063542654, 0.044788733, 0.054816857, 0.062257513, + -0.00093483756, 0.048938446, -0.004952862, -0.007730018, + -0.04043371, -0.017094059, 0.07229206, -0.023670016, + -0.052195564, -0.025616996, -0.01520939, 0.045104615, + -0.007376126, 0.003533447, 0.006570588, 0.056037236, + 0.12436656, 0.051817212, 0.028532185, -0.08686856, + 0.11868599, 0.07663395, -0.07323171, 0.03463402, + -0.050708205, -0.04458982, -0.11590894, 0.021273347, + 0.1251325, -0.15313013, -0.12224372, 0.17228661, + 0.023029093, 0.086124025, 0.006445803, -0.03496501, + 0.028332196, 0.04449512, -0.042436164, -0.026587414, + -0.006041347, -0.09292539, -0.05678812, 0.03897832, + 0.09465633, 0.008115513, -0.02171956, 0.08304309, + 0.071401566, 0.019622514, 0.032163795, -0.004167056, + 0.02295182, 0.030739572, 0.056506045, 0.004612461, + 0.06524936, 0.059999723, 0.046395954, -0.0045512207, + -0.1335546, -0.030136576, 0.11584653, -0.014678886, + 0.0020118146, -0.09688814, -0.0790206, 0.039770417, + -0.0329582, 0.07922767, 0.029322514, 0.026405897, + 0.04207835, -0.07073373, 0.063781224, 0.0859677, + -0.10925287, -0.07011058, 0.048005477, 0.03438226, + -0.09606514, -0.006669445, -0.043381985, 0.04240257, + -0.06955775, -0.06769346, 0.043903265, -0.026784198, + -0.017840602, 0.024307009, -0.040079936, -0.019946516, + 0.045318738, -0.12233574, 0.026170589, 0.0074471775, + 0.15978073, 0.10185836, 0.10298046, -0.015476589, + -0.039390966, -0.072174534, 0.0739445, -0.1211869, + -0.0347889, -0.07943156, 0.014809798, -0.12412325, + -0.0030663363, 0.039695457, 0.0647603, -0.08291318, + -0.018529687, -0.004423833, 0.0037507233, 0.084633216, + -0.01514876, -0.056505352, -0.012800942, -0.06994386, + 0.012962922, -0.031234352, 0.07029052, 0.016418684, + 0.03618972, 0.055686004, -0.08663945, -0.017404709, + -0.054761406, 0.029065743, 0.052404847, 0.020238016, + 0.0048197987, -0.0214882, 0.07078733, 0.013016777, + 0.06262858, 0.009184685, 0.020785125, -0.043904778, + -0.0270329, -0.03299152, -0.060088247, -0.015162964, + -0.001828936, 0.12642565, -0.056757294, 0.013586685, + 0.09232601, -0.035886683, 0.06000002, 0.05229691, + -0.052580316, -0.082029596, -0.010794592, 0.012947712, + -0.036429964, -0.085508935, -0.13127148, -0.017744139, + 0.031502828, 0.036232427, -0.031581745, 0.023051167, + -0.05325106, -0.03421577, 0.028793324, -0.034633752, + -0.009881397, -0.043551125, -0.018609839, 0.0019097115, + -0.008799762, 0.056595087, 0.0022273948, 0.055752404}); + + lstm.SetRecurrentToOutputWeights({ + 0.025825322, -0.05813119, 0.09495884, -0.045984812, -0.01255415, + -0.0026479573, -0.08196161, -0.054914974, -0.0046604523, -0.029587349, + -0.044576716, -0.07480124, -0.082868785, 0.023254942, 0.027502948, + -0.0039728214, -0.08683098, -0.08116779, -0.014675607, -0.037924774, + -0.023314456, -0.007401714, -0.09255757, 0.029460307, -0.08829125, + -0.005139627, -0.08989442, -0.0555066, 0.13596267, -0.025062224, + -0.048351806, -0.03850004, 0.07266485, -0.022414139, 0.05940088, + 0.075114764, 0.09597592, -0.010211725, -0.0049794707, -0.011523867, + -0.025980417, 0.072999895, 0.11091378, -0.081685916, 0.014416728, + 0.043229222, 0.034178585, -0.07530371, 0.035837382, -0.085607, + -0.007721233, -0.03287832, -0.043848954, -0.06404588, -0.06632928, + -0.073643476, 0.008214239, -0.045984086, 0.039764922, 0.03474462, + 0.060612556, -0.080590084, 0.049127717, 0.04151091, -0.030063879, + 0.008801774, -0.023021035, -0.019558564, 0.05158114, -0.010947698, + -0.011825728, 0.0075720972, 0.0699727, -0.0039981045, 0.069350146, + 0.08799282, 0.016156472, 0.035502106, 0.11695009, 0.006217345, + 0.13392477, -0.037875112, 0.025745004, 0.08940699, -0.00924166, + 0.0046702605, -0.036598757, -0.08811812, 0.10522024, -0.032441203, + 0.008176899, -0.04454919, 0.07058152, 0.0067963637, 0.039206743, + 0.03259838, 0.03725492, -0.09515802, 0.013326398, -0.052055415, + -0.025676316, 0.03198509, -0.015951829, -0.058556724, 0.036879618, + 0.043357447, 0.028362012, -0.05908629, 0.0059240665, -0.04995891, + -0.019187413, 0.0276265, -0.01628143, 0.0025863599, 0.08800015, + 0.035250366, -0.022165963, -0.07328642, -0.009415526, -0.07455109, + 0.11690406, 0.0363299, 0.07411125, 0.042103454, -0.009660886, + 0.019076364, 0.018299393, -0.046004917, 0.08891175, 0.0431396, + -0.026327137, -0.051502608, 0.08979574, -0.051670972, 0.04940282, + -0.07491107, -0.021240504, 0.022596184, -0.034280192, 0.060163025, + -0.058211457, -0.051837247, -0.01349775, -0.04639988, -0.035936575, + -0.011681591, 0.064818054, 0.0073146066, -0.021745546, -0.043124277, + -0.06471268, -0.07053354, -0.029321948, -0.05330136, 0.016933719, + -0.053782392, 0.13747959, -0.1361751, -0.11569455, 0.0033329215, + 0.05693899, -0.053219706, 0.063698, 0.07977434, -0.07924483, + 0.06936997, 0.0034815092, -0.007305279, -0.037325785, -0.07251102, + -0.033633437, -0.08677009, 0.091591336, -0.14165086, 0.021752775, + 0.019683983, 0.0011612234, -0.058154266, 0.049996935, 0.0288841, + -0.0024567875, -0.14345716, 0.010955264, -0.10234828, 0.1183656, + -0.0010731248, -0.023590032, -0.072285876, -0.0724771, -0.026382286, + -0.0014920527, 0.042667855, 0.0018776858, 0.02986552, 0.009814309, + 0.0733756, 0.12289186, 0.018043943, -0.0458958, 0.049412545, + 0.033632483, 0.05495232, 0.036686596, -0.013781798, -0.010036754, + 0.02576849, -0.08307328, 0.010112348, 0.042521734, -0.05869831, + -0.071689695, 0.03876447, -0.13275425, -0.0352966, -0.023077697, + 0.10285965, 0.084736146, 0.15568255, -0.00040734606, 0.027835453, + -0.10292561, -0.032401145, 0.10053256, -0.026142767, -0.08271222, + -0.0030240538, -0.016368777, 0.1070414, 0.042672627, 0.013456989, + -0.0437609, -0.022309763, 0.11576483, 0.04108048, 0.061026827, + -0.0190714, -0.0869359, 0.037901703, 0.0610107, 0.07202949, + 0.01675338, 0.086139716, -0.08795751, -0.014898893, -0.023771819, + -0.01965048, 0.007955471, -0.043740474, 0.03346837, -0.10549954, + 0.090567775, 0.042013682, -0.03176985, 0.12569028, -0.02421228, + -0.029526481, 0.023851605, 0.031539805, 0.05292009, -0.02344001, + -0.07811758, -0.08834428, 0.10094801, 0.16594367, -0.06861939, + -0.021256343, -0.041093912, -0.06669611, 0.035498552, 0.021757556, + -0.09302526, -0.015403468, -0.06614931, -0.051798206, -0.013874718, + 0.03630673, 0.010412845, -0.08077351, 0.046185967, 0.0035662893, + 0.03541868, -0.094149634, -0.034814864, 0.003128424, -0.020674974, + -0.03944324, -0.008110165, -0.11113267, 0.08484226, 0.043586485, + 0.040582247, 0.0968012, -0.065249965, -0.028036479, 0.0050708856, + 0.0017462453, 0.0326779, 0.041296225, 0.09164146, -0.047743853, + -0.015952192, -0.034451712, 0.084197424, -0.05347844, -0.11768019, + 0.085926116, -0.08251791, -0.045081906, 0.0948852, 0.068401024, + 0.024856757, 0.06978981, -0.057309967, -0.012775832, -0.0032452994, + 0.01977615, -0.041040014, -0.024264973, 0.063464895, 0.05431621, + }); + + lstm.SetCellToInputWeights( + {0.040369894, 0.030746894, 0.24704495, 0.018586371, -0.037586458, + -0.15312155, -0.11812848, -0.11465643, 0.20259799, 0.11418174, + -0.10116027, -0.011334949, 0.12411352, -0.076769054, -0.052169047, + 0.21198851, -0.38871562, -0.09061183, -0.09683246, -0.21929175}); + + lstm.SetCellToForgetWeights( + {-0.01998659, -0.15568835, -0.24248174, -0.012770197, 0.041331276, + -0.072311886, -0.052123554, -0.0066330447, -0.043891653, 0.036225766, + -0.047248036, 0.021479502, 0.033189066, 0.11952997, -0.020432774, + 0.64658105, -0.06650122, -0.03467612, 0.095340036, 0.23647355}); + + lstm.SetCellToOutputWeights( + {0.08286371, -0.08261836, -0.51210177, 0.002913762, 0.17764764, + -0.5495371, -0.08460716, -0.24552552, 0.030037103, 0.04123544, + -0.11940523, 0.007358328, 0.1890978, 0.4833202, -0.34441817, + 0.36312827, -0.26375428, 0.1457655, -0.19724406, 0.15548733}); + + lstm.SetProjectionWeights( + {-0.009802181, 0.09401916, 0.0717386, -0.13895074, 0.09641832, + 0.060420845, 0.08539281, 0.054285463, 0.061395317, 0.034448683, + -0.042991187, 0.019801661, -0.16840284, -0.015726732, -0.23041931, + -0.024478018, -0.10959692, -0.013875541, 0.18600968, -0.061274476, + 0.0138165, -0.08160894, -0.07661644, 0.032372914, 0.16169067, + 0.22465782, -0.03993472, -0.004017731, 0.08633481, -0.28869787, + 0.08682067, 0.17240396, 0.014975425, 0.056431185, 0.031037588, + 0.16702051, 0.0077946745, 0.15140012, 0.29405436, 0.120285, + -0.188994, -0.027265169, 0.043389652, -0.022061434, 0.014777949, + -0.20203483, 0.094781205, 0.19100232, 0.13987629, -0.036132768, + -0.06426278, -0.05108664, 0.13221376, 0.009441198, -0.16715929, + 0.15859416, -0.040437475, 0.050779544, -0.022187516, 0.012166504, + 0.027685808, -0.07675938, -0.0055694645, -0.09444123, 0.0046453946, + 0.050794356, 0.10770313, -0.20790008, -0.07149004, -0.11425117, + 0.008225835, -0.035802525, 0.14374903, 0.15262283, 0.048710253, + 0.1847461, -0.007487823, 0.11000021, -0.09542012, 0.22619456, + -0.029149994, 0.08527916, 0.009043713, 0.0042746216, 0.016261552, + 0.022461696, 0.12689082, -0.043589946, -0.12035478, -0.08361797, + -0.050666027, -0.1248618, -0.1275799, -0.071875185, 0.07377272, + 0.09944291, -0.18897448, -0.1593054, -0.06526116, -0.040107165, + -0.004618631, -0.067624845, -0.007576253, 0.10727444, 0.041546922, + -0.20424393, 0.06907816, 0.050412357, 0.00724631, 0.039827548, + 0.12449835, 0.10747581, 0.13708383, 0.09134148, -0.12617786, + -0.06428341, 0.09956831, 0.1208086, -0.14676677, -0.0727722, + 0.1126304, 0.010139365, 0.015571211, -0.038128063, 0.022913318, + -0.042050496, 0.16842307, -0.060597885, 0.10531834, -0.06411776, + -0.07451711, -0.03410368, -0.13393489, 0.06534304, 0.003620307, + 0.04490757, 0.05970546, 0.05197996, 0.02839995, 0.10434969, + -0.013699693, -0.028353551, -0.07260381, 0.047201227, -0.024575593, + -0.036445823, 0.07155557, 0.009672501, -0.02328883, 0.009533515, + -0.03606021, -0.07421458, -0.028082801, -0.2678904, -0.13221288, + 0.18419984, -0.13012612, -0.014588381, -0.035059117, -0.04824723, + 0.07830115, -0.056184657, 0.03277091, 0.025466874, 0.14494097, + -0.12522776, -0.098633975, -0.10766018, -0.08317623, 0.08594209, + 0.07749552, 0.039474737, 0.1776665, -0.07409566, -0.0477268, + 0.29323658, 0.10801441, 0.1154011, 0.013952499, 0.10739139, + 0.10708251, -0.051456142, 0.0074137426, -0.10430189, 0.10034707, + 0.045594677, 0.0635285, -0.0715442, -0.089667566, -0.10811871, + 0.00026344223, 0.08298446, -0.009525053, 0.006585689, -0.24567553, + -0.09450807, 0.09648481, 0.026996298, -0.06419476, -0.04752702, + -0.11063944, -0.23441927, -0.17608605, -0.052156363, 0.067035615, + 0.19271925, -0.0032889997, -0.043264326, 0.09663576, -0.057112187, + -0.10100678, 0.0628376, 0.04447668, 0.017961001, -0.10094388, + -0.10190601, 0.18335468, 0.10494553, -0.052095775, -0.0026118709, + 0.10539724, -0.04383912, -0.042349473, 0.08438151, -0.1947263, + 0.02251204, 0.11216432, -0.10307853, 0.17351969, -0.039091777, + 0.08066188, -0.00561982, 0.12633002, 0.11335965, -0.0088127935, + -0.019777594, 0.06864014, -0.059751723, 0.016233567, -0.06894641, + -0.28651384, -0.004228674, 0.019708522, -0.16305895, -0.07468996, + -0.0855457, 0.099339016, -0.07580735, -0.13775392, 0.08434318, + 0.08330512, -0.12131499, 0.031935584, 0.09180414, -0.08876437, + -0.08049874, 0.008753825, 0.03498998, 0.030215185, 0.03907079, + 0.089751154, 0.029194152, -0.03337423, -0.019092513, 0.04331237, + 0.04299654, -0.036394123, -0.12915532, 0.09793732, 0.07512415, + -0.11319543, -0.032502122, 0.15661901, 0.07671967, -0.005491124, + -0.19379048, -0.218606, 0.21448623, 0.017840758, 0.1416943, + -0.07051762, 0.19488361, 0.02664691, -0.18104725, -0.09334311, + 0.15026465, -0.15493552, -0.057762887, -0.11604192, -0.262013, + -0.01391798, 0.012185008, 0.11156489, -0.07483202, 0.06693364, + -0.26151478, 0.046425626, 0.036540434, -0.16435726, 0.17338543, + -0.21401681, -0.11385144, -0.08283257, -0.069031075, 0.030635102, + 0.010969227, 0.11109743, 0.010919218, 0.027526086, 0.13519906, + 0.01891392, -0.046839405, -0.040167913, 0.017953383, -0.09700955, + 0.0061885654, -0.07000971, 0.026893595, -0.038844477, 0.14543656}); + + static float lstm_input[][20] = { + {// Batch0: 4 (input_sequence_size) * 5 (n_input) + 0.787926, 0.151646, 0.071352, 0.118426, 0.458058, 0.596268, 0.998386, + 0.568695, 0.864524, 0.571277, 0.073204, 0.296072, 0.743333, 0.069199, + 0.045348, 0.867394, 0.291279, 0.013714, 0.482521, 0.626339}, + + {// Batch1: 4 (input_sequence_size) * 5 (n_input) + 0.295743, 0.544053, 0.690064, 0.858138, 0.497181, 0.642421, 0.524260, + 0.134799, 0.003639, 0.162482, 0.640394, 0.930399, 0.050782, 0.432485, + 0.988078, 0.082922, 0.563329, 0.865614, 0.333232, 0.259916}}; + + static float lstm_golden_output[][64] = { + {// Batch0: 4 (input_sequence_size) * 16 (n_output) + -0.00396806, 0.029352, -0.00279226, 0.0159977, -0.00835576, + -0.0211779, 0.0283512, -0.0114597, 0.00907307, -0.0244004, + -0.0152191, -0.0259063, 0.00914318, 0.00415118, 0.017147, + 0.0134203, -0.0166936, 0.0381209, 0.000889694, 0.0143363, + -0.0328911, -0.0234288, 0.0333051, -0.012229, 0.0110322, + -0.0457725, -0.000832209, -0.0202817, 0.0327257, 0.0121308, + 0.0155969, 0.0312091, -0.0213783, 0.0350169, 0.000324794, + 0.0276012, -0.0263374, -0.0371449, 0.0446149, -0.0205474, + 0.0103729, -0.0576349, -0.0150052, -0.0292043, 0.0376827, + 0.0136115, 0.0243435, 0.0354492, -0.0189322, 0.0464512, + -0.00251373, 0.0225745, -0.0308346, -0.0317124, 0.0460407, + -0.0189395, 0.0149363, -0.0530162, -0.0150767, -0.0340193, + 0.0286833, 0.00824207, 0.0264887, 0.0305169}, + {// Batch1: 4 (input_sequence_size) * 16 (n_output) + -0.013869, 0.0287268, -0.00334693, 0.00733398, -0.0287926, + -0.0186926, 0.0193662, -0.0115437, 0.00422612, -0.0345232, + 0.00223253, -0.00957321, 0.0210624, 0.013331, 0.0150954, + 0.02168, -0.0141913, 0.0322082, 0.00227024, 0.0260507, + -0.0188721, -0.0296489, 0.0399134, -0.0160509, 0.0116039, + -0.0447318, -0.0150515, -0.0277406, 0.0316596, 0.0118233, + 0.0214762, 0.0293641, -0.0204549, 0.0450315, -0.00117378, + 0.0167673, -0.0375007, -0.0238314, 0.038784, -0.0174034, + 0.0131743, -0.0506589, -0.0048447, -0.0240239, 0.0325789, + 0.00790065, 0.0220157, 0.0333314, -0.0264787, 0.0387855, + -0.000764675, 0.0217599, -0.037537, -0.0335206, 0.0431679, + -0.0211424, 0.010203, -0.062785, -0.00832363, -0.025181, + 0.0412031, 0.0118723, 0.0239643, 0.0394009}}; + + // Resetting cell_state and output_state + lstm.ResetCellState(); + lstm.ResetOutputState(); + + for (int i = 0; i < lstm.sequence_length(); i++) { + float* batch0_start = lstm_input[0] + i * lstm.num_inputs(); + float* batch0_end = batch0_start + lstm.num_inputs(); + + lstm.SetInput(2 * i * lstm.num_inputs(), batch0_start, batch0_end); + + float* batch1_start = lstm_input[1] + i * lstm.num_inputs(); + float* batch1_end = batch1_start + lstm.num_inputs(); + lstm.SetInput((2 * i + 1) * lstm.num_inputs(), batch1_start, batch1_end); + } + + lstm.Invoke(); + + std::vector expected; + for (int i = 0; i < lstm.sequence_length(); i++) { + float* golden_start_batch0 = lstm_golden_output[0] + i * lstm.num_outputs(); + float* golden_end_batch0 = golden_start_batch0 + lstm.num_outputs(); + float* golden_start_batch1 = lstm_golden_output[1] + i * lstm.num_outputs(); + float* golden_end_batch1 = golden_start_batch1 + lstm.num_outputs(); + expected.insert(expected.end(), golden_start_batch0, golden_end_batch0); + expected.insert(expected.end(), golden_start_batch1, golden_end_batch1); + } + EXPECT_THAT(lstm.GetOutput(), ElementsAreArray(ArrayFloatNear(expected))); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/memory_planner.h b/tensorflow/contrib/lite/memory_planner.h index b11d86c375ca6bd8693f2271df63ecb3c87657de..5cd6c208500f3ea84ab8146f7f136e8b7851ff03 100644 --- a/tensorflow/contrib/lite/memory_planner.h +++ b/tensorflow/contrib/lite/memory_planner.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_MEMORY_PLANNER_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_MEMORY_PLANNER_H_ +#ifndef TENSORFLOW_CONTRIB_LITE_MEMORY_PLANNER_H_ +#define TENSORFLOW_CONTRIB_LITE_MEMORY_PLANNER_H_ #include "tensorflow/contrib/lite/context.h" @@ -42,4 +42,4 @@ class MemoryPlanner { } // namespace tflite -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_MEMORY_PLANNER_H_ +#endif // TENSORFLOW_CONTRIB_LITE_MEMORY_PLANNER_H_ diff --git a/tensorflow/contrib/lite/model.cc b/tensorflow/contrib/lite/model.cc index 4b0c853f77c102efa7574ff97c254d92504730a3..303a10af03e582d5e4e641c15072e1c9d594e1f4 100644 --- a/tensorflow/contrib/lite/model.cc +++ b/tensorflow/contrib/lite/model.cc @@ -463,6 +463,7 @@ void* ParseOpData(const Operator* op, BuiltinOperator op_type, builtin_data = reinterpret_cast(params); break; } + case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM: case BuiltinOperator_LSTM: { TfLiteLSTMParams* params = MallocPOD(); if (auto* lstm_params = op->builtin_options_as_LSTMOptions()) { @@ -475,35 +476,9 @@ void* ParseOpData(const Operator* op, BuiltinOperator op_type, break; } case BuiltinOperator_RESIZE_BILINEAR: { - auto* params = MallocPOD(); - if (auto* schema_params = - op->builtin_options_as_ResizeBilinearOptions()) { - params->new_height = schema_params->new_height(); - params->new_width = schema_params->new_width(); - } - builtin_data = reinterpret_cast(params); break; } case BuiltinOperator_PAD: { - auto* params = MallocPOD(); - if (auto* schema_params = op->builtin_options_as_PadOptions()) { - auto* before_padding = schema_params->before_padding(); - FlatBufferIntVectorToArray(sizeof(params->before_padding), - before_padding, params->before_padding, - error_reporter); - - auto* after_padding = schema_params->after_padding(); - FlatBufferIntVectorToArray(sizeof(params->after_padding), after_padding, - params->after_padding, error_reporter); - - if (before_padding->Length() != after_padding->Length()) { - error_reporter->Report( - "Before padding and after padding arrays need to contain the " - "same number of dimensions.\n"); - } - params->num_dimensions = after_padding->Length(); - } - builtin_data = reinterpret_cast(params); break; } case BuiltinOperator_RESHAPE: { @@ -617,6 +592,18 @@ void* ParseOpData(const Operator* op, BuiltinOperator op_type, builtin_data = reinterpret_cast(params); break; } + case BuiltinOperator_STRIDED_SLICE: { + auto* params = MallocPOD(); + if (auto* schema_params = op->builtin_options_as_StridedSliceOptions()) { + params->begin_mask = schema_params->begin_mask(); + params->end_mask = schema_params->end_mask(); + params->ellipsis_mask = schema_params->ellipsis_mask(); + params->new_axis_mask = schema_params->new_axis_mask(); + params->shrink_axis_mask = schema_params->shrink_axis_mask(); + } + builtin_data = reinterpret_cast(params); + break; + } } return builtin_data; } diff --git a/tensorflow/contrib/lite/model.h b/tensorflow/contrib/lite/model.h index e0c96f7f0480cd3146f95a22957477809cf0096d..a467df5bb4eee3f6ce814512cb8b74bf09a6a4e7 100644 --- a/tensorflow/contrib/lite/model.h +++ b/tensorflow/contrib/lite/model.h @@ -31,8 +31,8 @@ limitations under the License. // OpResolver must be defined to provide your kernel implementations to the // interpreter. This is environment specific and may consist of just the builtin // ops, or some custom operators you defined to extend tflite. -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_MODEL_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_MODEL_H_ +#ifndef TENSORFLOW_CONTRIB_LITE_MODEL_H_ +#define TENSORFLOW_CONTRIB_LITE_MODEL_H_ #include #include "tensorflow/contrib/lite/error_reporter.h" @@ -173,4 +173,4 @@ class InterpreterBuilder { } // namespace tflite -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_MODEL_H_ +#endif // TENSORFLOW_CONTRIB_LITE_MODEL_H_ diff --git a/tensorflow/contrib/lite/models/smartreply/predictor.h b/tensorflow/contrib/lite/models/smartreply/predictor.h index d17323a3f9a0ea80ad5e215b0a4700e625d0c590..90260c8d620b0e756f72089d3f4d8d9f92d44fbe 100644 --- a/tensorflow/contrib/lite/models/smartreply/predictor.h +++ b/tensorflow/contrib/lite/models/smartreply/predictor.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_MODELS_SMARTREPLY_PREDICTOR_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_MODELS_SMARTREPLY_PREDICTOR_H_ +#ifndef TENSORFLOW_CONTRIB_LITE_MODELS_SMARTREPLY_PREDICTOR_H_ +#define TENSORFLOW_CONTRIB_LITE_MODELS_SMARTREPLY_PREDICTOR_H_ #include #include @@ -77,4 +77,4 @@ struct SmartReplyConfig { } // namespace custom } // namespace tflite -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_MODELS_SMARTREPLY_PREDICTOR_H_ +#endif // TENSORFLOW_CONTRIB_LITE_MODELS_SMARTREPLY_PREDICTOR_H_ diff --git a/tensorflow/contrib/lite/models/smartreply/predictor_test.cc b/tensorflow/contrib/lite/models/smartreply/predictor_test.cc index 97d3c650e21c3cb4bef1db09df93f4bf24f38ba5..e6c8d966f1aff5a867f9469f8fcdec526df84763 100644 --- a/tensorflow/contrib/lite/models/smartreply/predictor_test.cc +++ b/tensorflow/contrib/lite/models/smartreply/predictor_test.cc @@ -22,8 +22,9 @@ limitations under the License. #include #include "absl/strings/str_cat.h" #include "absl/strings/str_split.h" -#include "tensorflow/contrib/lite/models/test_utils.h" +//#include "tensorflow/contrib/lite/models/test_utils.h" #include "tensorflow/contrib/lite/string_util.h" +#include "tensorflow/core/platform/test.h" namespace tflite { namespace custom { @@ -33,6 +34,11 @@ namespace { const char kModelName[] = "smartreply_ondevice_model.bin"; const char kSamples[] = "smartreply_samples.tsv"; +string TestDataPath() { + return string(StrCat(tensorflow::testing::TensorFlowSrcRoot(), "/", + "contrib/lite/models/testdata/")); +} + MATCHER_P(IncludeAnyResponesIn, expected_response, "contains the response") { bool has_expected_response = false; for (const auto &item : *arg) { diff --git a/tensorflow/contrib/lite/models/speech_asr_am_model_test.cc b/tensorflow/contrib/lite/models/speech_asr_am_model_test.cc deleted file mode 100644 index bf95b313f31c2f76046727353a9a7b0658dbf067..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/lite/models/speech_asr_am_model_test.cc +++ /dev/null @@ -1,127 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -// Unit test for speech ASR AM model using TFLite Ops. - -#include - -#include -#include - -#include "base/logging.h" -#include "file/base/path.h" -#include "testing/base/public/googletest.h" -#include -#include "tensorflow/contrib/lite/context.h" -#include "tensorflow/contrib/lite/interpreter.h" -#include "tensorflow/contrib/lite/kernels/register.h" -#include "tensorflow/contrib/lite/model.h" -#include "tensorflow/contrib/lite/models/test_utils.h" - -namespace tflite { -namespace models { - -constexpr int kModelInputTensor = 0; -constexpr int kLstmLayer1OutputStateTensor = 19; -constexpr int kLstmLayer1CellStateTensor = 20; -constexpr int kLstmLayer2OutputStateTensor = 40; -constexpr int kLstmLayer2CellStateTensor = 41; -constexpr int kLstmLayer3OutputStateTensor = 61; -constexpr int kLstmLayer3CellStateTensor = 62; -constexpr int kLstmLayer4OutputStateTensor = 82; -constexpr int kLstmLayer4CellStateTensor = 83; -constexpr int kLstmLayer5OutputStateTensor = 103; -constexpr int kLstmLayer5CellStateTensor = 104; -constexpr int kModelOutputTensor = 109; - -TEST(SpeechAsrAm, RandomIOTest) { - // Read the model. - string tflite_file_path = - file::JoinPath(TestDataPath(), "speech_asr_am_model.tflite"); - auto model = FlatBufferModel::BuildFromFile(tflite_file_path.c_str()); - CHECK(model) << "Failed to mmap model " << tflite_file_path; - - // Initialize the interpreter. - ops::builtin::BuiltinOpResolver builtins; - std::unique_ptr interpreter; - InterpreterBuilder(*model, builtins)(&interpreter); - CHECK(interpreter != nullptr); - interpreter->AllocateTensors(); - - // Load the input frames. - Frames input_frames; - const string input_file_path = - file::JoinPath(TestDataPath(), "speech_asr_am_model_in.csv"); - ReadFrames(input_file_path, &input_frames); - - // Load the golden output results. - Frames output_frames; - const string output_file_path = - file::JoinPath(TestDataPath(), "speech_asr_am_model_out.csv"); - ReadFrames(output_file_path, &output_frames); - - const int speech_batch_size = - interpreter->tensor(kModelInputTensor)->dims->data[0]; - const int speech_input_size = - interpreter->tensor(kModelInputTensor)->dims->data[1]; - const int speech_output_size = - interpreter->tensor(kModelOutputTensor)->dims->data[1]; - - float* input_ptr = interpreter->tensor(kModelInputTensor)->data.f; - float* output_ptr = interpreter->tensor(kModelOutputTensor)->data.f; - - // Clear the LSTM state for layers. - memset(interpreter->tensor(kLstmLayer1OutputStateTensor)->data.raw, 0, - interpreter->tensor(kLstmLayer1OutputStateTensor)->bytes); - memset(interpreter->tensor(kLstmLayer1CellStateTensor)->data.raw, 0, - interpreter->tensor(kLstmLayer1CellStateTensor)->bytes); - - memset(interpreter->tensor(kLstmLayer2OutputStateTensor)->data.raw, 0, - interpreter->tensor(kLstmLayer2OutputStateTensor)->bytes); - memset(interpreter->tensor(kLstmLayer2CellStateTensor)->data.raw, 0, - interpreter->tensor(kLstmLayer2CellStateTensor)->bytes); - - memset(interpreter->tensor(kLstmLayer3OutputStateTensor)->data.raw, 0, - interpreter->tensor(kLstmLayer3OutputStateTensor)->bytes); - memset(interpreter->tensor(kLstmLayer3CellStateTensor)->data.raw, 0, - interpreter->tensor(kLstmLayer3CellStateTensor)->bytes); - - memset(interpreter->tensor(kLstmLayer4OutputStateTensor)->data.raw, 0, - interpreter->tensor(kLstmLayer4OutputStateTensor)->bytes); - memset(interpreter->tensor(kLstmLayer4CellStateTensor)->data.raw, 0, - interpreter->tensor(kLstmLayer4CellStateTensor)->bytes); - - memset(interpreter->tensor(kLstmLayer5OutputStateTensor)->data.raw, 0, - interpreter->tensor(kLstmLayer5OutputStateTensor)->bytes); - memset(interpreter->tensor(kLstmLayer5CellStateTensor)->data.raw, 0, - interpreter->tensor(kLstmLayer5CellStateTensor)->bytes); - - - for (int i = 0; i < input_frames.size(); i++) { - // Feed the input to model. - int frame_ptr = 0; - for (int k = 0; k < speech_input_size * speech_batch_size; k++) { - input_ptr[k] = input_frames[i][frame_ptr++]; - } - // Run the model. - interpreter->Invoke(); - // Validate the output. - for (int k = 0; k < speech_output_size; k++) { - ASSERT_NEAR(output_ptr[k], output_frames[i][k], 5.2e-4); - } - } -} - -} // namespace models -} // namespace tflite diff --git a/tensorflow/contrib/lite/models/speech_asr_lm_model_test.cc b/tensorflow/contrib/lite/models/speech_asr_lm_model_test.cc deleted file mode 100644 index 53f2b66da492f8fe56fa9e234f0951cf61c35037..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/lite/models/speech_asr_lm_model_test.cc +++ /dev/null @@ -1,122 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -// Unit test for speech ASR LM model using TFLite Ops. - -#include - -#include -#include - -#include "base/logging.h" -#include "file/base/path.h" -#include "testing/base/public/googletest.h" -#include -#include "tensorflow/contrib/lite/context.h" -#include "tensorflow/contrib/lite/interpreter.h" -#include "tensorflow/contrib/lite/kernels/register.h" -#include "tensorflow/contrib/lite/model.h" -#include "tensorflow/contrib/lite/models/test_utils.h" - -namespace tflite { -namespace models { - -constexpr int kModelInput1Tensor = 0; -constexpr int kModelInput2Tensor = 66; -constexpr int kLstmLayer1OutputStateTensor = 21; -constexpr int kLstmLayer1CellStateTensor = 22; -constexpr int kLstmLayer2OutputStateTensor = 42; -constexpr int kLstmLayer2CellStateTensor = 43; -constexpr int kLstmLayer3OutputStateTensor = 63; -constexpr int kLstmLayer3CellStateTensor = 64; -constexpr int kModelOutputTensor = 75; - -static void ClearLstmStates(Interpreter* interpreter) { - memset(interpreter->tensor(kLstmLayer1OutputStateTensor)->data.raw, 0, - interpreter->tensor(kLstmLayer1OutputStateTensor)->bytes); - memset(interpreter->tensor(kLstmLayer1CellStateTensor)->data.raw, 0, - interpreter->tensor(kLstmLayer1CellStateTensor)->bytes); - - memset(interpreter->tensor(kLstmLayer2OutputStateTensor)->data.raw, 0, - interpreter->tensor(kLstmLayer2OutputStateTensor)->bytes); - memset(interpreter->tensor(kLstmLayer2CellStateTensor)->data.raw, 0, - interpreter->tensor(kLstmLayer2CellStateTensor)->bytes); - - memset(interpreter->tensor(kLstmLayer3OutputStateTensor)->data.raw, 0, - interpreter->tensor(kLstmLayer3OutputStateTensor)->bytes); - memset(interpreter->tensor(kLstmLayer3CellStateTensor)->data.raw, 0, - interpreter->tensor(kLstmLayer3CellStateTensor)->bytes); -} - -TEST(SpeechAsrLm, EndToEndTest) { - // Read the model. - string tflite_file_path = - file::JoinPath(TestDataPath(), "speech_asr_lm_model.tflite"); - auto model = FlatBufferModel::BuildFromFile(tflite_file_path.c_str()); - CHECK(model) << "Failed to mmap model " << tflite_file_path; - - // Initialize the interpreter. - ops::builtin::BuiltinOpResolver builtins; - std::unique_ptr interpreter; - InterpreterBuilder(*model, builtins)(&interpreter); - CHECK(interpreter != nullptr); - interpreter->AllocateTensors(); - - // Load the input frames. - Frames input_frames; - const string input_file_path = - file::JoinPath(TestDataPath(), "speech_asr_lm_model_in.csv"); - ReadFrames(input_file_path, &input_frames); - - // Load the golden output results. - Frames output_frames; - const string output_file_path = - file::JoinPath(TestDataPath(), "speech_asr_lm_model_out.csv"); - ReadFrames(output_file_path, &output_frames); - - CHECK_EQ(interpreter->tensor(kModelInput1Tensor)->dims->size, 1); - const int input1_size = - interpreter->tensor(kModelInput1Tensor)->dims->data[0]; - CHECK_EQ(input1_size, 1); - CHECK_EQ(interpreter->tensor(kModelInput2Tensor)->dims->size, 1); - const int output_size = - interpreter->tensor(kModelOutputTensor)->dims->data[0]; - CHECK_EQ(output_size, 1); - - int* input_lookup_ptr = interpreter->tensor(kModelInput1Tensor)->data.i32; - int* output_lookup_ptr = interpreter->tensor(kModelInput2Tensor)->data.i32; - float* output_ptr = interpreter->tensor(kModelOutputTensor)->data.f; - - - for (int i = 0; i < input_frames.size(); i++) { - float output_score = 0.0f; - // Reset LSTM states for each sequence. - ClearLstmStates(interpreter.get()); - // For subsequent inputs feed them sequentially, one-by-one. - for (int k = 1; k < input_frames[i].size(); k++) { - // Feed the inputs to model. - input_lookup_ptr[0] = static_cast(input_frames[i][k - 1]); - output_lookup_ptr[0] = static_cast(input_frames[i][k]); - // Run the model. - interpreter->Invoke(); - // Sum up the outputs. - output_score += output_ptr[0]; - } - // Validate the output. - ASSERT_NEAR(output_score, output_frames[i][0], 1.4e-5); - } -} - -} // namespace models -} // namespace tflite diff --git a/tensorflow/contrib/lite/models/speech_endpointer_model_test.cc b/tensorflow/contrib/lite/models/speech_endpointer_model_test.cc deleted file mode 100644 index f7e136113aa056fdc87378f8c902f53c811cd39c..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/lite/models/speech_endpointer_model_test.cc +++ /dev/null @@ -1,104 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -// Unit test for speech EndPointer model using TFLite Ops. - -#include - -#include -#include - -#include "base/logging.h" -#include "testing/base/public/googletest.h" -#include -#include "absl/strings/str_cat.h" -#include "tensorflow/contrib/lite/context.h" -#include "tensorflow/contrib/lite/interpreter.h" -#include "tensorflow/contrib/lite/kernels/register.h" -#include "tensorflow/contrib/lite/model.h" -#include "tensorflow/contrib/lite/models/test_utils.h" - -namespace tflite { -namespace models { - -constexpr int kModelInputTensor = 0; -constexpr int kLstmLayer1OutputStateTensor = 28; -constexpr int kLstmLayer1CellStateTensor = 29; -constexpr int kLstmLayer2OutputStateTensor = 49; -constexpr int kLstmLayer2CellStateTensor = 50; -constexpr int kModelOutputTensor = 58; - -TEST(SpeechEndpointer, EndpointerTest) { - // Read the model. - string tflite_file_path = - StrCat(TestDataPath(), "/", "speech_endpointer_model.tflite"); - auto model = FlatBufferModel::BuildFromFile(tflite_file_path.c_str()); - CHECK(model) << "Failed to read model from file " << tflite_file_path; - - // Initialize the interpreter. - ops::builtin::BuiltinOpResolver builtins; - std::unique_ptr interpreter; - InterpreterBuilder(*model, builtins)(&interpreter); - CHECK(interpreter != nullptr); - interpreter->AllocateTensors(); - - // Load the input frames. - Frames input_frames; - const string input_file_path = - StrCat(TestDataPath(), "/", "speech_endpointer_model_in.csv"); - ReadFrames(input_file_path, &input_frames); - - // Load the golden output results. - Frames output_frames; - const string output_file_path = - StrCat(TestDataPath(), "/", "speech_endpointer_model_out.csv"); - ReadFrames(output_file_path, &output_frames); - - const int speech_batch_size = - interpreter->tensor(kModelInputTensor)->dims->data[0]; - const int speech_input_size = - interpreter->tensor(kModelInputTensor)->dims->data[1]; - const int speech_output_size = - interpreter->tensor(kModelOutputTensor)->dims->data[1]; - - float* input_ptr = interpreter->tensor(kModelInputTensor)->data.f; - float* output_ptr = interpreter->tensor(kModelOutputTensor)->data.f; - - // Clear the LSTM state for layers. - memset(interpreter->tensor(kLstmLayer1OutputStateTensor)->data.raw, 0, - interpreter->tensor(kLstmLayer1OutputStateTensor)->bytes); - memset(interpreter->tensor(kLstmLayer1CellStateTensor)->data.raw, 0, - interpreter->tensor(kLstmLayer1CellStateTensor)->bytes); - memset(interpreter->tensor(kLstmLayer2OutputStateTensor)->data.raw, 0, - interpreter->tensor(kLstmLayer2OutputStateTensor)->bytes); - memset(interpreter->tensor(kLstmLayer2CellStateTensor)->data.raw, 0, - interpreter->tensor(kLstmLayer2CellStateTensor)->bytes); - - for (int i = 0; i < input_frames.size(); i++) { - // Feed the input to model. - int frame_ptr = 0; - for (int k = 0; k < speech_input_size * speech_batch_size; k++) { - input_ptr[k] = input_frames[i][frame_ptr++]; - } - // Run the model. - interpreter->Invoke(); - // Validate the output. - for (int k = 0; k < speech_output_size; k++) { - ASSERT_NEAR(output_ptr[k], output_frames[i][k], 1e-5); - } - } -} - -} // namespace models -} // namespace tflite diff --git a/tensorflow/contrib/lite/models/speech_hotword_model_test.cc b/tensorflow/contrib/lite/models/speech_hotword_model_test.cc deleted file mode 100644 index f69cae8d2cb08678f9eec8c9b9d653cfce55bd2e..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/lite/models/speech_hotword_model_test.cc +++ /dev/null @@ -1,114 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -// Unit test for speech Hotword model using TFLite Ops. - -#include - -#include -#include - -#include "base/logging.h" -#include "testing/base/public/googletest.h" -#include -#include "absl/strings/str_cat.h" -#include "tensorflow/contrib/lite/context.h" -#include "tensorflow/contrib/lite/interpreter.h" -#include "tensorflow/contrib/lite/kernels/register.h" -#include "tensorflow/contrib/lite/model.h" -#include "tensorflow/contrib/lite/models/test_utils.h" - -namespace tflite { -namespace models { - -void RunTest(int model_input_tensor, int svdf_layer_state_tensor, - int model_output_tensor, const string& model_name, - const string& golden_in_name, const string& golden_out_name) { - // Read the model. - string tflite_file_path = StrCat(TestDataPath(), "/", model_name); - auto model = FlatBufferModel::BuildFromFile(tflite_file_path.c_str()); - CHECK(model) << "Failed to read model from file " << tflite_file_path; - - // Initialize the interpreter. - ops::builtin::BuiltinOpResolver builtins; - std::unique_ptr interpreter; - InterpreterBuilder(*model, builtins)(&interpreter); - CHECK(interpreter != nullptr); - interpreter->AllocateTensors(); - - // Reset the SVDF layer state. - memset(interpreter->tensor(svdf_layer_state_tensor)->data.raw, 0, - interpreter->tensor(svdf_layer_state_tensor)->bytes); - - // Load the input frames. - Frames input_frames; - const string input_file_path = StrCat(TestDataPath(), "/", golden_in_name); - ReadFrames(input_file_path, &input_frames); - - // Load the golden output results. - Frames output_frames; - const string output_file_path = StrCat(TestDataPath(), "/", golden_out_name); - ReadFrames(output_file_path, &output_frames); - - const int speech_batch_size = - interpreter->tensor(model_input_tensor)->dims->data[0]; - const int speech_input_size = - interpreter->tensor(model_input_tensor)->dims->data[1]; - const int speech_output_size = - interpreter->tensor(model_output_tensor)->dims->data[1]; - const int input_sequence_size = - input_frames[0].size() / (speech_input_size * speech_batch_size); - float* input_ptr = interpreter->tensor(model_input_tensor)->data.f; - float* output_ptr = interpreter->tensor(model_output_tensor)->data.f; - - // The first layer (SVDF) input size is 40 (speech_input_size). Each speech - // input frames for this model is 1600 floats, which can be fed to input in a - // sequence of size 40 (input_sequence_size). - for (int i = 0; i < TestInputSize(input_frames); i++) { - int frame_ptr = 0; - for (int s = 0; s < input_sequence_size; s++) { - for (int k = 0; k < speech_input_size * speech_batch_size; k++) { - input_ptr[k] = input_frames[i][frame_ptr++]; - } - interpreter->Invoke(); - } - // After the whole frame (1280 floats) is fed, we can check the output frame - // matches with the golden output frame. - for (int k = 0; k < speech_output_size; k++) { - ASSERT_NEAR(output_ptr[k], output_frames[i][k], 1e-5); - } - } -} - -TEST(SpeechHotword, OkGoogleTestRank1) { - constexpr int kModelInputTensor = 0; - constexpr int kSvdfLayerStateTensor = 4; - constexpr int kModelOutputTensor = 18; - - RunTest(kModelInputTensor, kSvdfLayerStateTensor, kModelOutputTensor, - "speech_hotword_model_rank1.tflite", "speech_hotword_model_in.csv", - "speech_hotword_model_out_rank1.csv"); -} - -TEST(SpeechHotword, OkGoogleTestRank2) { - constexpr int kModelInputTensor = 17; - constexpr int kSvdfLayerStateTensor = 1; - constexpr int kModelOutputTensor = 18; - RunTest(kModelInputTensor, kSvdfLayerStateTensor, kModelOutputTensor, - "speech_hotword_model_rank2.tflite", "speech_hotword_model_in.csv", - "speech_hotword_model_out_rank2.csv"); -} - -} // namespace models -} // namespace tflite diff --git a/tensorflow/contrib/lite/models/speech_speakerid_model_test.cc b/tensorflow/contrib/lite/models/speech_speakerid_model_test.cc deleted file mode 100644 index e208fac8dfcb1b84e9884d303ac9b8a67d4fa47f..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/lite/models/speech_speakerid_model_test.cc +++ /dev/null @@ -1,121 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -// Unit test for speech SpeakerId model using TFLite Ops. - -#include - -#include -#include - -#include "base/logging.h" -#include "testing/base/public/googletest.h" -#include -#include "absl/strings/str_cat.h" -#include "tensorflow/contrib/lite/context.h" -#include "tensorflow/contrib/lite/interpreter.h" -#include "tensorflow/contrib/lite/model.h" -#include "tensorflow/contrib/lite/models/test_utils.h" -#include "tensorflow/contrib/lite/tools/mutable_op_resolver.h" - -void RegisterSelectedOps(::tflite::MutableOpResolver* resolver); - -namespace tflite { -namespace models { - -constexpr int kModelInputTensor = 0; -constexpr int kLstmLayer1OutputStateTensor = 19; -constexpr int kLstmLayer1CellStateTensor = 20; -constexpr int kLstmLayer2OutputStateTensor = 40; -constexpr int kLstmLayer2CellStateTensor = 41; -constexpr int kLstmLayer3OutputStateTensor = 61; -constexpr int kLstmLayer3CellStateTensor = 62; -constexpr int kModelOutputTensor = 66; - -void SpeakerIdTest(bool useNNAPI) { - // Read the model. - string tflite_file_path = - StrCat(TestDataPath(), "/", "speech_speakerid_model.tflite"); - auto model = FlatBufferModel::BuildFromFile(tflite_file_path.c_str()); - CHECK(model) << "Failed to read model from file " << tflite_file_path; - - // Initialize the interpreter. - ::tflite::MutableOpResolver resolver; - RegisterSelectedOps(&resolver); - std::unique_ptr interpreter; - InterpreterBuilder(*model, resolver)(&interpreter); - CHECK(interpreter != nullptr); - - interpreter->UseNNAPI(useNNAPI); - - interpreter->AllocateTensors(); - - // Load the input frames. - Frames input_frames; - const string input_file_path = - StrCat(TestDataPath(), "/", "speech_speakerid_model_in.csv"); - ReadFrames(input_file_path, &input_frames); - - // Load the golden output results. - Frames output_frames; - const string output_file_path = - StrCat(TestDataPath(), "/", "speech_speakerid_model_out.csv"); - ReadFrames(output_file_path, &output_frames); - - const int speech_batch_size = - interpreter->tensor(kModelInputTensor)->dims->data[0]; - const int speech_input_size = - interpreter->tensor(kModelInputTensor)->dims->data[1]; - const int speech_output_size = - interpreter->tensor(kModelOutputTensor)->dims->data[1]; - - float* input_ptr = interpreter->tensor(kModelInputTensor)->data.f; - float* output_ptr = interpreter->tensor(kModelOutputTensor)->data.f; - - // Clear the LSTM state for layers. - memset(interpreter->tensor(kLstmLayer1OutputStateTensor)->data.raw, 0, - interpreter->tensor(kLstmLayer1OutputStateTensor)->bytes); - memset(interpreter->tensor(kLstmLayer1CellStateTensor)->data.raw, 0, - interpreter->tensor(kLstmLayer1CellStateTensor)->bytes); - - memset(interpreter->tensor(kLstmLayer2OutputStateTensor)->data.raw, 0, - interpreter->tensor(kLstmLayer2OutputStateTensor)->bytes); - memset(interpreter->tensor(kLstmLayer2CellStateTensor)->data.raw, 0, - interpreter->tensor(kLstmLayer2CellStateTensor)->bytes); - - memset(interpreter->tensor(kLstmLayer3OutputStateTensor)->data.raw, 0, - interpreter->tensor(kLstmLayer3OutputStateTensor)->bytes); - memset(interpreter->tensor(kLstmLayer3CellStateTensor)->data.raw, 0, - interpreter->tensor(kLstmLayer3CellStateTensor)->bytes); - for (int i = 0; i < input_frames.size(); i++) { - // Feed the input to model. - int frame_ptr = 0; - for (int k = 0; k < speech_input_size * speech_batch_size; k++) { - input_ptr[k] = input_frames[i][frame_ptr++]; - } - // Run the model. - interpreter->Invoke(); - // Validate the output. - for (int k = 0; k < speech_output_size; k++) { - ASSERT_NEAR(output_ptr[k], output_frames[i][k], 1e-5); - } - } -} - -TEST(SpeechSpeakerId, OkGoogleTest) { SpeakerIdTest(false); } - -TEST(SpeechSpeakerId, OkGoogleTestUsingNNAPI) { SpeakerIdTest(true); } - -} // namespace models -} // namespace tflite diff --git a/tensorflow/contrib/lite/models/speech_test.cc b/tensorflow/contrib/lite/models/speech_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..daa8c3100b64e9290256aa14a6ab641f19174a0a --- /dev/null +++ b/tensorflow/contrib/lite/models/speech_test.cc @@ -0,0 +1,189 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Unit test for speech models (Hotword, SpeakerId) using TFLite Ops. + +#include +#include + +#include + +#include "testing/base/public/googletest.h" +#include +#include "tensorflow/contrib/lite/testing/parse_testdata.h" +#include "tensorflow/contrib/lite/testing/split.h" +#include "tensorflow/contrib/lite/testing/tflite_driver.h" + +namespace tflite { +namespace { + +const char kDataPath[] = "third_party/tensorflow/contrib/lite/models/testdata/"; + +bool Init(const string& in_file_name, testing::TfLiteDriver* driver, + std::ifstream* in_file) { + driver->SetModelBaseDir(kDataPath); + in_file->open(string(kDataPath) + in_file_name, std::ifstream::in); + return in_file->is_open(); +} + +// Converts a set of test files provided by the speech team into a single +// test_spec. Input CSV files are supposed to contain a number of sequences per +// line. Each sequence maps to a single invocation of the interpreter and the +// output tensor after all sequences have run is compared to the corresponding +// line in the output CSV file. +bool ConvertCsvData(const string& model_name, const string& in_name, + const string& out_name, const string& input_tensor, + const string& output_tensor, + const string& persistent_tensors, int sequence_size, + std::ostream* out) { + auto data_path = [](const string& s) { return string(kDataPath) + s; }; + + *out << "load_model: \"" << data_path(model_name) << "\"" << std::endl; + + *out << "init_state: \"" << persistent_tensors << "\"" << std::endl; + + string in_file_name = data_path(in_name); + std::ifstream in_file(in_file_name); + if (!in_file.is_open()) { + std::cerr << "Failed to open " << in_file_name << std::endl; + return false; + } + string out_file_name = data_path(out_name); + std::ifstream out_file(out_file_name); + if (!out_file.is_open()) { + std::cerr << "Failed to open " << out_file_name << std::endl; + return false; + } + + int invocation_count = 0; + string in_values; + while (std::getline(in_file, in_values, '\n')) { + std::vector input = testing::Split(in_values, ","); + int num_sequences = input.size() / sequence_size; + + for (int j = 0; j < num_sequences; ++j) { + *out << "invoke {" << std::endl; + *out << " id: " << invocation_count << std::endl; + *out << " input: \""; + for (int k = 0; k < sequence_size; ++k) { + *out << input[k + j * sequence_size] << ","; + } + *out << "\"" << std::endl; + + if (j == num_sequences - 1) { + string out_values; + if (!std::getline(out_file, out_values, '\n')) { + std::cerr << "Not enough lines in " << out_file_name << std::endl; + return false; + } + *out << " output: \"" << out_values << "\"" << std::endl; + } + + *out << "}" << std::endl; + ++invocation_count; + } + } + return true; +} + +TEST(SpeechTest, HotwordOkGoogleRank1Test) { + std::stringstream os; + ASSERT_TRUE(ConvertCsvData( + "speech_hotword_model_rank1.tflite", "speech_hotword_model_in.csv", + "speech_hotword_model_out_rank1.csv", /*input_tensor=*/"0", + /*output_tensor=*/"18", /*persistent_tensors=*/"4", + /*sequence_size=*/40, &os)); + testing::TfLiteDriver test_driver(/*use_nnapi=*/false); + ASSERT_TRUE(testing::ParseAndRunTests(&os, &test_driver)) + << test_driver.GetErrorMessage(); +} + +TEST(SpeechTest, HotwordOkGoogleRank2Test) { + std::stringstream os; + ASSERT_TRUE(ConvertCsvData( + "speech_hotword_model_rank2.tflite", "speech_hotword_model_in.csv", + "speech_hotword_model_out_rank2.csv", /*input_tensor=*/"17", + /*output_tensor=*/"18", /*persistent_tensors=*/"1", + /*sequence_size=*/40, &os)); + testing::TfLiteDriver test_driver(/*use_nnapi=*/false); + ASSERT_TRUE(testing::ParseAndRunTests(&os, &test_driver)) + << test_driver.GetErrorMessage(); +} + +TEST(SpeechTest, SpeakerIdOkGoogleTest) { + std::stringstream os; + ASSERT_TRUE(ConvertCsvData( + "speech_speakerid_model.tflite", "speech_speakerid_model_in.csv", + "speech_speakerid_model_out.csv", /*input_tensor=*/"0", + /*output_tensor=*/"66", + /*persistent_tensors=*/"19,20,40,41,61,62", + /*sequence_size=*/80, &os)); + testing::TfLiteDriver test_driver(/*use_nnapi=*/false); + ASSERT_TRUE(testing::ParseAndRunTests(&os, &test_driver)) + << test_driver.GetErrorMessage(); +} + +TEST(SpeechTest, AsrAmTest) { + std::stringstream os; + ASSERT_TRUE( + ConvertCsvData("speech_asr_am_model.tflite", "speech_asr_am_model_in.csv", + "speech_asr_am_model_out.csv", /*input_tensor=*/"0", + /*output_tensor=*/"109", + /*persistent_tensors=*/"19,20,40,41,61,62,82,83,103,104", + /*sequence_size=*/320, &os)); + testing::TfLiteDriver test_driver(/*use_nnapi=*/false); + ASSERT_TRUE(testing::ParseAndRunTests(&os, &test_driver)) + << test_driver.GetErrorMessage(); +} + +// The original version of speech_asr_lm_model_test.cc ran a few sequences +// through the interpreter and stored the sum of all the output, which was them +// compared for correctness. In this test we are comparing all the intermediate +// results. +TEST(SpeechTest, AsrLmTest) { + std::ifstream in_file; + testing::TfLiteDriver test_driver(/*use_nnapi=*/false); + ASSERT_TRUE(Init("speech_asr_lm_model.test_spec", &test_driver, &in_file)); + ASSERT_TRUE(testing::ParseAndRunTests(&in_file, &test_driver)) + << test_driver.GetErrorMessage(); +} + +TEST(SpeechTest, EndpointerTest) { + std::stringstream os; + ASSERT_TRUE(ConvertCsvData( + "speech_endpointer_model.tflite", "speech_endpointer_model_in.csv", + "speech_endpointer_model_out.csv", /*input_tensor=*/"0", + /*output_tensor=*/"58", + /*persistent_tensors=*/"28,29,49,50", + /*sequence_size=*/320, &os)); + testing::TfLiteDriver test_driver(/*use_nnapi=*/false); + ASSERT_TRUE(testing::ParseAndRunTests(&os, &test_driver)) + << test_driver.GetErrorMessage(); +} + +TEST(SpeechTest, TtsTest) { + std::stringstream os; + ASSERT_TRUE(ConvertCsvData("speech_tts_model.tflite", + "speech_tts_model_in.csv", + "speech_tts_model_out.csv", /*input_tensor=*/"0", + /*output_tensor=*/"74", + /*persistent_tensors=*/"25,26,46,47,67,68,73", + /*sequence_size=*/334, &os)); + testing::TfLiteDriver test_driver(/*use_nnapi=*/false); + ASSERT_TRUE(testing::ParseAndRunTests(&os, &test_driver)) + << test_driver.GetErrorMessage(); +} + +} // namespace +} // namespace tflite diff --git a/tensorflow/contrib/lite/models/speech_tts_model_test.cc b/tensorflow/contrib/lite/models/speech_tts_model_test.cc deleted file mode 100644 index 88291776892f3186ca5bfc726e814f8d23d73b11..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/lite/models/speech_tts_model_test.cc +++ /dev/null @@ -1,116 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -// Unit test for speech TTS model using TFLite Ops. - -#include - -#include -#include - -#include "base/logging.h" -#include "testing/base/public/googletest.h" -#include -#include "absl/strings/str_cat.h" -#include "tensorflow/contrib/lite/context.h" -#include "tensorflow/contrib/lite/interpreter.h" -#include "tensorflow/contrib/lite/kernels/register.h" -#include "tensorflow/contrib/lite/model.h" -#include "tensorflow/contrib/lite/models/test_utils.h" - -namespace tflite { -namespace models { - -constexpr int kModelInputTensor = 0; -constexpr int kLstmLayer1OutputStateTensor = 25; -constexpr int kLstmLayer1CellStateTensor = 26; -constexpr int kLstmLayer2OutputStateTensor = 46; -constexpr int kLstmLayer2CellStateTensor = 47; -constexpr int kLstmLayer3OutputStateTensor = 67; -constexpr int kLstmLayer3CellStateTensor = 68; -constexpr int kRnnLayerHiddenStateTensor = 73; -constexpr int kModelOutputTensor = 74; - -TEST(SpeechTTS, RandomIOTest) { - // Read the model. - string tflite_file_path = - StrCat(TestDataPath(), "/", "speech_tts_model.tflite"); - auto model = FlatBufferModel::BuildFromFile(tflite_file_path.c_str()); - CHECK(model) << "Failed to mmap model " << tflite_file_path; - - // Initialize the interpreter. - ops::builtin::BuiltinOpResolver builtins; - std::unique_ptr interpreter; - InterpreterBuilder(*model, builtins)(&interpreter); - CHECK(interpreter != nullptr); - interpreter->AllocateTensors(); - - // Load the input frames. - Frames input_frames; - const string input_file_path = - StrCat(TestDataPath(), "/", "speech_tts_model_in.csv"); - ReadFrames(input_file_path, &input_frames); - - // Load the golden output results. - Frames output_frames; - const string output_file_path = - StrCat(TestDataPath(), "/", "speech_tts_model_out.csv"); - ReadFrames(output_file_path, &output_frames); - - const int speech_batch_size = - interpreter->tensor(kModelInputTensor)->dims->data[0]; - const int speech_input_size = - interpreter->tensor(kModelInputTensor)->dims->data[1]; - const int speech_output_size = - interpreter->tensor(kModelOutputTensor)->dims->data[1]; - - float* input_ptr = interpreter->tensor(kModelInputTensor)->data.f; - float* output_ptr = interpreter->tensor(kModelOutputTensor)->data.f; - - // Clear the LSTM state for layers. - memset(interpreter->tensor(kLstmLayer1OutputStateTensor)->data.raw, 0, - interpreter->tensor(kLstmLayer1OutputStateTensor)->bytes); - memset(interpreter->tensor(kLstmLayer1CellStateTensor)->data.raw, 0, - interpreter->tensor(kLstmLayer1CellStateTensor)->bytes); - - memset(interpreter->tensor(kLstmLayer2OutputStateTensor)->data.raw, 0, - interpreter->tensor(kLstmLayer2OutputStateTensor)->bytes); - memset(interpreter->tensor(kLstmLayer2CellStateTensor)->data.raw, 0, - interpreter->tensor(kLstmLayer2CellStateTensor)->bytes); - - memset(interpreter->tensor(kLstmLayer3OutputStateTensor)->data.raw, 0, - interpreter->tensor(kLstmLayer3OutputStateTensor)->bytes); - memset(interpreter->tensor(kLstmLayer3CellStateTensor)->data.raw, 0, - interpreter->tensor(kLstmLayer3CellStateTensor)->bytes); - - memset(interpreter->tensor(kRnnLayerHiddenStateTensor)->data.raw, 0, - interpreter->tensor(kRnnLayerHiddenStateTensor)->bytes); - - for (int i = 0; i < input_frames.size(); i++) { - // Feed the input to model. - int frame_ptr = 0; - for (int k = 0; k < speech_input_size * speech_batch_size; k++) { - input_ptr[k] = input_frames[i][frame_ptr++]; - } - // Run the model. - interpreter->Invoke(); - // Validate the output. - for (int k = 0; k < speech_output_size; k++) { - ASSERT_NEAR(output_ptr[k], output_frames[i][k], 1e-5); - } - } -} - -} // namespace models -} // namespace tflite diff --git a/tensorflow/contrib/lite/models/test_utils.h b/tensorflow/contrib/lite/models/test_utils.h deleted file mode 100644 index 1e14c26a3544ed44f9395ff3b59a70551a1a6394..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/lite/models/test_utils.h +++ /dev/null @@ -1,84 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_MODELS_TEST_UTILS_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_MODELS_TEST_UTILS_H_ - -#include -#include - -#include -#include -#include -#include - -namespace tflite { -namespace models { -using Frames = std::vector>; -} // namespace models -} // namespace tflite - -#ifndef __ANDROID__ -#include "absl/strings/str_cat.h" -#include "tensorflow/core/platform/test.h" - -inline string TestDataPath() { - return string(StrCat(tensorflow::testing::TensorFlowSrcRoot(), "/", - "contrib/lite/models/testdata/")); -} -inline int TestInputSize(const tflite::models::Frames& input_frames) { - return input_frames.size(); -} -#else -inline string TestDataPath() { - return string("third_party/tensorflow/contrib/lite/models/testdata/"); -} - -inline int TestInputSize(const tflite::models::Frames& input_frames) { - // Android TAP is very slow, we only test the first 20 frames. - return 20; -} -#endif - -namespace tflite { -namespace models { - -// Read float data from a comma-separated file: -// Each line will be read into a float vector. -// The return result will be a vector of float vectors. -void ReadFrames(const string& csv_file_path, Frames* frames) { - std::ifstream csv_file(csv_file_path); - string line; - while (std::getline(csv_file, line, '\n')) { - std::vector fields; - // Used by strtok_r internaly for successive calls on the same string. - char* save_ptr = nullptr; - - // Tokenize the line. - char* next_token = - strtok_r(const_cast(line.c_str()), ",", &save_ptr); - while (next_token != nullptr) { - float f = strtod(next_token, nullptr); - fields.push_back(f); - next_token = strtok_r(nullptr, ",", &save_ptr); - } - frames->push_back(fields); - } - csv_file.close(); -} - -} // namespace models -} // namespace tflite - -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_MODELS_TEST_UTILS_H_ diff --git a/tensorflow/contrib/lite/models/testdata/g3doc/README.md b/tensorflow/contrib/lite/models/testdata/g3doc/README.md index 667a58838329145a9500576749c5aa497641d61c..1c47e00aae2a0e76ba04004a2fc3cc02ec4536f7 100644 --- a/tensorflow/contrib/lite/models/testdata/g3doc/README.md +++ b/tensorflow/contrib/lite/models/testdata/g3doc/README.md @@ -53,7 +53,7 @@ with the corresponding parameters as shown in the figure. ### Automatic Speech Recognizer (ASR) Acoustic Model (AM) The acoustic model for automatic speech recognition is the neural network model -for matching phonemes to the input autio features. It generates posterior +for matching phonemes to the input audio features. It generates posterior probabilities of phonemes from speech frontend features (log-mel filterbanks). It has an input size of 320 (float), an output size of 42 (float), five LSTM layers and one fully connected layers with a Softmax activation function, with @@ -68,7 +68,7 @@ for predicting the probability of a word given previous words in a sentence. It generates posterior probabilities of the next word based from a sequence of words. The words are encoded as indices in a fixed size dictionary. The model has two inputs both of size one (integer): the current word index and -next word index, an output size of one (float): the log probability. It consits +next word index, an output size of one (float): the log probability. It consists of three embedding layer, three LSTM layers, followed by a multiplication, a fully connected layers and an addition. The corresponding parameters as shown in the figure. diff --git a/tensorflow/contrib/lite/models/testdata/speech_asr_lm_model.test_spec b/tensorflow/contrib/lite/models/testdata/speech_asr_lm_model.test_spec new file mode 100644 index 0000000000000000000000000000000000000000..5812de4b30382f6b031c907bf8bd12a34ac9e0b3 --- /dev/null +++ b/tensorflow/contrib/lite/models/testdata/speech_asr_lm_model.test_spec @@ -0,0 +1,202 @@ +load_model: "speech_asr_lm_model.tflite" +init_state: "21,22,42,43,63,64" +invoke { + id: 3 + input: "63982" + input: "8409" + output: "-2.75389" +} +invoke { + id: 4 + input: "8409" + input: "1488" + output: "0.601841" +} +invoke { + id: 5 + input: "1488" + input: "63981" + output: "-0.314846" +} +init_state: "21,22,42,43,63,64" +invoke { + id: 6 + input: "63982" + input: "8409" + output: "-2.75389" +} +invoke { + id: 7 + input: "8409" + input: "3082" + output: "-3.63721" +} +init_state: "21,22,42,43,63,64" +invoke { + id: 8 + input: "63982" + input: "8409" + output: "-2.75389" +} +invoke { + id: 9 + input: "8409" + input: "18965" + output: "-6.93985" +} +init_state: "21,22,42,43,63,64" +invoke { + id: 13 + input: "63982" + input: "12516" + output: "-6.20867" +} +invoke { + id: 14 + input: "12516" + input: "914" + output: "-0.407277" +} +invoke { + id: 15 + input: "914" + input: "63981" + output: "-3.82091" +} +init_state: "21,22,42,43,63,64" +invoke { + id: 19 + input: "63982" + input: "12516" + output: "-6.20867" +} +invoke { + id: 20 + input: "12516" + input: "914" + output: "-0.407277" +} +invoke { + id: 21 + input: "914" + input: "48619" + output: "-4.02131" +} +invoke { + id: 22 + input: "48619" + input: "63981" + output: "-0.677399" +} +init_state: "21,22,42,43,63,64" +invoke { + id: 26 + input: "63982" + input: "12516" + output: "-6.20867" +} +invoke { + id: 27 + input: "12516" + input: "914" + output: "-0.407277" +} +invoke { + id: 28 + input: "914" + input: "4700" + output: "-4.056" +} +invoke { + id: 29 + input: "4700" + input: "63981" + output: "0.415889" +} +init_state: "21,22,42,43,63,64" +invoke { + id: 30 + input: "63982" + input: "12516" + output: "-6.20867" +} +invoke { + id: 31 + input: "12516" + input: "914" + output: "-0.407277" +invoke { + id: 32 + input: "914" + input: "51923" + output: "-14.1147" +} +init_state: "21,22,42,43,63,64" +invoke { + id: 34 + input: "63982" + input: "5520" + output: "-4.56971" +} +invoke { + id: 35 + input: "5520" + input: "16318" + output: "-1.54815" +} +init_state: "21,22,42,43,63,64" +invoke { + id: 36 + input: "63982" + input: "5520" + output: "-4.56971" +} +invoke { + id: 37 + input: "5520" + input: "28303" + output: "-14.0947" +} +init_state: "21,22,42,43,63,64" +invoke { + id: 38 + input: "63982" + input: "12451" + output: "-6.24243" +} +invoke { + id: 39 + input: "12451" + input: "752" + output: "0.0700736" +} +invoke { + id: 40 + input: "752" + input: "11" + output: "-1.72744" +} +invoke { + id: 41 + input: "11" + input: "19454" + output: "-3.19211" +} +invoke { + id: 42 + input: "19454" + input: "16989" + output: "-4.01684" +} +invoke { + id: 43 + input: "16989" + input: "40168" + output: "-8.91317" +} +invoke { + id: 44 + input: "40168" + input: "63981" + output: "-0.675377" +} diff --git a/tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h b/tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h index 3cda4bccccd0c30bb0ccfb82e1c80f7c6a7b9d84..7019c29959fc02f4f84d1e4c8cf280751e585de0 100644 --- a/tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h +++ b/tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h @@ -370,7 +370,7 @@ enum { * Looks up items from a given tensor. * * Each item in the output is a raw copy of the corresponding item in - * the input “values”. If the the given “lookup” indices are out of bounds, + * the input “values”. If the given “lookup” indices are out of bounds, * the op will fail and an error will be reported. * * Inputs: diff --git a/tensorflow/contrib/lite/nnapi_delegate.cc b/tensorflow/contrib/lite/nnapi_delegate.cc index b3602f799e7d05bcd837135ca60cb410ac1a4fe4..d5b9319407a461c571411c44ae702c137c914fa9 100644 --- a/tensorflow/contrib/lite/nnapi_delegate.cc +++ b/tensorflow/contrib/lite/nnapi_delegate.cc @@ -322,6 +322,7 @@ void AddOpsAndParams(tflite::Interpreter* interpreter, case tflite::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN: case tflite::BuiltinOperator_EMBEDDING_LOOKUP: case tflite::BuiltinOperator_EMBEDDING_LOOKUP_SPARSE: + case tflite::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM: case tflite::BuiltinOperator_L2_NORMALIZATION: case tflite::BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION: case tflite::BuiltinOperator_MUL: @@ -338,6 +339,7 @@ void AddOpsAndParams(tflite::Interpreter* interpreter, case tflite::BuiltinOperator_DIV: case tflite::BuiltinOperator_SUB: case tflite::BuiltinOperator_SQUEEZE: + case tflite::BuiltinOperator_STRIDED_SLICE: FATAL("Op code %d is currently not delegated to NNAPI", builtin); nn_op_type = -1; // set to invalid break; diff --git a/tensorflow/contrib/lite/nnapi_delegate.h b/tensorflow/contrib/lite/nnapi_delegate.h index f29aa9e18e605ef0b5d246b2a672639c64391646..e98000929a1168c786f6c18f498f9d1d72311ada 100644 --- a/tensorflow/contrib/lite/nnapi_delegate.h +++ b/tensorflow/contrib/lite/nnapi_delegate.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_NNAPI_DELEGATE_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_NNAPI_DELEGATE_H_ +#ifndef TENSORFLOW_CONTRIB_LITE_NNAPI_DELEGATE_H_ +#define TENSORFLOW_CONTRIB_LITE_NNAPI_DELEGATE_H_ #include "tensorflow/contrib/lite/allocation.h" #include "tensorflow/contrib/lite/context.h" @@ -63,4 +63,4 @@ class NNAPIDelegate { } // namespace tflite -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_NNAPI_DELEGATE_H_ +#endif // TENSORFLOW_CONTRIB_LITE_NNAPI_DELEGATE_H_ diff --git a/tensorflow/contrib/lite/optional_debug_tools.h b/tensorflow/contrib/lite/optional_debug_tools.h index 54d48760951c946d0493a86961348df25e53bd1f..1b6998cda382782b974bea3d18ffb6217e8f780c 100644 --- a/tensorflow/contrib/lite/optional_debug_tools.h +++ b/tensorflow/contrib/lite/optional_debug_tools.h @@ -14,8 +14,8 @@ limitations under the License. ==============================================================================*/ // Optional debugging functionality. For small sized binaries, these are not // needed. -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_DEBUG_TOOLS_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_DEBUG_TOOLS_H_ +#ifndef TENSORFLOW_CONTRIB_LITE_DEBUG_TOOLS_H_ +#define TENSORFLOW_CONTRIB_LITE_DEBUG_TOOLS_H_ #include "tensorflow/contrib/lite/interpreter.h" @@ -29,4 +29,4 @@ TfLiteStatus ValidateInterpreterState(const Interpreter* interpreter); } // namespace tflite -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_DEBUG_TOOLS_H_ +#endif // TENSORFLOW_CONTRIB_LITE_DEBUG_TOOLS_H_ diff --git a/tensorflow/contrib/lite/python/lite.py b/tensorflow/contrib/lite/python/lite.py index 4d87a5907b1335794e57689f144e03747cec9e70..3c369774beda57cca3bc1ea0ab9a9ad619841e7e 100644 --- a/tensorflow/contrib/lite/python/lite.py +++ b/tensorflow/contrib/lite/python/lite.py @@ -31,10 +31,18 @@ import tempfile from tensorflow.contrib.lite.toco import model_flags_pb2 as _model_flags_pb2 from tensorflow.contrib.lite.toco import toco_flags_pb2 as _toco_flags_pb2 from tensorflow.contrib.lite.toco import types_pb2 as _types_pb2 -from tensorflow.contrib.lite.toco.python.tensorflow_wrap_toco import TocoConvert as _toco_convert_protos from tensorflow.python.framework import dtypes as _dtypes from tensorflow.python.platform import resource_loader as _resource_loader from tensorflow.python.util.all_util import remove_undocumented +from tensorflow.python.util.lazy_loader import LazyLoader + +# Lazy load since some of the performance benchmark skylark rules +# break dependencies. +_toco_python = LazyLoader( + "tensorflow_wrap_toco", globals(), + "tensorflow.contrib.lite.toco.python." + "tensorflow_wrap_toco") +del LazyLoader # Enum types from the protobuf promoted to the API FLOAT = _types_pb2.FLOAT @@ -86,7 +94,8 @@ def toco_convert_protos(model_flags_str, toco_flags_str, input_data_str): # TODO(aselle): When toco does not use fatal errors for failure, we can # switch this on. if not _toco_from_proto_bin: - return _toco_convert_protos(model_flags_str, toco_flags_str, input_data_str) + return _toco_python.TocoConvert( + model_flags_str, toco_flags_str, input_data_str) with tempfile.NamedTemporaryFile() as fp_toco, \ tempfile.NamedTemporaryFile() as fp_model, \ diff --git a/tensorflow/contrib/lite/schema/schema.fbs b/tensorflow/contrib/lite/schema/schema.fbs index 260a87c93bf2886de5f951af9f3fd20d4c33bb83..ec202cd4073f152e1b2f4d5efd443615e901afc6 100644 --- a/tensorflow/contrib/lite/schema/schema.fbs +++ b/tensorflow/contrib/lite/schema/schema.fbs @@ -53,12 +53,12 @@ table Tensor { type:TensorType; // An index that refers to the buffers table at the root of the model. Or, // if there is no data buffer associated (i.e. intermediate results), then - // this is 0 (which refers to an always existant empty buffer). + // this is 0 (which refers to an always existent empty buffer). // // The data_buffer itself is an opaque container, with the assumption that the // target device is little-endian. In addition, all builtin operators assume // the memory is ordered such that if `shape` is [4, 3, 2], then index - // [i, j, k] maps to data_buffer[i*3*2 + j*3 + k]. + // [i, j, k] maps to data_buffer[i*3*2 + j*2 + k]. buffer:uint; name:string; // For debugging and importing back into tensorflow. quantization:QuantizationParameters; // Optional. @@ -117,6 +117,8 @@ enum BuiltinOperator : byte { SUB = 41, DIV = 42, SQUEEZE = 43, + UNIDIRECTIONAL_SEQUENCE_LSTM = 44, + STRIDED_SLICE = 45, } // Options for the builtin operators. @@ -152,6 +154,7 @@ union BuiltinOptions { DivOptions, SqueezeOptions, SequenceRNNOptions, + StridedSliceOptions, } enum Padding : byte { SAME, VALID } @@ -263,8 +266,6 @@ table LSTMOptions { } table ResizeBilinearOptions { - new_height:int; - new_width:int; } // A call operation options @@ -274,8 +275,6 @@ table CallOptions { } table PadOptions { - before_padding:[int]; - after_padding:[int]; } table ReshapeOptions { @@ -339,6 +338,14 @@ table SqueezeOptions { squeeze_dims:[int]; } +table StridedSliceOptions { + begin_mask: int; + end_mask: int; + ellipsis_mask: int; + new_axis_mask: int; + shrink_axis_mask: int; +} + // An OperatorCode can be an enum value (BuiltinOperator) if the operator is a // builtin, or a string if the operator is custom. table OperatorCode { diff --git a/tensorflow/contrib/lite/schema/schema_generated.h b/tensorflow/contrib/lite/schema/schema_generated.h old mode 100755 new mode 100644 index fd98be8f70ee06024142cb8c2099fc07ffebcb87..c04a73a2bf00807442967499cceaaee941e54278 --- a/tensorflow/contrib/lite/schema/schema_generated.h +++ b/tensorflow/contrib/lite/schema/schema_generated.h @@ -120,6 +120,9 @@ struct MeanOptionsT; struct SqueezeOptions; struct SqueezeOptionsT; +struct StridedSliceOptions; +struct StridedSliceOptionsT; + struct OperatorCode; struct OperatorCodeT; @@ -206,11 +209,13 @@ enum BuiltinOperator { BuiltinOperator_SUB = 41, BuiltinOperator_DIV = 42, BuiltinOperator_SQUEEZE = 43, + BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM = 44, + BuiltinOperator_STRIDED_SLICE = 45, BuiltinOperator_MIN = BuiltinOperator_ADD, - BuiltinOperator_MAX = BuiltinOperator_SQUEEZE + BuiltinOperator_MAX = BuiltinOperator_STRIDED_SLICE }; -inline BuiltinOperator (&EnumValuesBuiltinOperator())[41] { +inline BuiltinOperator (&EnumValuesBuiltinOperator())[43] { static BuiltinOperator values[] = { BuiltinOperator_ADD, BuiltinOperator_AVERAGE_POOL_2D, @@ -252,7 +257,9 @@ inline BuiltinOperator (&EnumValuesBuiltinOperator())[41] { BuiltinOperator_MEAN, BuiltinOperator_SUB, BuiltinOperator_DIV, - BuiltinOperator_SQUEEZE}; + BuiltinOperator_SQUEEZE, + BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM, + BuiltinOperator_STRIDED_SLICE}; return values; } @@ -301,6 +308,8 @@ inline const char **EnumNamesBuiltinOperator() { "SUB", "DIV", "SQUEEZE", + "UNIDIRECTIONAL_SEQUENCE_LSTM", + "STRIDED_SLICE", nullptr}; return names; } @@ -343,11 +352,12 @@ enum BuiltinOptions { BuiltinOptions_DivOptions = 29, BuiltinOptions_SqueezeOptions = 30, BuiltinOptions_SequenceRNNOptions = 31, + BuiltinOptions_StridedSliceOptions = 32, BuiltinOptions_MIN = BuiltinOptions_NONE, - BuiltinOptions_MAX = BuiltinOptions_SequenceRNNOptions + BuiltinOptions_MAX = BuiltinOptions_StridedSliceOptions }; -inline BuiltinOptions (&EnumValuesBuiltinOptions())[32] { +inline BuiltinOptions (&EnumValuesBuiltinOptions())[33] { static BuiltinOptions values[] = { BuiltinOptions_NONE, BuiltinOptions_Conv2DOptions, @@ -380,7 +390,8 @@ inline BuiltinOptions (&EnumValuesBuiltinOptions())[32] { BuiltinOptions_SubOptions, BuiltinOptions_DivOptions, BuiltinOptions_SqueezeOptions, - BuiltinOptions_SequenceRNNOptions}; + BuiltinOptions_SequenceRNNOptions, + BuiltinOptions_StridedSliceOptions}; return values; } @@ -417,6 +428,7 @@ inline const char **EnumNamesBuiltinOptions() { "DivOptions", "SqueezeOptions", "SequenceRNNOptions", + "StridedSliceOptions", nullptr}; return names; } @@ -590,6 +602,11 @@ struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_SequenceRNNOptions; }; +template <> +struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_StridedSliceOptions; +}; + struct BuiltinOptionsUnion { BuiltinOptions type; void *value; @@ -947,6 +964,16 @@ struct BuiltinOptionsUnion { ? reinterpret_cast(value) : nullptr; } + StridedSliceOptionsT *AsStridedSliceOptions() { + return type == BuiltinOptions_StridedSliceOptions + ? reinterpret_cast(value) + : nullptr; + } + const StridedSliceOptionsT *AsStridedSliceOptions() const { + return type == BuiltinOptions_StridedSliceOptions + ? reinterpret_cast(value) + : nullptr; + } }; bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, @@ -2630,26 +2657,13 @@ flatbuffers::Offset CreateCallOptions( struct PadOptionsT : public flatbuffers::NativeTable { typedef PadOptions TableType; - std::vector before_padding; - std::vector after_padding; PadOptionsT() {} }; struct PadOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { typedef PadOptionsT NativeTableType; - enum { VT_BEFORE_PADDING = 4, VT_AFTER_PADDING = 6 }; - const flatbuffers::Vector *before_padding() const { - return GetPointer *>(VT_BEFORE_PADDING); - } - const flatbuffers::Vector *after_padding() const { - return GetPointer *>(VT_AFTER_PADDING); - } bool Verify(flatbuffers::Verifier &verifier) const { - return VerifyTableStart(verifier) && - VerifyOffset(verifier, VT_BEFORE_PADDING) && - verifier.Verify(before_padding()) && - VerifyOffset(verifier, VT_AFTER_PADDING) && - verifier.Verify(after_padding()) && verifier.EndTable(); + return VerifyTableStart(verifier) && verifier.EndTable(); } PadOptionsT *UnPack( const flatbuffers::resolver_function_t *_resolver = nullptr) const; @@ -2664,14 +2678,6 @@ struct PadOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { struct PadOptionsBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; - void add_before_padding( - flatbuffers::Offset> before_padding) { - fbb_.AddOffset(PadOptions::VT_BEFORE_PADDING, before_padding); - } - void add_after_padding( - flatbuffers::Offset> after_padding) { - fbb_.AddOffset(PadOptions::VT_AFTER_PADDING, after_padding); - } explicit PadOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); @@ -2685,24 +2691,11 @@ struct PadOptionsBuilder { }; inline flatbuffers::Offset CreatePadOptions( - flatbuffers::FlatBufferBuilder &_fbb, - flatbuffers::Offset> before_padding = 0, - flatbuffers::Offset> after_padding = 0) { + flatbuffers::FlatBufferBuilder &_fbb) { PadOptionsBuilder builder_(_fbb); - builder_.add_after_padding(after_padding); - builder_.add_before_padding(before_padding); return builder_.Finish(); } -inline flatbuffers::Offset CreatePadOptionsDirect( - flatbuffers::FlatBufferBuilder &_fbb, - const std::vector *before_padding = nullptr, - const std::vector *after_padding = nullptr) { - return tflite::CreatePadOptions( - _fbb, before_padding ? _fbb.CreateVector(*before_padding) : 0, - after_padding ? _fbb.CreateVector(*after_padding) : 0); -} - flatbuffers::Offset CreatePadOptions( flatbuffers::FlatBufferBuilder &_fbb, const PadOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); @@ -3529,6 +3522,111 @@ flatbuffers::Offset CreateSqueezeOptions( flatbuffers::FlatBufferBuilder &_fbb, const SqueezeOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +struct StridedSliceOptionsT : public flatbuffers::NativeTable { + typedef StridedSliceOptions TableType; + int32_t begin_mask; + int32_t end_mask; + int32_t ellipsis_mask; + int32_t new_axis_mask; + int32_t shrink_axis_mask; + StridedSliceOptionsT() + : begin_mask(0), + end_mask(0), + ellipsis_mask(0), + new_axis_mask(0), + shrink_axis_mask(0) {} +}; + +struct StridedSliceOptions FLATBUFFERS_FINAL_CLASS + : private flatbuffers::Table { + typedef StridedSliceOptionsT NativeTableType; + enum { + VT_BEGIN_MASK = 4, + VT_END_MASK = 6, + VT_ELLIPSIS_MASK = 8, + VT_NEW_AXIS_MASK = 10, + VT_SHRINK_AXIS_MASK = 12 + }; + int32_t begin_mask() const { return GetField(VT_BEGIN_MASK, 0); } + int32_t end_mask() const { return GetField(VT_END_MASK, 0); } + int32_t ellipsis_mask() const { + return GetField(VT_ELLIPSIS_MASK, 0); + } + int32_t new_axis_mask() const { + return GetField(VT_NEW_AXIS_MASK, 0); + } + int32_t shrink_axis_mask() const { + return GetField(VT_SHRINK_AXIS_MASK, 0); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_BEGIN_MASK) && + VerifyField(verifier, VT_END_MASK) && + VerifyField(verifier, VT_ELLIPSIS_MASK) && + VerifyField(verifier, VT_NEW_AXIS_MASK) && + VerifyField(verifier, VT_SHRINK_AXIS_MASK) && + verifier.EndTable(); + } + StridedSliceOptionsT *UnPack( + const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo( + StridedSliceOptionsT *_o, + const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack( + flatbuffers::FlatBufferBuilder &_fbb, const StridedSliceOptionsT *_o, + const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct StridedSliceOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_begin_mask(int32_t begin_mask) { + fbb_.AddElement(StridedSliceOptions::VT_BEGIN_MASK, begin_mask, 0); + } + void add_end_mask(int32_t end_mask) { + fbb_.AddElement(StridedSliceOptions::VT_END_MASK, end_mask, 0); + } + void add_ellipsis_mask(int32_t ellipsis_mask) { + fbb_.AddElement(StridedSliceOptions::VT_ELLIPSIS_MASK, + ellipsis_mask, 0); + } + void add_new_axis_mask(int32_t new_axis_mask) { + fbb_.AddElement(StridedSliceOptions::VT_NEW_AXIS_MASK, + new_axis_mask, 0); + } + void add_shrink_axis_mask(int32_t shrink_axis_mask) { + fbb_.AddElement(StridedSliceOptions::VT_SHRINK_AXIS_MASK, + shrink_axis_mask, 0); + } + explicit StridedSliceOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + StridedSliceOptionsBuilder &operator=(const StridedSliceOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateStridedSliceOptions( + flatbuffers::FlatBufferBuilder &_fbb, int32_t begin_mask = 0, + int32_t end_mask = 0, int32_t ellipsis_mask = 0, int32_t new_axis_mask = 0, + int32_t shrink_axis_mask = 0) { + StridedSliceOptionsBuilder builder_(_fbb); + builder_.add_shrink_axis_mask(shrink_axis_mask); + builder_.add_new_axis_mask(new_axis_mask); + builder_.add_ellipsis_mask(ellipsis_mask); + builder_.add_end_mask(end_mask); + builder_.add_begin_mask(begin_mask); + return builder_.Finish(); +} + +flatbuffers::Offset CreateStridedSliceOptions( + flatbuffers::FlatBufferBuilder &_fbb, const StridedSliceOptionsT *_o, + const flatbuffers::rehasher_function_t *_rehasher = nullptr); + struct OperatorCodeT : public flatbuffers::NativeTable { typedef OperatorCode TableType; BuiltinOperator builtin_code; @@ -3813,6 +3911,11 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { ? static_cast(builtin_options()) : nullptr; } + const StridedSliceOptions *builtin_options_as_StridedSliceOptions() const { + return builtin_options_type() == BuiltinOptions_StridedSliceOptions + ? static_cast(builtin_options()) + : nullptr; + } const flatbuffers::Vector *custom_options() const { return GetPointer *>(VT_CUSTOM_OPTIONS); } @@ -4020,6 +4123,12 @@ Operator::builtin_options_as() const { return builtin_options_as_SequenceRNNOptions(); } +template <> +inline const StridedSliceOptions * +Operator::builtin_options_as() const { + return builtin_options_as_StridedSliceOptions(); +} + struct OperatorBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; @@ -4959,11 +5068,11 @@ inline void SequenceRNNOptions::UnPackTo( { auto _e = time_major(); _o->time_major = _e; - } + }; { auto _e = fused_activation_function(); _o->fused_activation_function = _e; - } + }; } inline flatbuffers::Offset SequenceRNNOptions::Pack( @@ -5429,24 +5538,6 @@ inline void PadOptions::UnPackTo( PadOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { - auto _e = before_padding(); - if (_e) { - _o->before_padding.resize(_e->size()); - for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { - _o->before_padding[_i] = _e->Get(_i); - } - } - }; - { - auto _e = after_padding(); - if (_e) { - _o->after_padding.resize(_e->size()); - for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { - _o->after_padding[_i] = _e->Get(_i); - } - } - }; } inline flatbuffers::Offset PadOptions::Pack( @@ -5466,11 +5557,7 @@ inline flatbuffers::Offset CreatePadOptions( const flatbuffers::rehasher_function_t *__rehasher; } _va = {&_fbb, _o, _rehasher}; (void)_va; - auto _before_padding = - _o->before_padding.size() ? _fbb.CreateVector(_o->before_padding) : 0; - auto _after_padding = - _o->after_padding.size() ? _fbb.CreateVector(_o->after_padding) : 0; - return tflite::CreatePadOptions(_fbb, _before_padding, _after_padding); + return tflite::CreatePadOptions(_fbb); } inline ReshapeOptionsT *ReshapeOptions::UnPack( @@ -6037,6 +6124,67 @@ inline flatbuffers::Offset CreateSqueezeOptions( return tflite::CreateSqueezeOptions(_fbb, _squeeze_dims); } +inline StridedSliceOptionsT *StridedSliceOptions::UnPack( + const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new StridedSliceOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void StridedSliceOptions::UnPackTo( + StridedSliceOptionsT *_o, + const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { + auto _e = begin_mask(); + _o->begin_mask = _e; + }; + { + auto _e = end_mask(); + _o->end_mask = _e; + }; + { + auto _e = ellipsis_mask(); + _o->ellipsis_mask = _e; + }; + { + auto _e = new_axis_mask(); + _o->new_axis_mask = _e; + }; + { + auto _e = shrink_axis_mask(); + _o->shrink_axis_mask = _e; + }; +} + +inline flatbuffers::Offset StridedSliceOptions::Pack( + flatbuffers::FlatBufferBuilder &_fbb, const StridedSliceOptionsT *_o, + const flatbuffers::rehasher_function_t *_rehasher) { + return CreateStridedSliceOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateStridedSliceOptions( + flatbuffers::FlatBufferBuilder &_fbb, const StridedSliceOptionsT *_o, + const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { + flatbuffers::FlatBufferBuilder *__fbb; + const StridedSliceOptionsT *__o; + const flatbuffers::rehasher_function_t *__rehasher; + } _va = {&_fbb, _o, _rehasher}; + (void)_va; + auto _begin_mask = _o->begin_mask; + auto _end_mask = _o->end_mask; + auto _ellipsis_mask = _o->ellipsis_mask; + auto _new_axis_mask = _o->new_axis_mask; + auto _shrink_axis_mask = _o->shrink_axis_mask; + return tflite::CreateStridedSliceOptions(_fbb, _begin_mask, _end_mask, + _ellipsis_mask, _new_axis_mask, + _shrink_axis_mask); +} + inline OperatorCodeT *OperatorCode::UnPack( const flatbuffers::resolver_function_t *_resolver) const { auto _o = new OperatorCodeT(); @@ -6549,6 +6697,10 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } + case BuiltinOptions_StridedSliceOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } default: return false; } @@ -6697,6 +6849,10 @@ inline void *BuiltinOptionsUnion::UnPack( auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } + case BuiltinOptions_StridedSliceOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } default: return nullptr; } @@ -6832,6 +6988,10 @@ inline flatbuffers::Offset BuiltinOptionsUnion::Pack( auto ptr = reinterpret_cast(value); return CreateSequenceRNNOptions(_fbb, ptr, _rehasher).Union(); } + case BuiltinOptions_StridedSliceOptions: { + auto ptr = reinterpret_cast(value); + return CreateStridedSliceOptions(_fbb, ptr, _rehasher).Union(); + } default: return 0; } @@ -6982,6 +7142,11 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) *reinterpret_cast(u.value)); break; } + case BuiltinOptions_StridedSliceOptions: { + value = new StridedSliceOptionsT( + *reinterpret_cast(u.value)); + break; + } default: break; } @@ -7144,6 +7309,11 @@ inline void BuiltinOptionsUnion::Reset() { delete ptr; break; } + case BuiltinOptions_StridedSliceOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } default: break; } diff --git a/tensorflow/contrib/lite/simple_memory_arena.h b/tensorflow/contrib/lite/simple_memory_arena.h index 07a38c42436655d307c89a987ebba4db38eba442..0c5e00a1f2e6a3303556ec54d8e50e8398644bf5 100644 --- a/tensorflow/contrib/lite/simple_memory_arena.h +++ b/tensorflow/contrib/lite/simple_memory_arena.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_SIMPLE_MEMORY_ARENA_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_SIMPLE_MEMORY_ARENA_H_ +#ifndef TENSORFLOW_CONTRIB_LITE_SIMPLE_MEMORY_ARENA_H_ +#define TENSORFLOW_CONTRIB_LITE_SIMPLE_MEMORY_ARENA_H_ #include #include @@ -85,4 +85,4 @@ class SimpleMemoryArena { } // namespace tflite -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_SIMPLE_MEMORY_ARENA_H_ +#endif // TENSORFLOW_CONTRIB_LITE_SIMPLE_MEMORY_ARENA_H_ diff --git a/tensorflow/contrib/lite/string_util.h b/tensorflow/contrib/lite/string_util.h index 8ae05bf7f59a5f9d619ae18a342f8819e19d9888..4c5d8578ac1cd59d757df895fc67394b837b4fae 100644 --- a/tensorflow/contrib/lite/string_util.h +++ b/tensorflow/contrib/lite/string_util.h @@ -37,8 +37,8 @@ limitations under the License. // # described above. // buf.WriteToTensor(tensor) -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_STRING_UTIL_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_STRING_UTIL_H_ +#ifndef TENSORFLOW_CONTRIB_LITE_STRING_UTIL_H_ +#define TENSORFLOW_CONTRIB_LITE_STRING_UTIL_H_ #include @@ -88,4 +88,4 @@ int GetStringCount(const TfLiteTensor* tensor); StringRef GetString(const TfLiteTensor* tensor, int string_index); } // namespace tflite -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_STRING_UTIL_H_ +#endif // TENSORFLOW_CONTRIB_LITE_STRING_UTIL_H_ diff --git a/tensorflow/contrib/lite/testing/BUILD b/tensorflow/contrib/lite/testing/BUILD index 933da11353a04d4b1538c9b8d777365a875e62fc..50e8ca75f8efd600d4773b83cd2c8de11c9d13ca 100644 --- a/tensorflow/contrib/lite/testing/BUILD +++ b/tensorflow/contrib/lite/testing/BUILD @@ -46,6 +46,7 @@ gen_zipped_test_files( "space_to_batch_nd.zip", "space_to_depth.zip", "squeeze.zip", + "strided_slice.zip", "sub.zip", "transpose.zip", ], diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/contrib/lite/testing/generate_examples.py index 6c3d31fc9a278e14864c3de12be9e8d0f835c522..a639351657835a1e7d17466e70277e8bf40bc0f9 100644 --- a/tensorflow/contrib/lite/testing/generate_examples.py +++ b/tensorflow/contrib/lite/testing/generate_examples.py @@ -853,34 +853,55 @@ def make_fused_batch_norm_tests(zip_path): def make_conv_tests(zip_path): """Make a set of tests to do convolution.""" - test_parameters = [{ - "input_shape": [[1, 3, 4, 3]], - "filter_shape": [[1, 1, 3, 2]], - "strides": [[1, 1, 1, 1], [1, 2, 3, 1]], - "padding": ["SAME", "VALID"], - "data_format": ["NHWC"], # TODO(aselle): NCHW would be good - }, { - "input_shape": [[2, 14, 14, 2]], - "filter_shape": [[6, 6, 2, 2]], - "strides": [[1, 1, 1, 1], [1, 2, 3, 1]], - "padding": ["SAME", "VALID"], - "data_format": ["NHWC"], # TODO(aselle): NCHW would be good - }] + test_parameters = [ + { + "input_shape": [[1, 3, 4, 3]], + "filter_shape": [[1, 1, 3, 2]], + "strides": [[1, 1, 1, 1], [1, 2, 3, 1]], + "padding": ["SAME", "VALID"], + "data_format": ["NHWC"], # TODO(aselle): NCHW would be good + "constant_filter": [True, False], + }, + { + "input_shape": [[2, 14, 14, 2]], + "filter_shape": [[6, 6, 2, 2]], + "strides": [[1, 1, 1, 1], [1, 2, 3, 1]], + "padding": ["SAME", "VALID"], + "data_format": ["NHWC"], # TODO(aselle): NCHW would be good + "constant_filter": [True, False], + } + ] def build_graph(parameters): + """Build a conv graph given `parameters`.""" input_tensor = tf.placeholder( dtype=tf.float32, name="input", shape=parameters["input_shape"]) - filter_values = create_tensor_data(np.float32, parameters["filter_shape"]) - out = tf.nn.conv2d(input_tensor, filter_values, - strides=parameters["strides"], - padding=parameters["padding"], - data_format=parameters["data_format"]) - return [input_tensor], [out] + + # Get filter input either as a placeholder or constants. Also get a list of + # the input tensors that are represented as placeholders. + if parameters["constant_filter"]: + filter_input = create_tensor_data(np.float32, parameters["filter_shape"]) + input_tensors = [input_tensor] + else: + filter_input = tf.placeholder( + dtype=tf.float32, name="filter", shape=parameters["filter_shape"]) + input_tensors = [input_tensor, filter_input] + + out = tf.nn.conv2d( + input_tensor, + filter_input, + strides=parameters["strides"], + padding=parameters["padding"], + data_format=parameters["data_format"]) + return input_tensors, [out] def build_inputs(parameters, sess, inputs, outputs): - input_values = create_tensor_data(np.float32, parameters["input_shape"]) - return [input_values], sess.run( - outputs, feed_dict=dict(zip(inputs, [input_values]))) + # Build list of input values either containing 1 tensor (input) or 2 tensors + # (input, filter) based on whether filter is constant or variable input. + values = [create_tensor_data(np.float32, parameters["input_shape"])] + if not parameters["constant_filter"]: + values.append(create_tensor_data(np.float32, parameters["filter_shape"])) + return values, sess.run(outputs, feed_dict=dict(zip(inputs, values))) make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) @@ -889,45 +910,70 @@ def make_depthwiseconv_tests(zip_path): """Make a set of tests to do convolution.""" # Tensorflow only supports equal strides - test_parameters = [{ - "input_shape": [[1, 3, 4, 3], [1, 10, 10, 3]], - "filter_size": [[1, 1], [1, 2], [3, 3]], - "strides": [[1, 1, 1, 1], [1, 3, 3, 1]], - "channel_multiplier": [1, 2], - "rate": [[1, 1]], - "padding": ["SAME", "VALID"], - "data_format": ["NHWC"], - }, { - "input_shape": [[1, 3, 4, 3]], - "filter_size": [[1, 1]], - "strides": [[1, 1, 2, 1]], # TF needs [1, x, x, 1] - "channel_multiplier": [2], - "rate": [[2, 2]], # Only [1, 1] is supported - "padding": ["SAME"], - "data_format": ["NHWC"], - }] + test_parameters = [ + { + "input_shape": [[1, 3, 4, 3], [1, 10, 10, 3]], + "filter_size": [[1, 1], [1, 2], [3, 3]], + "strides": [[1, 1, 1, 1], [1, 3, 3, 1]], + "channel_multiplier": [1, 2], + "rate": [[1, 1]], + "padding": ["SAME", "VALID"], + "data_format": ["NHWC"], + "constant_filter": [True, False], + }, + { + "input_shape": [[1, 3, 4, 3]], + "filter_size": [[1, 1]], + "strides": [[1, 1, 2, 1]], # TF needs [1, x, x, 1] + "channel_multiplier": [2], + "rate": [[2, 2]], # Only [1, 1] is supported + "padding": ["SAME"], + "data_format": ["NHWC"], + "constant_filter": [True, False], + } + ] - def build_graph(parameters): - """Build a depthwise conv graph given `parameters`.""" + def get_tensor_shapes(parameters): input_shape = parameters["input_shape"] filter_size = parameters["filter_size"] + filter_shape = filter_size + [ + input_shape[3], parameters["channel_multiplier"] + ] + return [input_shape, filter_shape] + + def build_graph(parameters): + """Build a depthwise conv graph given `parameters`.""" + input_shape, filter_shape = get_tensor_shapes(parameters) input_tensor = tf.placeholder( dtype=tf.float32, name="input", shape=input_shape) - filter_shape = filter_size + [ - input_shape[3], parameters["channel_multiplier"]] - filter_values = create_tensor_data(np.float32, filter_shape) + + # Get filter input either as a placeholder or constants. Also get a list of + # the input tensors that are represented as placeholders. + if parameters["constant_filter"]: + filter_input = create_tensor_data(np.float32, filter_shape) + input_tensors = [input_tensor] + else: + filter_input = tf.placeholder( + dtype=tf.float32, name="filter", shape=filter_shape) + input_tensors = [input_tensor, filter_input] + out = tf.nn.depthwise_conv2d( - input_tensor, filter_values, + input_tensor, + filter_input, strides=parameters["strides"], rate=parameters["rate"], padding=parameters["padding"], data_format=parameters["data_format"]) - return [input_tensor], [out] + return input_tensors, [out] def build_inputs(parameters, sess, inputs, outputs): - input_values = create_tensor_data(np.float32, parameters["input_shape"]) - return [input_values], sess.run( - outputs, feed_dict=dict(zip(inputs, [input_values]))) + # Build list of input values either containing 1 tensor (input) or 2 tensors + # (input, filter) based on whether filter is constant or variable input. + input_shape, filter_shape = get_tensor_shapes(parameters) + values = [create_tensor_data(np.float32, input_shape)] + if not parameters["constant_filter"]: + values.append(create_tensor_data(np.float32, filter_shape)) + return values, sess.run(outputs, feed_dict=dict(zip(inputs, values))) make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) @@ -978,32 +1024,49 @@ def make_fully_connected_tests(zip_path): "shape2": [[3, 3]], "transpose_a": [True, False], "transpose_b": [True, False], + "constant_filter": [True, False], }, { "shape1": [[4, 4], [1, 4], [4]], "shape2": [[4, 4], [4, 1], [4]], "transpose_a": [False], "transpose_b": [False], + "constant_filter": [True, False], }, { "shape1": [[40, 37]], "shape2": [[37, 40]], "transpose_a": [False], "transpose_b": [False], - + "constant_filter": [True, False], }] def build_graph(parameters): + """Build a matmul graph given `parameters`.""" input_tensor1 = tf.placeholder(dtype=tf.float32, name="input1", shape=parameters["shape1"]) - input_tensor2 = create_tensor_data(np.float32, parameters["shape2"]) + + # Get input_tensor2 either as a placeholder or constants. Also get a list of + # the input tensors that are represented as placeholders. + if parameters["constant_filter"]: + input_tensor2 = create_tensor_data(np.float32, parameters["shape2"]) + input_tensors = [input_tensor1] + else: + input_tensor2 = tf.placeholder( + dtype=tf.float32, name="input2", shape=parameters["shape2"]) + input_tensors = [input_tensor1, input_tensor2] + out = tf.matmul(input_tensor1, input_tensor2, transpose_a=parameters["transpose_a"], transpose_b=parameters["transpose_b"]) - return [input_tensor1], [out] + return input_tensors, [out] def build_inputs(parameters, sess, inputs, outputs): - input_values1 = create_tensor_data(np.float32, shape=parameters["shape1"]) - return [input_values1], sess.run( - outputs, feed_dict=dict(zip(inputs, [input_values1]))) + # Build list of input values either containing 1 tensor (input_values1) or 2 + # tensors (input_values1, input_values2) based on whether the second input + # is a constant or variable input. + values = [create_tensor_data(np.float32, shape=parameters["shape1"])] + if not parameters["constant_filter"]: + values.append(create_tensor_data(np.float32, parameters["shape2"])) + return values, sess.run(outputs, feed_dict=dict(zip(inputs, values))) make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) @@ -1078,28 +1141,43 @@ def make_pad_tests(zip_path): "input_shape": [[1, 1, 2, 1], [2, 1, 1, 1]], "paddings": [[[0, 0], [0, 1], [2, 3], [0, 0]], [[0, 1], [0, 0], [0, 0], [2, 3]]], + "constant_paddings": [True, False], }, # Non-4D use case. { "dtype": [tf.int32, tf.int64, tf.float32], "input_shape": [[1, 2], [0, 1, 2]], "paddings": [[[0, 1], [2, 3]]], + "constant_paddings": [True, False], }, ] def build_graph(parameters): + """Build a pad graph given `parameters`.""" input_tensor = tf.placeholder( dtype=parameters["dtype"], name="input", shape=parameters["input_shape"]) - out = tf.pad(input_tensor, paddings=parameters["paddings"]) - return [input_tensor], [out] + + # Get paddings as either a placeholder or constants. + if parameters["constant_paddings"]: + paddings = parameters["paddings"] + input_tensors = [input_tensor] + else: + shape = [len(parameters["paddings"]), 2] + paddings = tf.placeholder(dtype=tf.int32, name="padding", shape=shape) + input_tensors = [input_tensor, paddings] + + out = tf.pad(input_tensor, paddings=paddings) + return input_tensors, [out] def build_inputs(parameters, sess, inputs, outputs): - input_values = create_tensor_data(parameters["dtype"], - parameters["input_shape"]) - return [input_values], sess.run( - outputs, feed_dict=dict(zip(inputs, [input_values]))) + values = [ + create_tensor_data(parameters["dtype"], parameters["input_shape"]) + ] + if not parameters["constant_paddings"]: + values.append(np.array(parameters["paddings"])) + return values, sess.run(outputs, feed_dict=dict(zip(inputs, values))) make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) @@ -1361,6 +1439,10 @@ def make_squeeze_tests(zip_path): "dtype": [tf.int32, tf.float32, tf.int64], "input_shape": [[1]], "axis": [None, [], [0], [-1]], + }, { + "dtype": [tf.int32, tf.float32, tf.int64], + "input_shape": [[1, 1, 1, 1, 1]], + "axis": [None, [], [0], [3, 0], [-2, 0, 3, 2]], }] def build_graph(parameters): @@ -1380,6 +1462,97 @@ def make_squeeze_tests(zip_path): make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) +def make_strided_slice_tests(zip_path): + """Make a set of tests to do strided_slice.""" + + # TODO(soroosh): add test/support for uint8. + test_parameters = [ + # 4-D + { + "dtype": [tf.float32, tf.int32, tf.int64], + "index_type": [tf.int32], + "input_shape": [[12, 2, 2, 5]], + "begin": [[0, 0, 0, 0], [1, 0, 1, 0]], + "end": [[8, 2, 2, 3], [12, 2, 2, 5]], + "strides": [None, [1, 1, 1, 1], [2, 1, 3, 1]], + "begin_mask": [None, 0, 1, 2, 8], + "end_mask": [None, 0, 1, 2, 8], + }, + # 2-D + { + "dtype": [tf.float32, tf.int32, tf.int64], + "index_type": [tf.int32], + "input_shape": [[2, 3]], + "begin": [[0, 0], [1, 0]], + "end": [[2, 3], [2, 2]], + "strides": [None, [1, 1], [2, 2]], + "begin_mask": [None, 0, 1, 2], + "end_mask": [None, 0, 1, 2], + }, + # Negative strides + { + "dtype": [tf.float32, tf.int32, tf.int64], + "index_type": [tf.int32], + "input_shape": [[2, 3]], + "begin": [[0, -1]], + "end": [[2, -3]], + "strides": [[1, -1]], + "begin_mask": [None, 0, 1, 2], + "end_mask": [None, 0, 1, 2], + }, + ] + + def build_graph(parameters): + """Build graph for stride_slice test.""" + input_tensor = tf.placeholder( + dtype=parameters["dtype"], + name="input", + shape=parameters["input_shape"]) + begin = tf.placeholder( + dtype=parameters["index_type"], + name="begin", + shape=[len(parameters["input_shape"])]) + end = tf.placeholder( + dtype=parameters["index_type"], + name="end", + shape=[len(parameters["input_shape"])]) + strides = ( + tf.placeholder( + dtype=parameters["index_type"], + name="strides", + shape=[len(parameters["input_shape"])]) + if parameters["strides"] is not None else None) + tensors = [input_tensor, begin, end] + if strides is not None: + tensors.append(strides) + out = tf.strided_slice( + input_tensor, + begin, + end, + strides, + begin_mask=parameters["begin_mask"], + end_mask=parameters["end_mask"]) + return tensors, [out] + + def build_inputs(parameters, sess, inputs, outputs): + """Build inputs for stride_slice test.""" + input_values = create_tensor_data(parameters["dtype"], + parameters["input_shape"]) + index_type = _TF_TYPE_INFO[parameters["index_type"]][0] + begin_values = np.array(parameters["begin"]).astype(index_type) + end_values = np.array(parameters["end"]).astype(index_type) + stride_values = ( + np.array(parameters["strides"]).astype(index_type) + if parameters["strides"] is not None else None) + values = [input_values, begin_values, end_values] + if stride_values is not None: + values.append(stride_values) + + return values, sess.run(outputs, feed_dict=dict(zip(inputs, values))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + def make_l2_pool(input_tensor, ksize, strides, padding, data_format): """Given an input perform a sequence of TensorFlow ops to produce l2pool.""" return tf.sqrt(tf.nn.avg_pool( @@ -1438,6 +1611,7 @@ def main(unused_args): "transpose.zip": make_transpose_tests, "mean.zip": make_mean_tests, "squeeze.zip": make_squeeze_tests, + "strided_slice.zip": make_strided_slice_tests, } out = FLAGS.zip_to_output bin_path = FLAGS.toco diff --git a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc index c8a6e07abd02633f90ea768ad6f65d2a7d9d716a..41652a07d21fbf022cb66a4022706cfee02d2c09 100644 --- a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc +++ b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc @@ -48,47 +48,51 @@ tensorflow::Env* env = tensorflow::Env::Default(); // TODO(ahentz): make sure we clean this list up frequently. std::map kBrokenTests = { // Add doesn't support broadcasting. - {R"(adda.*input_shape_1=\[1,3,4,3\],input_shape_2=\[3\])", "68500195"}, - {R"(mula.*input_shape_1=\[1,3,4,3\],input_shape_2=\[3\])", "68500195"}, - {R"(diva.*input_shape_1=\[1,3,4,3\],input_shape_2=\[3\])", "68500195"}, - {R"(suba.*input_shape_1=\[1,3,4,3\],input_shape_2=\[3\])", "68500195"}, + {R"(^\/adda.*input_shape_1=\[1,3,4,3\],input_shape_2=\[3\])", "68500195"}, + {R"(^\/mula.*input_shape_1=\[1,3,4,3\],input_shape_2=\[3\])", "68500195"}, + {R"(^\/diva.*input_shape_1=\[1,3,4,3\],input_shape_2=\[3\])", "68500195"}, + {R"(^\/suba.*input_shape_1=\[1,3,4,3\],input_shape_2=\[3\])", "68500195"}, // Add only supports float32. (and "constant" tests use Add) - {R"(adda.*int32)", "68808744"}, - {R"(constant.*int32)", "68808744"}, - {R"(mul.*int32)", "68808744"}, - {R"(div.*int32)", "68808744"}, - {R"(sub.*int32)", "68808744"}, + {R"(^\/adda.*int32)", "68808744"}, + {R"(^\/constant.*int32)", "68808744"}, + {R"(^\/mul.*int32)", "68808744"}, + {R"(^\/div.*int32)", "68808744"}, + {R"(^\/sub.*int32)", "68808744"}, // Pad only supports 4D tensors. - {R"(paddtype=.*,input_shape=\[.,.\],paddings=\[\[.,.\],\[.,.\]\])", + {R"(^\/pad.*,input_shape=\[.,.\],paddings=\[\[.,.\],\[.,.\]\])", "70527055"}, // L2Norm only supports tensors with 4D or fewer. - {R"(l2normdim=.*,epsilon=.*,input_shape=\[.,.,.,.,.*\])", "67963684"}, + {R"(^\/l2normdim=.*,epsilon=.*,input_shape=\[.,.,.,.,.*\])", "67963684"}, // SpaceToBatch only supports 4D tensors. - {R"(space_to_batch_nd.*input_shape=\[1,4,4,4,1,1\])", "70848787"}, + {R"(^\/space_to_batch_nd.*input_shape=\[1,4,4,4,1,1\])", "70848787"}, // L2Norm only works for dim=-1. - {R"(l2normdim=-2,epsilon=.*,input_shape=\[.,.\])", "67963812"}, - {R"(l2normdim=0,epsilon=.*,input_shape=\[.,.\])", "67963812"}, - {R"(l2normdim=-2,epsilon=.*,input_shape=\[3,15,14,3\])", "67963812"}, - {R"(l2normdim=-2,epsilon=.*,input_shape=\[1,3,4,3\])", "67963812"}, - {R"(l2normdim=2,epsilon=.*,input_shape=\[3,15,14,3\])", "67963812"}, - {R"(l2normdim=2,epsilon=.*,input_shape=\[1,3,4,3\])", "67963812"}, - {R"(l2normdim=0,epsilon=.*,input_shape=\[3,15,14,3\])", "67963812"}, - {R"(l2normdim=0,epsilon=.*,input_shape=\[1,3,4,3\])", "67963812"}, - {R"(l2normdim=1,epsilon=.*,input_shape=\[3,15,14,3\])", "67963812"}, - {R"(l2normdim=1,epsilon=.*,input_shape=\[1,3,4,3\])", "67963812"}, - {R"(l2normdim=\[2,3\],epsilon=.*,input_shape=\[3,15,14,3\])", "67963812"}, - {R"(l2normdim=\[2,3\],epsilon=.*,input_shape=\[1,3,4,3\])", "67963812"}, + {R"(^\/l2normdim=-2,epsilon=.*,input_shape=\[.,.\])", "67963812"}, + {R"(^\/l2normdim=0,epsilon=.*,input_shape=\[.,.\])", "67963812"}, + {R"(^\/l2normdim=-2,epsilon=.*,input_shape=\[3,15,14,3\])", "67963812"}, + {R"(^\/l2normdim=-2,epsilon=.*,input_shape=\[1,3,4,3\])", "67963812"}, + {R"(^\/l2normdim=2,epsilon=.*,input_shape=\[3,15,14,3\])", "67963812"}, + {R"(^\/l2normdim=2,epsilon=.*,input_shape=\[1,3,4,3\])", "67963812"}, + {R"(^\/l2normdim=0,epsilon=.*,input_shape=\[3,15,14,3\])", "67963812"}, + {R"(^\/l2normdim=0,epsilon=.*,input_shape=\[1,3,4,3\])", "67963812"}, + {R"(^\/l2normdim=1,epsilon=.*,input_shape=\[3,15,14,3\])", "67963812"}, + {R"(^\/l2normdim=1,epsilon=.*,input_shape=\[1,3,4,3\])", "67963812"}, + {R"(^\/l2normdim=\[2,3\],epsilon=.*,input_shape=\[3,15,14,3\])", + "67963812"}, + {R"(^\/l2normdim=\[2,3\],epsilon=.*,input_shape=\[1,3,4,3\])", "67963812"}, // ResizeBilinear looks completely incompatible with Tensorflow - {R"(resize_bilinear)", "67964336"}, + {R"(^\/resize_bilinear.*dtype=tf.int32)", "72401107"}, + {R"(^\/resize_bilinearalign_corners=True,.*,size=\[2,2\])", "72401483"}, + {R"(^\/resize_bilinearalign_corners=True,.*,size=\[4,3\])", "72401483"}, + {R"(^\/resize_bilinearalign_corners=True,.*,size=\[5,6\])", "72401483"}, // Transpose only supports 1D-4D input tensors. - {R"(transposedtype=.*,input_shape=\[.,.,.,.,.\],perm=.*)", "71545879"}, + {R"(^\/transposedtype=.*,input_shape=\[.,.,.,.,.\],perm=.*)", "71545879"}, }; // Allows test data to be unzipped into a temporary directory and makes @@ -263,6 +267,7 @@ INSTANTIATE_TESTS(div) INSTANTIATE_TESTS(transpose) INSTANTIATE_TESTS(mean) INSTANTIATE_TESTS(squeeze) +INSTANTIATE_TESTS(strided_slice) } // namespace testing } // namespace tflite diff --git a/tensorflow/contrib/lite/testing/message.h b/tensorflow/contrib/lite/testing/message.h index 78ef7e2cbe1c323753ac36f1be06a089e650aa37..e2bc4082141f0601c141a193fbea75f8f759146a 100644 --- a/tensorflow/contrib/lite/testing/message.h +++ b/tensorflow/contrib/lite/testing/message.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_MESSAGE_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_MESSAGE_H_ +#ifndef TENSORFLOW_CONTRIB_LITE_TESTING_MESSAGE_H_ +#define TENSORFLOW_CONTRIB_LITE_TESTING_MESSAGE_H_ #include #include @@ -79,4 +79,4 @@ class Message { } // namespace testing } // namespace tflite -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_MESSAGE_H_ +#endif // TENSORFLOW_CONTRIB_LITE_TESTING_MESSAGE_H_ diff --git a/tensorflow/contrib/lite/testing/parse_testdata.cc b/tensorflow/contrib/lite/testing/parse_testdata.cc index 7c371f2bd445e10bc6d4b20793582c34300316b3..0caef0fe2201a668b2235a98304eb353072a3c2f 100644 --- a/tensorflow/contrib/lite/testing/parse_testdata.cc +++ b/tensorflow/contrib/lite/testing/parse_testdata.cc @@ -18,6 +18,7 @@ limitations under the License. // ASCII file. #include "tensorflow/contrib/lite/testing/parse_testdata.h" +#include #include #include #include @@ -218,8 +219,8 @@ TfLiteStatus CheckOutputs(tflite::Interpreter* interpreter, int32_t computed = data[idx]; int32_t reference = example.outputs[0].flat_data[idx]; if (std::abs(computed - reference) > 0) { - fprintf(stderr, "output[%zu][%zu] did not match %d vs reference %f\n", - i, idx, data[idx], example.outputs[0].flat_data[idx]); + fprintf(stderr, "output[%zu][%zu] did not match %d vs reference %d\n", + i, idx, computed, reference); return kTfLiteError; } } @@ -231,8 +232,9 @@ TfLiteStatus CheckOutputs(tflite::Interpreter* interpreter, int64_t reference = example.outputs[0].flat_data[idx]; if (std::abs(computed - reference) > 0) { fprintf(stderr, - "output[%zu][%zu] did not match %ld vs reference %f\n", i, - idx, data[idx], example.outputs[0].flat_data[idx]); + "output[%zu][%zu] did not match %" PRId64 + " vs reference %" PRId64 "\n", + i, idx, computed, reference); return kTfLiteError; } } diff --git a/tensorflow/contrib/lite/testing/parse_testdata.h b/tensorflow/contrib/lite/testing/parse_testdata.h index 90839fe24550b6c4a0a3a3f4115c479a71580bb0..7ebf362eb99c5f4cf6ea3654cf71e13ff1de99b3 100644 --- a/tensorflow/contrib/lite/testing/parse_testdata.h +++ b/tensorflow/contrib/lite/testing/parse_testdata.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_NNAPI_PARSE_TESTDATA_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_NNAPI_PARSE_TESTDATA_H_ +#ifndef TENSORFLOW_CONTRIB_LITE_NNAPI_PARSE_TESTDATA_H_ +#define TENSORFLOW_CONTRIB_LITE_NNAPI_PARSE_TESTDATA_H_ #include #include "tensorflow/contrib/lite/interpreter.h" @@ -71,4 +71,4 @@ bool ParseAndRunTests(std::istream* input, TestRunner* test_runner); } // namespace testing } // namespace tflite -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_NNAPI_PARSE_TESTDATA_H_ +#endif // TENSORFLOW_CONTRIB_LITE_NNAPI_PARSE_TESTDATA_H_ diff --git a/tensorflow/contrib/lite/testing/split.h b/tensorflow/contrib/lite/testing/split.h index cfc1e929e9e66a6641fc3a9c47cbe511f692b748..428cfda4f216f0ee6409a32c43a4cf91ecc11922 100644 --- a/tensorflow/contrib/lite/testing/split.h +++ b/tensorflow/contrib/lite/testing/split.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_SPLIT_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_SPLIT_H_ +#ifndef TENSORFLOW_CONTRIB_LITE_TESTING_SPLIT_H_ +#define TENSORFLOW_CONTRIB_LITE_TESTING_SPLIT_H_ #include #include @@ -83,4 +83,4 @@ inline std::vector Split(const string& s, const string& delimiter) { } // namespace testing } // namespace tflite -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_SPLIT_H_ +#endif // TENSORFLOW_CONTRIB_LITE_TESTING_SPLIT_H_ diff --git a/tensorflow/contrib/lite/testing/test_runner.h b/tensorflow/contrib/lite/testing/test_runner.h index f4b26949b57e0702ac5554afd766a6072af268a4..60eaafa474a01887bee12b031b1f59cc5c91f173 100644 --- a/tensorflow/contrib/lite/testing/test_runner.h +++ b/tensorflow/contrib/lite/testing/test_runner.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_TEST_RUNNER_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_TEST_RUNNER_H_ +#ifndef TENSORFLOW_CONTRIB_LITE_TESTING_TEST_RUNNER_H_ +#define TENSORFLOW_CONTRIB_LITE_TESTING_TEST_RUNNER_H_ #include #include @@ -121,4 +121,4 @@ class TestRunner { } // namespace testing } // namespace tflite -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_TEST_RUNNER_H_ +#endif // TENSORFLOW_CONTRIB_LITE_TESTING_TEST_RUNNER_H_ diff --git a/tensorflow/contrib/lite/testing/tflite_driver.h b/tensorflow/contrib/lite/testing/tflite_driver.h index 4440d4285e948c3d1622c8de5c47ff3729c5847f..25689a9fb42c06fa3f8f2f92064cf59e8c331637 100644 --- a/tensorflow/contrib/lite/testing/tflite_driver.h +++ b/tensorflow/contrib/lite/testing/tflite_driver.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_TFLITE_DRIVER_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_TFLITE_DRIVER_H_ +#ifndef TENSORFLOW_CONTRIB_LITE_TESTING_TFLITE_DRIVER_H_ +#define TENSORFLOW_CONTRIB_LITE_TESTING_TFLITE_DRIVER_H_ #include @@ -59,4 +59,4 @@ class TfLiteDriver : public TestRunner { } // namespace testing } // namespace tflite -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_TFLITE_DRIVER_H_ +#endif // TENSORFLOW_CONTRIB_LITE_TESTING_TFLITE_DRIVER_H_ diff --git a/tensorflow/contrib/lite/testing/tokenize.h b/tensorflow/contrib/lite/testing/tokenize.h index daccf0e84a450a0ffdf04a1eb8ff319878cfc808..7ed8eb96b7a10eecd915fe426ab3abf0e7a46ca4 100644 --- a/tensorflow/contrib/lite/testing/tokenize.h +++ b/tensorflow/contrib/lite/testing/tokenize.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_TOKENIZER_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_TOKENIZER_H_ +#ifndef TENSORFLOW_CONTRIB_LITE_TESTING_TOKENIZER_H_ +#define TENSORFLOW_CONTRIB_LITE_TESTING_TOKENIZER_H_ #include #include @@ -39,4 +39,4 @@ void Tokenize(std::istream* input, TokenProcessor* processor); } // namespace testing } // namespace tflite -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_TOKENIZER_H_ +#endif // TENSORFLOW_CONTRIB_LITE_TESTING_TOKENIZER_H_ diff --git a/tensorflow/contrib/lite/testing/util.h b/tensorflow/contrib/lite/testing/util.h index 4d4304f022187027950f58050ececae73dedffb6..6d20aec141c7c3a3e48af290edb169c6fd7254cf 100644 --- a/tensorflow/contrib/lite/testing/util.h +++ b/tensorflow/contrib/lite/testing/util.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_UTIL_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_UTIL_H_ +#ifndef TENSORFLOW_CONTRIB_LITE_TESTING_UTIL_H_ +#define TENSORFLOW_CONTRIB_LITE_TESTING_UTIL_H_ namespace tflite { @@ -25,4 +25,4 @@ inline void LogToStderr() { } // namespace tflite -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_UTIL_H_ +#endif // TENSORFLOW_CONTRIB_LITE_TESTING_UTIL_H_ diff --git a/tensorflow/contrib/lite/tflite_static.bp b/tensorflow/contrib/lite/tflite_static.bp index 771771fd5f6e013bdf5e863e9d216f9b50e86d97..3db884de660475e995040da40b8cb3ee67de0fba 100644 --- a/tensorflow/contrib/lite/tflite_static.bp +++ b/tensorflow/contrib/lite/tflite_static.bp @@ -58,9 +58,11 @@ cc_library_static { "kernels/space_to_batch_nd.cc", "kernels/space_to_depth.cc", "kernels/squeeze.cc", + "kernels/strided_slice.cc", "kernels/sub.cc", "kernels/svdf.cc", "kernels/transpose.cc", + "kernels/unidirectional_sequence_lstm.cc", "kernels/unidirectional_sequence_rnn.cc", "kernels/internal/tensor_utils.cc", "kernels/internal/quantization_util.cc", @@ -83,6 +85,7 @@ cc_library_static { "-Werror", "-Wextra", "-Wno-array-bounds", + "-Wno-extern-c-compat", "-Wno-invalid-partial-specialization", "-Wno-mismatched-tags", "-Wno-missing-field-initializers", diff --git a/tensorflow/contrib/lite/toco/BUILD b/tensorflow/contrib/lite/toco/BUILD index 967e304742fb27ba05591a3b1614de14cd9f5262..041e2487903c63572a7acda17f2f3ebc701be0c7 100644 --- a/tensorflow/contrib/lite/toco/BUILD +++ b/tensorflow/contrib/lite/toco/BUILD @@ -171,6 +171,8 @@ cc_library( srcs = [ "graph_transformations/convert_expanddims_to_reshape.cc", "graph_transformations/convert_pure_conv_to_depthwise.cc", + "graph_transformations/convert_reorder_axes.cc", + "graph_transformations/convert_trivial_addn_to_add.cc", "graph_transformations/convert_trivial_transpose_to_reshape.cc", "graph_transformations/create_im2col_arrays.cc", "graph_transformations/dequantize.cc", diff --git a/tensorflow/contrib/lite/toco/allocate_transient_arrays.cc b/tensorflow/contrib/lite/toco/allocate_transient_arrays.cc index d4da8f5dfe13a38e8b6886656c5c7e0c8fbb1316..5961d30bf5403df7fa6228e05124479d118dd279 100644 --- a/tensorflow/contrib/lite/toco/allocate_transient_arrays.cc +++ b/tensorflow/contrib/lite/toco/allocate_transient_arrays.cc @@ -148,7 +148,7 @@ std::size_t TransientArraySize(const Model& model, const string& array_name, if (!IsAllocatableTransientArray(model, array_name)) { return 0; } - const auto& array = model.arrays.at(array_name); + const auto& array = &model.GetArray(array_name); CHECK(array->has_shape()) << "Array '" << array_name << "' doesn't have a shape"; if (array->data_type == ArrayDataType::kNone) { @@ -185,7 +185,7 @@ void AllocateTransientArray(const Model& model, const string& array_name, } const std::size_t size = TransientArraySize(model, array_name, transient_data_alignment); - const auto& array = model.arrays.at(array_name); + const auto& array = &model.GetArray(array_name); CHECK(!array->alloc); allocator->Allocate(size, &array->GetOrCreateAlloc()); } @@ -197,7 +197,7 @@ void DeallocateTransientArray(const Model& model, const string& array_name, if (!IsAllocatableTransientArray(model, array_name)) { return; } - const auto& array = model.arrays.at(array_name); + const auto& array = &model.GetArray(array_name); CHECK(!!array->alloc); allocator->Deallocate(*array->alloc); } @@ -231,7 +231,7 @@ void AllocateTransientArrays(Model* model, // Construct a sorted map of array names, so that other layout engines can // match exactly. std::map ordered_arrays_map; - for (const auto& pair : model->arrays) { + for (const auto& pair : model->GetArrayMap()) { ordered_arrays_map[pair.first] = pair.second.get(); } diff --git a/tensorflow/contrib/lite/toco/allocate_transient_arrays.h b/tensorflow/contrib/lite/toco/allocate_transient_arrays.h index 12d0d0498f5224962f2775d4e3cb7d8e360cbe46..59d8ada1e9bb985f2eaa7ff6d29bc4f1b054a070 100644 --- a/tensorflow/contrib/lite/toco/allocate_transient_arrays.h +++ b/tensorflow/contrib/lite/toco/allocate_transient_arrays.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_ALLOCATE_TRANSIENT_ARRAYS_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_ALLOCATE_TRANSIENT_ARRAYS_H_ +#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_ALLOCATE_TRANSIENT_ARRAYS_H_ +#define TENSORFLOW_CONTRIB_LITE_TOCO_ALLOCATE_TRANSIENT_ARRAYS_H_ #include "tensorflow/contrib/lite/toco/model.h" @@ -41,4 +41,4 @@ void AllocateTransientArrays(Model* model, } // namespace toco -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_ALLOCATE_TRANSIENT_ARRAYS_H_ +#endif // TENSORFLOW_CONTRIB_LITE_TOCO_ALLOCATE_TRANSIENT_ARRAYS_H_ diff --git a/tensorflow/contrib/lite/toco/args.h b/tensorflow/contrib/lite/toco/args.h index eb2d7ba916e49cd5ec838eb945d478f008f149ca..8004a1a37ae48468e9bf22785ec02f8de54bf236 100644 --- a/tensorflow/contrib/lite/toco/args.h +++ b/tensorflow/contrib/lite/toco/args.h @@ -15,8 +15,8 @@ limitations under the License. // This abstracts command line arguments in toco. // Arg is a parseable type that can register a default value, be able to // parse itself, and keep track of whether it was specified. -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_ARGS_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_ARGS_H_ +#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_ARGS_H_ +#define TENSORFLOW_CONTRIB_LITE_TOCO_ARGS_H_ #include #include @@ -147,12 +147,12 @@ class Arg final { if (!TryStripPrefixString(outer_member, "{", &outer_member)) return false; if (!TryStripSuffixString(outer_member, "}", &outer_member)) return false; const std::vector inner_fields_vector = - strings::Split(outer_member, ','); + absl::StrSplit(outer_member, ','); std::unordered_map element; for (const string& member_field : inner_fields_vector) { std::vector outer_member_key_value = - strings::Split(member_field, ':'); + absl::StrSplit(member_field, ':'); if (outer_member_key_value.size() != 2) return false; string& key = outer_member_key_value[0]; string& value = outer_member_key_value[1]; @@ -232,4 +232,4 @@ struct ParsedTocoFlags { }; } // namespace toco -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_ARGS_H_ +#endif // TENSORFLOW_CONTRIB_LITE_TOCO_ARGS_H_ diff --git a/tensorflow/contrib/lite/toco/dump_graphviz.cc b/tensorflow/contrib/lite/toco/dump_graphviz.cc index 39809216c77bdadfd44aafbddc8e0979fde66a49..c726eb6d8678e2703f5acba8b3d8d740186939f5 100644 --- a/tensorflow/contrib/lite/toco/dump_graphviz.cc +++ b/tensorflow/contrib/lite/toco/dump_graphviz.cc @@ -278,8 +278,8 @@ std::vector OperatorsToDump(const Model& model) { if (last_specified) { // Return only the part of the graph between graphviz_first_array // and graphviz_last_array. - CHECK(model.arrays.count(dump_options.graphviz_first_array)); - CHECK(model.arrays.count(dump_options.graphviz_last_array)); + CHECK(model.HasArray(dump_options.graphviz_first_array)); + CHECK(model.HasArray(dump_options.graphviz_last_array)); std::unordered_set arrays_already_produced; std::vector arrays_to_produce; arrays_to_produce.push_back(dump_options.graphviz_last_array); @@ -336,7 +336,7 @@ void DumpGraphviz(const Model& model, string* output_file_contents) { op_properties.color.TextColorString().c_str()); // Add nodes and edges for all inputs of the operator. for (const auto& input : op.inputs) { - if (model.arrays.count(input) == 0) { + if (!model.HasArray(input)) { // Arrays should _always_ exist. Except, perhaps, during development. continue; } @@ -352,7 +352,7 @@ void DumpGraphviz(const Model& model, string* output_file_contents) { } // Add nodes and edges for all outputs of the operator. for (const auto& output : op.outputs) { - if (model.arrays.count(output) == 0) { + if (!model.HasArray(output)) { // Arrays should _always_ exist. Except, perhaps, during development. continue; } diff --git a/tensorflow/contrib/lite/toco/dump_graphviz.h b/tensorflow/contrib/lite/toco/dump_graphviz.h index 0fb28e3de844b123a60e36bc23c7d2add8189962..ea5a4031c39580be00130a2fd3a89c61da2acf01 100644 --- a/tensorflow/contrib/lite/toco/dump_graphviz.h +++ b/tensorflow/contrib/lite/toco/dump_graphviz.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_DUMP_GRAPHVIZ_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_DUMP_GRAPHVIZ_H_ +#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_DUMP_GRAPHVIZ_H_ +#define TENSORFLOW_CONTRIB_LITE_TOCO_DUMP_GRAPHVIZ_H_ #include @@ -25,4 +25,4 @@ void DumpGraphviz(const Model& model, string* output_file_contents); } // namespace toco -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_DUMP_GRAPHVIZ_H_ +#endif // TENSORFLOW_CONTRIB_LITE_TOCO_DUMP_GRAPHVIZ_H_ diff --git a/tensorflow/contrib/lite/toco/export_tensorflow.cc b/tensorflow/contrib/lite/toco/export_tensorflow.cc index 90fa442746cdee975b0103ce60817a95f9b31086..529df3cd2e56f1888f3d431ddcd7dc7051a98355 100644 --- a/tensorflow/contrib/lite/toco/export_tensorflow.cc +++ b/tensorflow/contrib/lite/toco/export_tensorflow.cc @@ -156,8 +156,8 @@ void ConvertFloatTensorConst(const Model& model, const string& name, const_op->set_name(name); (*const_op->mutable_attr())["dtype"].set_type(DT_FLOAT); auto* tensor = (*const_op->mutable_attr())["value"].mutable_tensor(); - CHECK(model.arrays.count(name)); - const auto& input_array = *model.arrays.at(name); + CHECK(model.HasArray(name)); + const auto& input_array = model.GetArray(name); const auto& input_shape = input_array.shape(); CHECK(input_array.buffer); CHECK(input_array.buffer->type == ArrayDataType::kFloat); @@ -177,8 +177,8 @@ void ConvertFloatTensorConst(const Model& model, const string& name, const_op->set_name(name); (*const_op->mutable_attr())["dtype"].set_type(DT_FLOAT); auto* tensor = (*const_op->mutable_attr())["value"].mutable_tensor(); - CHECK(model.arrays.count(name)); - const auto& input_array = *model.arrays.at(name); + CHECK(model.HasArray(name)); + const auto& input_array = model.GetArray(name); const auto& input_shape = input_array.shape(); CHECK(input_array.buffer); CHECK(input_array.buffer->type == ArrayDataType::kFloat); @@ -193,8 +193,8 @@ void ConvertIntTensorConst(const Model& model, const string& name, if (HasAlreadyExportedConst(name, *tensorflow_graph)) { return; } - CHECK(model.arrays.count(name)); - const auto& array = *model.arrays.at(name); + CHECK(model.HasArray(name)); + const auto& array = model.GetArray(name); auto* const_op = tensorflow_graph->add_node(); const_op->set_op("Const"); const_op->set_name(name); @@ -324,7 +324,7 @@ void ConvertConvOperator(const Model& model, const ConvOperator& src_op, biasadd_op->add_input(conv_output); biasadd_op->add_input(src_op.inputs[2]); (*biasadd_op->mutable_attr())["T"].set_type(DT_FLOAT); - CHECK(model.arrays.count(src_op.inputs[2])); + CHECK(model.HasArray(src_op.inputs[2])); const string& bias_array_name = WalkUpToConstantArray(model, src_op.inputs[2]); const auto& bias_array = model.GetArray(bias_array_name); @@ -361,7 +361,7 @@ void ConvertDepthwiseConvOperator(const Model& model, // We need to convert that to H x W x InputDepth x Multiplier. // That's only a matter of constructing a Dims object; the actual // array layout is the same. - CHECK(model.arrays.count(src_op.inputs[1])); + CHECK(model.HasArray(src_op.inputs[1])); const string& src_weights_name = WalkUpToConstantArray(model, src_op.inputs[1]); const auto& src_weights_array = model.GetArray(src_weights_name); @@ -404,7 +404,7 @@ void ConvertDepthwiseConvOperator(const Model& model, biasadd_op->add_input(conv_output); biasadd_op->add_input(src_op.inputs[2]); (*biasadd_op->mutable_attr())["T"].set_type(DT_FLOAT); - CHECK(model.arrays.count(src_op.inputs[2])); + CHECK(model.HasArray(src_op.inputs[2])); const string& bias_name = WalkUpToConstantArray(model, src_op.inputs[2]); const auto& bias_array = model.GetArray(bias_name); // TODO(b/62904716) Bias arrays should be 1-D, and used directly. @@ -469,10 +469,10 @@ void ConvertFullyConnectedOperator(const Model& model, (*matmul_op->mutable_attr())["T"].set_type(DT_FLOAT); (*matmul_op->mutable_attr())["transpose_a"].set_b(false); (*matmul_op->mutable_attr())["transpose_b"].set_b(false); - CHECK(model.arrays.count(src_op.inputs[1])); + CHECK(model.HasArray(src_op.inputs[1])); const string& fc_weights_name = WalkUpToConstantArray(model, src_op.inputs[1]); - const auto& fc_weights_array = *model.arrays.at(fc_weights_name); + const auto& fc_weights_array = model.GetArray(fc_weights_name); const auto& fc_weights_shape = fc_weights_array.shape(); CHECK_EQ(fc_weights_shape.dimensions_count(), 2); CreateMatrixShapeTensorConst(reshape_shape, fc_weights_shape.dims(1), -1, @@ -492,8 +492,8 @@ void ConvertFullyConnectedOperator(const Model& model, biasadd_op->add_input(matmul_output); biasadd_op->add_input(src_op.inputs[2]); (*biasadd_op->mutable_attr())["T"].set_type(DT_FLOAT); - CHECK(model.arrays.count(src_op.inputs[2])); - const auto& bias_array = *model.arrays.at(src_op.inputs[2]); + CHECK(model.HasArray(src_op.inputs[2])); + const auto& bias_array = model.GetArray(src_op.inputs[2]); // TODO(b/62904716) Bias arrays should be 1-D, and used directly. Shape bias_shape_1d = bias_array.shape(); UnextendShape(&bias_shape_1d, 1); @@ -519,6 +519,18 @@ void ConvertAddOperator(const Model& model, const AddOperator& src_op, (*add_op->mutable_attr())["T"].set_type(DT_FLOAT); } +void ConvertAddNOperator(const Model& model, const AddNOperator& src_op, + GraphDef* tensorflow_graph) { + auto* add_op = tensorflow_graph->add_node(); + add_op->set_op("AddN"); + add_op->set_name(src_op.outputs[0]); + for (const auto& input : src_op.inputs) { + *add_op->add_input() = input; + } + (*add_op->mutable_attr())["N"].set_i(src_op.inputs.size()); + (*add_op->mutable_attr())["T"].set_type(DT_FLOAT); +} + void ConvertMulOperator(const Model& model, const MulOperator& src_op, GraphDef* tensorflow_graph) { auto* add_op = tensorflow_graph->add_node(); @@ -625,7 +637,7 @@ void ConvertSoftmaxOperator(const Model& model, const SoftmaxOperator& src_op, *reshape_op->add_input() = softmax_size; (*reshape_op->mutable_attr())["T"].set_type(DT_FLOAT); - const auto& input_shape = model.arrays.at(src_op.inputs[0])->shape(); + const auto& input_shape = model.GetArray(src_op.inputs[0]).shape(); int32 flattened_size = 1; for (int i = 0; i < input_shape.dimensions_count() - 1; ++i) { flattened_size *= input_shape.dims(i); @@ -1013,8 +1025,8 @@ void ConvertLstmCellOperator(const Model& model, const LstmCellOperator& src_op, // Op names have been chosen to match the tf.slim LSTM naming // as closely as possible. const int axis = - model.arrays.at(src_op.inputs[LstmCellOperator::PREV_ACTIV_INPUT]) - ->shape() + model.GetArray(src_op.inputs[LstmCellOperator::PREV_ACTIV_INPUT]) + .shape() .dimensions_count() - 1; // Note that DATA_INPUT may have extra size 1 dimensions, but TF concat @@ -1033,9 +1045,9 @@ void ConvertLstmCellOperator(const Model& model, const LstmCellOperator& src_op, // Write weights const string weights_output = base + "weights"; - CHECK(model.arrays.count(src_op.inputs[LstmCellOperator::WEIGHTS_INPUT])); + CHECK(model.HasArray(src_op.inputs[LstmCellOperator::WEIGHTS_INPUT])); const auto& weights_array = - *model.arrays.at(src_op.inputs[LstmCellOperator::WEIGHTS_INPUT]); + model.GetArray(src_op.inputs[LstmCellOperator::WEIGHTS_INPUT]); // Convert 4D FullyConnected weights into 2D matrix const auto& weights_shape = weights_array.shape(); CHECK_EQ(weights_shape.dimensions_count(), 2); @@ -1059,9 +1071,9 @@ void ConvertLstmCellOperator(const Model& model, const LstmCellOperator& src_op, // Write biases const string biases_output = base + "biases"; - CHECK(model.arrays.count(src_op.inputs[LstmCellOperator::BIASES_INPUT])); + CHECK(model.HasArray(src_op.inputs[LstmCellOperator::BIASES_INPUT])); const auto& bias_array = - *model.arrays.at(src_op.inputs[LstmCellOperator::BIASES_INPUT]); + model.GetArray(src_op.inputs[LstmCellOperator::BIASES_INPUT]); // TODO(b/62904716) Bias arrays should be 1-D, and used directly. Shape bias_shape_1d = bias_array.shape(); UnextendShape(&bias_shape_1d, 1); @@ -1406,6 +1418,9 @@ void ConvertOperator(const Model& model, const Operator& src_op, } else if (src_op.type == OperatorType::kAdd) { ConvertAddOperator(model, static_cast(src_op), tensorflow_graph); + } else if (src_op.type == OperatorType::kAddN) { + ConvertAddNOperator(model, static_cast(src_op), + tensorflow_graph); } else if (src_op.type == OperatorType::kMul) { ConvertMulOperator(model, static_cast(src_op), tensorflow_graph); @@ -1557,7 +1572,7 @@ void AddPlaceholderForRNNState(const Model& model, const string& name, int size, (*placeholder->mutable_attr())["dtype"].set_type(DT_FLOAT); auto* shape = (*placeholder->mutable_attr())["shape"].mutable_shape(); - const auto& state_array = *model.arrays.at(name); + const auto& state_array = model.GetArray(name); if (state_array.has_shape()) { const auto& state_shape = state_array.shape(); const int kDims = state_shape.dimensions_count(); @@ -1574,7 +1589,7 @@ void ExportTensorFlowGraphDefImplementation(const Model& model, GraphDef* tensorflow_graph) { for (const auto& input_array : model.flags.input_arrays()) { AddPlaceholder(input_array.name(), - model.arrays.at(input_array.name())->data_type, + model.GetArray(input_array.name()).data_type, tensorflow_graph); } for (const auto& rnn_state : model.flags.rnn_states()) { @@ -1588,7 +1603,7 @@ void ExportTensorFlowGraphDefImplementation(const Model& model, // by the above operators export. It's important that this comes // after, as some operators need to export arrays that they reference // in a specific way, rather than in the generic way done below. - for (const auto& array_pair : model.arrays) { + for (const auto& array_pair : model.GetArrayMap()) { const string& array_name = array_pair.first; const auto& array = *array_pair.second; if (array.buffer) { diff --git a/tensorflow/contrib/lite/toco/export_tensorflow.h b/tensorflow/contrib/lite/toco/export_tensorflow.h index eca97745767387a04bcd2c8deb579928edf2497c..79682153a8fd143c4934095567764b886bd776af 100644 --- a/tensorflow/contrib/lite/toco/export_tensorflow.h +++ b/tensorflow/contrib/lite/toco/export_tensorflow.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_EXPORT_TENSORFLOW_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_EXPORT_TENSORFLOW_H_ +#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_EXPORT_TENSORFLOW_H_ +#define TENSORFLOW_CONTRIB_LITE_TOCO_EXPORT_TENSORFLOW_H_ #include #include "tensorflow/contrib/lite/toco/model.h" @@ -24,4 +24,4 @@ void ExportTensorFlowGraphDef(const Model& model, string* output_file_contents); } // namespace toco -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_EXPORT_TENSORFLOW_H_ +#endif // TENSORFLOW_CONTRIB_LITE_TOCO_EXPORT_TENSORFLOW_H_ diff --git a/tensorflow/contrib/lite/toco/format_port.h b/tensorflow/contrib/lite/toco/format_port.h index 0e999001e0e35fb916b11db199dbf28572685f3d..eb81e90faf20133ed722185928f86ef45ac4f8f6 100644 --- a/tensorflow/contrib/lite/toco/format_port.h +++ b/tensorflow/contrib/lite/toco/format_port.h @@ -16,8 +16,8 @@ limitations under the License. // and util::format::AppendF. Unfortunately, type safety is not as good as a // a full C++ example. // TODO(aselle): When absl adds support for StrFormat, use that instead. -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_FORMAT_PORT_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_FORMAT_PORT_H_ +#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_FORMAT_PORT_H_ +#define TENSORFLOW_CONTRIB_LITE_TOCO_FORMAT_PORT_H_ #include "tensorflow/contrib/lite/toco/toco_types.h" #include "tensorflow/core/lib/strings/stringprintf.h" @@ -74,4 +74,4 @@ inline string StringF(const char* fmt, Args&&... args) { } // namespace port } // namespace toco -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_FORMAT_PORT_H_ +#endif // TENSORFLOW_CONTRIB_LITE_TOCO_FORMAT_PORT_H_ diff --git a/tensorflow/contrib/lite/toco/g3doc/cmdline_reference.md b/tensorflow/contrib/lite/toco/g3doc/cmdline_reference.md index 4776741ab9273cf3b2ef0c63a6dbfdea5475b057..5e077952235fa1aac1e12403d3d83633a617ccb7 100644 --- a/tensorflow/contrib/lite/toco/g3doc/cmdline_reference.md +++ b/tensorflow/contrib/lite/toco/g3doc/cmdline_reference.md @@ -229,7 +229,7 @@ additional information about the multiple input arrays: well-formed quantized representation of these graphs. Such graphs should be fixed, but as a temporary work-around, setting this reorder_across_fake_quant flag allows the converter to perform necessary - graph transformaitons on them, at the cost of no longer faithfully matching + graph transformations on them, at the cost of no longer faithfully matching inference and training arithmetic. ### Logging flags diff --git a/tensorflow/contrib/lite/toco/graph_transformations/convert_expanddims_to_reshape.cc b/tensorflow/contrib/lite/toco/graph_transformations/convert_expanddims_to_reshape.cc index 3bde9b0169ddfb7fc37657122e2e8eb65ccbdf6d..56f48d47de4e86ece76ceef1d09a25f50957a8dc 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/convert_expanddims_to_reshape.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/convert_expanddims_to_reshape.cc @@ -35,7 +35,7 @@ bool ConvertExpandDimsToReshape::Run(Model* model, std::size_t op_index) { CHECK_EQ(expand_op->inputs.size(), 2); CHECK_EQ(expand_op->outputs.size(), 1); - const auto& input_array = *model->arrays[expand_op->inputs[0]]; + const auto& input_array = model->GetArray(expand_op->inputs[0]); if (!input_array.has_shape()) { // Yield until input dims have been resolved. return false; @@ -46,7 +46,7 @@ bool ConvertExpandDimsToReshape::Run(Model* model, std::size_t op_index) { return false; } - const auto& axis_array = *model->arrays[expand_op->inputs[1]]; + const auto& axis_array = model->GetArray(expand_op->inputs[1]); if (!axis_array.has_shape()) { // Yield until input axis array shape has been resolved. return false; @@ -86,7 +86,7 @@ bool ConvertExpandDimsToReshape::Run(Model* model, std::size_t op_index) { if (IsDiscardableArray(*model, axis_array_name) && CountOpsWithInput(*model, axis_array_name) == 1 && !GetOpWithOutput(*model, axis_array_name)) { - model->arrays.erase(axis_array_name); + model->EraseArray(axis_array_name); } // Replace the operator in the graph. diff --git a/tensorflow/contrib/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc b/tensorflow/contrib/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc index bf454c40c7b50d242d8a7e9eb6b7e579fb0da217..d38db85280d7bd935a47cda70227d383a513fbac 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc @@ -58,7 +58,7 @@ bool ConvertPureConvToDepthwise::Run(Model* model, std::size_t op_index) { depthwiseconv_op->outputs = {conv_op->outputs[0]}; if (conv_op->outputs.size() > 1) { // delete the im2col array. - model->arrays.erase(conv_op->outputs[1]); + model->EraseArray(conv_op->outputs[1]); } depthwiseconv_op->fused_activation_function = conv_op->fused_activation_function; diff --git a/tensorflow/contrib/lite/toco/graph_transformations/convert_reorder_axes.cc b/tensorflow/contrib/lite/toco/graph_transformations/convert_reorder_axes.cc new file mode 100644 index 0000000000000000000000000000000000000000..0d274fc687c8d42d47ddb5beb4f9c6f39b417097 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/convert_reorder_axes.cc @@ -0,0 +1,149 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include + +#include "absl/strings/str_cat.h" +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +// Creates a Reshape operator from ReorderAxes operator. +TensorFlowReshapeOperator* CreateReshapeFromReorderAxes( + Model* model, ReorderAxesOperator* reorder_op, const Shape& input_shape) { + auto* reshape_op = new TensorFlowReshapeOperator; + + // Copy inputs and outputs to Reshape. + reshape_op->inputs.push_back(reorder_op->inputs[0]); + reshape_op->outputs = reorder_op->outputs; + + // Create reshape dimensions based on input shape. Conversion from + // ReorderAxes to Reshape requires a 4D input shape. + CHECK_EQ(input_shape.dimensions_count(), 4); + std::vector reshape_dims = {1, input_shape.dims(0), input_shape.dims(1), + input_shape.dims(3) * input_shape.dims(2)}; + + // Create a new input array for Reshape. + string reshape_array_name = + AvailableArrayName(*model, reshape_op->outputs[0]); + reshape_op->inputs.push_back(reshape_array_name); + + Array& reshape_array = model->GetOrCreateArray(reshape_array_name); + *(reshape_array.mutable_shape()->mutable_dims()) = { + 1, static_cast(reshape_dims.size())}; + reshape_array.data_type = ArrayDataType::kInt32; + auto& reshape_buffer = + reshape_array.GetMutableBuffer(); + reshape_buffer.data = reshape_dims; + + return reshape_op; +} + +// Creates a Transpose operator from ReorderAxes operator. +TransposeOperator* CreateTransposeFromReorderAxes( + Model* model, ReorderAxesOperator* reorder_op, const Shape& input_shape, + const AxesOrder& input_axes_order, const AxesOrder& output_axes_order) { + auto* transpose_op = new TransposeOperator; + + // Copy inputs and outputs to Transpose. + transpose_op->inputs.push_back(reorder_op->inputs[0]); + transpose_op->outputs = reorder_op->outputs; + + // Create permutations data based on input and output axes order. + std::vector permutations_data; + GetShuffleShape(input_axes_order, output_axes_order, &permutations_data); + + // Create a new input permutations array for Transpose. + string perm_array_name = AvailableArrayName(*model, transpose_op->outputs[0]); + transpose_op->inputs.push_back(perm_array_name); + + Array& perm_array = model->GetOrCreateArray(perm_array_name); + *(perm_array.mutable_shape()->mutable_dims()) = { + static_cast(permutations_data.size())}; + perm_array.data_type = ArrayDataType::kInt32; + auto& perm_buffer = perm_array.GetMutableBuffer(); + perm_buffer.data = permutations_data; + + return transpose_op; +} + +// Converts ReorderAxes into Transpose and Reshape which are compatible with the +// TFLite interpreter. +bool ConvertReorderAxes::Run(Model* model, std::size_t op_index) { + auto reorder_it = model->operators.begin() + op_index; + if (reorder_it->get()->type != OperatorType::kReorderAxes) return false; + + auto* reorder_op = static_cast(reorder_it->get()); + CHECK_EQ(reorder_op->inputs.size(), 1); + CHECK_EQ(reorder_op->outputs.size(), 1); + + const auto& input_array_name = reorder_op->inputs[0]; + const auto& output_array_name = reorder_op->outputs[0]; + auto& input_array = model->GetArray(input_array_name); + auto& output_array = model->GetArray(output_array_name); + + // Get input array. If kFakeQuant is the input into ReorderAxes, get the input + // array passed into kFakeQuant. kFakeQuant op is dropped when possible. + string constant_input_array_name = input_array_name; + if (!input_array.buffer) { + const auto* op_producing_input = GetOpWithOutput(*model, input_array_name); + if (op_producing_input && + op_producing_input->type == OperatorType::kFakeQuant) { + constant_input_array_name = op_producing_input->inputs[0]; + } + } + + // Yield if input array contains constants or if output array size has not + // been adjusted to reflect the permutations in ReorderAxes. ReorderAxes will + // be merged into a constant array when possible. + if (IsConstantParameterArray(*model, constant_input_array_name)) return false; + if (!output_array.has_shape()) return false; + + const auto input_axes_order = reorder_op->input_axes_order; + const auto output_axes_order = reorder_op->output_axes_order; + const Shape input_shape = input_array.shape(); + + // Creates a Reshape or Transpose operator depending on the conversion. + if (input_axes_order == AxesOrder::kHWIM && + output_axes_order == AxesOrder::k1HWO) { + // Add Reshape operator into the graph. This special case is not just a + // permutation. The input dimensions get merged into 3 dimensions while the + // order of the elements does not change. + auto* reshape_op = + CreateReshapeFromReorderAxes(model, reorder_op, input_shape); + const auto reshape_it = model->operators.emplace(reorder_it, reshape_op); + reorder_it = reshape_it + 1; + } else { + // Add Transpose operator into the graph. + auto* transpose_op = CreateTransposeFromReorderAxes( + model, reorder_op, input_shape, input_axes_order, output_axes_order); + const auto transpose_it = + model->operators.emplace(reorder_it, transpose_op); + reorder_it = transpose_it + 1; + } + + // Remove ReorderAxes operator from the graph. + CHECK_EQ(reorder_it->get(), reorder_op); + model->operators.erase(reorder_it); + + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_addn_to_add.cc b/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_addn_to_add.cc new file mode 100644 index 0000000000000000000000000000000000000000..dcaaddbf3b5409f0fc3ddaf32e23b1e5eefb6565 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_addn_to_add.cc @@ -0,0 +1,51 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +// This pass will convert an AddN operator with only 2 inputs into a regular Add +// operator, to which more optimizations may apply. +bool ConvertTrivialAddNToAdd::Run(Model* model, std::size_t op_index) { + auto addn_it = model->operators.begin() + op_index; + if (addn_it->get()->type != OperatorType::kAddN) { + return false; + } + AddNOperator* addn_op = static_cast(addn_it->get()); + CHECK_GE(addn_op->inputs.size(), 2); + CHECK_EQ(addn_op->outputs.size(), 1); + + // We only reduce AddN with N=2 to a regular Add. + if (addn_op->inputs.size() != 2) { + return false; + } + + // Copy inputs & outputs to regular Add. + auto* add_op = new AddOperator; + add_op->inputs.push_back(addn_op->inputs[0]); + add_op->inputs.push_back(addn_op->inputs[1]); + add_op->outputs = addn_op->outputs; + + // Replace the AddN operator in the graph. + const auto add_it = model->operators.emplace(addn_it, add_op); + addn_it = add_it + 1; + CHECK_EQ(addn_it->get(), addn_op); + model->operators.erase(addn_it); + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_transpose_to_reshape.cc b/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_transpose_to_reshape.cc index a234c209240ecb9eeba1d2e416a294be53d221ee..c2b166033c33b777bad88cb712adf8517be1762a 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_transpose_to_reshape.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_transpose_to_reshape.cc @@ -29,7 +29,7 @@ bool ConvertTrivialTransposeToReshape::Run(Model* model, std::size_t op_index) { TransposeOperator* transpose_op = static_cast(transpose_it->get()); - const auto& output_array = *model->arrays[transpose_op->outputs[0]]; + const auto& output_array = model->GetArray(transpose_op->outputs[0]); if (!output_array.has_shape()) { // Yield until PropagateFixedSizes has been run on this op. return false; @@ -70,7 +70,7 @@ bool ConvertTrivialTransposeToReshape::Run(Model* model, std::size_t op_index) { // Delete perm array if unused if (IsDiscardableArray(*model, perm_array_name) && CountOpsWithInput(*model, perm_array_name) == 1) { - model->arrays.erase(perm_array_name); + model->EraseArray(perm_array_name); } // Replace the operator in the graph. diff --git a/tensorflow/contrib/lite/toco/graph_transformations/create_im2col_arrays.cc b/tensorflow/contrib/lite/toco/graph_transformations/create_im2col_arrays.cc index 1735b51e5b6ca517bad62bf55f0cc9f0c21ac440..076415ece8c1039caa32e947fe54ab3e101bec9e 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/create_im2col_arrays.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/create_im2col_arrays.cc @@ -35,7 +35,7 @@ bool CreateIm2colArrays::Run(Model* model, std::size_t op_index) { // We already have an im2col array return false; } - const auto& weights_array = *model->arrays[conv_op->inputs[1]]; + const auto& weights_array = model->GetArray(conv_op->inputs[1]); if (!weights_array.has_shape()) { // We need to yield until weights dims have been resolved, because // from the weights dims we determine whether an im2col array is diff --git a/tensorflow/contrib/lite/toco/graph_transformations/dequantize.cc b/tensorflow/contrib/lite/toco/graph_transformations/dequantize.cc index b89e3f5310cd7364294ad875cfcdf9c14660366b..498c864bde6d656c8318e981204cb42cb3a4d03f 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/dequantize.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/dequantize.cc @@ -53,7 +53,7 @@ std::vector>::iterator FindFirstOpWithInput( } void ClearArrayQuantizationParams(const string& array_name, Model* model) { - auto* array = model->arrays.at(array_name).get(); + auto* array = &model->GetArray(array_name); CHECK(array->quantization_params); for (auto& input_array : *model->flags.mutable_input_arrays()) { if (input_array.name() == array_name) { @@ -77,7 +77,7 @@ void ClearArrayQuantizationParams(const string& array_name, Model* model) { bool DequantizeArray(const string& array_name, GraphTransformation* transformation, Model* model) { - auto* array = model->arrays.at(array_name).get(); + auto* array = &model->GetArray(array_name); if (!array->quantization_params) { return false; } @@ -214,7 +214,9 @@ bool Dequantize::Run(Model* model, std::size_t op_index) { } bool changed = false; for (const string& array : arrays) { - changed |= DequantizeArray(array, this, model); + if (!model->IsOptionalArray(array)) { + changed |= DequantizeArray(array, this, model); + } } return changed; diff --git a/tensorflow/contrib/lite/toco/graph_transformations/drop_fake_quant.cc b/tensorflow/contrib/lite/toco/graph_transformations/drop_fake_quant.cc index fea360740f4e645e1f00eaed42cbff48f430fe2a..95558ef5ece9a78825daf0203e2f6f6fee6f3cda 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/drop_fake_quant.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/drop_fake_quant.cc @@ -45,7 +45,7 @@ bool DropFakeQuant::Run(Model* model, std::size_t op_index) { // Drop min/max inputs for (int i = 1; i < fakequant_op->inputs.size(); i++) { if (CountOpsWithInput(*model, fakequant_op->inputs[i]) == 1) { - model->arrays.erase(fakequant_op->inputs[i]); + model->EraseArray(fakequant_op->inputs[i]); } } fakequant_op->inputs.resize(1); diff --git a/tensorflow/contrib/lite/toco/graph_transformations/drop_im2col_arrays.cc b/tensorflow/contrib/lite/toco/graph_transformations/drop_im2col_arrays.cc index a3ed6663bcc80c5fc642a399b1e5c0cf3336973a..f7fd878b7e8b1c834125130ea2a778cecefd3de0 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/drop_im2col_arrays.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/drop_im2col_arrays.cc @@ -32,7 +32,7 @@ bool DropIm2colArrays::Run(Model* model, std::size_t op_index) { // Drop the im2col array. CHECK_EQ(conv_op->outputs.size(), 2); - model->arrays.erase(conv_op->outputs[1]); + model->EraseArray(conv_op->outputs[1]); conv_op->outputs.resize(1); AddMessageF("Dropped an im2col array for %s", LogName(*conv_op)); diff --git a/tensorflow/contrib/lite/toco/graph_transformations/fuse_activation_functions.cc b/tensorflow/contrib/lite/toco/graph_transformations/fuse_activation_functions.cc index ad4a6f9b78b06fd738da40c2054c07e8f272ee17..88e59664ec427841df6f20686238feacef6a47e9 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/fuse_activation_functions.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/fuse_activation_functions.cc @@ -91,7 +91,7 @@ bool FuseActivationFunctions::Run(Model* model, std::size_t op_index) { } else { LOG(FATAL) << "Unhandled activation function type"; } - model->arrays.erase(ac_op->inputs[0]); + model->EraseArray(ac_op->inputs[0]); op->outputs[0] = ac_op->outputs[0]; model->operators.erase(ac_it); return true; diff --git a/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_following_affine.cc b/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_following_affine.cc index 4619d8bbee2e52483a523277f421de5bfa155635..dcbbead517f26a227363989b5af2a4040c98ff57 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_following_affine.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_following_affine.cc @@ -285,13 +285,13 @@ bool FuseBinaryIntoFollowingAffine::Run(Model* model, std::size_t op_index) { AddMessageF("Fusing %s into the following %s", LogName(*binary_op), LogName(*following_op)); - model->arrays.erase(binary_op->outputs[0]); + model->EraseArray(binary_op->outputs[0]); following_op->inputs[0] = binary_op->inputs[index_of_variable_input]; const auto& old_constant_param_name = binary_op->inputs[index_of_constant_input]; CHECK(IsConstantParameterArray(*model, old_constant_param_name)); if (CountOpsWithInput(*model, old_constant_param_name) == 1) { - model->arrays.erase(old_constant_param_name); + model->EraseArray(old_constant_param_name); } model->operators.erase(binary_it); return true; diff --git a/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc b/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc index 8948653ec38f5a5a6e92cfe9e6bafdbf1aa9a962..5b57178b18d2d60e1f301a1a8b257d8057618550 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc @@ -309,7 +309,7 @@ bool FuseBinaryIntoPrecedingAffine::Run(Model* model, std::size_t op_index) { LOG(FATAL) << "should not get here"; } - model->arrays.erase(preceding_op->outputs[0]); + model->EraseArray(preceding_op->outputs[0]); preceding_op->outputs[0] = binary_op->outputs[0]; preceding_op->fused_activation_function = binary_op->fused_activation_function; @@ -317,7 +317,7 @@ bool FuseBinaryIntoPrecedingAffine::Run(Model* model, std::size_t op_index) { binary_op->inputs[index_of_constant_input]; CHECK(IsConstantParameterArray(*model, old_constant_param_name)); if (CountOpsWithInput(*model, old_constant_param_name) == 1) { - model->arrays.erase(old_constant_param_name); + model->EraseArray(old_constant_param_name); } model->operators.erase(binary_it); return true; diff --git a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.cc b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.cc index f861c4147a04fe31b7236bfa22ed4627f7742d09..6961e23690a5e53643f2b2c52bb62ce395d05c95 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.cc @@ -31,13 +31,13 @@ namespace { void PrintModelStats(const string& label, const Model& model) { int quantized_arrays = 0; - for (const auto& array : model.arrays) { + for (const auto& array : model.GetArrayMap()) { if (array.second->quantization_params) { quantized_arrays++; } } LOG(INFO) << label << ": " << model.operators.size() << " operators, " - << model.arrays.size() << " arrays (" << quantized_arrays + << model.GetArrayMap().size() << " arrays (" << quantized_arrays << " quantized)"; } @@ -91,14 +91,9 @@ void DiscardUselessConnectedComponentsAndRNNBackEdges(Model* model) { } } while (found_new_useful_arrays); // Erase arrays that aren't useful, and that are discardable. - for (auto it = model->arrays.begin(); it != model->arrays.end();) { - if (useful_arrays.count(it->first) || - !IsDiscardableArray(*model, it->first)) { - ++it; - } else { - it = model->arrays.erase(it); - } - } + model->EraseArrays([&](const string& name) { + return (!useful_arrays.count(name) && IsDiscardableArray(*model, name)); + }); // Erase operators that do not produce a useful output array. for (auto it = model->operators.begin(); it != model->operators.end();) { // Only need to test the first output, as we simultaneously added all of @@ -118,8 +113,8 @@ void DiscardUselessConnectedComponentsAndRNNBackEdges(Model* model) { std::vector rnn_states_to_keep; for (const auto& rnn_state : model->flags.rnn_states()) { const bool dangling = - !model->arrays.count(rnn_state.back_edge_source_array()) || - !model->arrays.count(rnn_state.state_array()); + !model->HasArray(rnn_state.back_edge_source_array()) || + !model->HasArray(rnn_state.state_array()); if (dangling) { CHECK(rnn_state.discardable()); } else { @@ -137,6 +132,7 @@ bool GraphTransformationsPass(int increment, Model* model, CHECK(increment == 1 || increment == -1); bool changed = false; if (model->operators.empty()) { + LOG(INFO) << "Model is empty!!!"; return false; } int op_index = increment == 1 ? 0 : model->operators.size() - 1; diff --git a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h index 9ec9f92c90fb93962994d084d9354c11ba367e95..e11bebcd4e0f66faf63290e3af0c72c39811cebe 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h +++ b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_GRAPH_TRANSFORMATIONS_GRAPH_TRANSFORMATIONS_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_GRAPH_TRANSFORMATIONS_GRAPH_TRANSFORMATIONS_H_ +#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_GRAPH_TRANSFORMATIONS_GRAPH_TRANSFORMATIONS_H_ +#define TENSORFLOW_CONTRIB_LITE_TOCO_GRAPH_TRANSFORMATIONS_GRAPH_TRANSFORMATIONS_H_ #include #include @@ -114,7 +114,9 @@ void RunGraphTransformations(Model* model, const string& message, // List of all graph transformations DECLARE_GRAPH_TRANSFORMATION(ConvertExpandDimsToReshape) DECLARE_GRAPH_TRANSFORMATION(ConvertPureConvToDepthwise) +DECLARE_GRAPH_TRANSFORMATION(ConvertTrivialAddNToAdd) DECLARE_GRAPH_TRANSFORMATION(ConvertTrivialTransposeToReshape) +DECLARE_GRAPH_TRANSFORMATION(ConvertReorderAxes) DECLARE_GRAPH_TRANSFORMATION(EnsureBiasVectors) DECLARE_GRAPH_TRANSFORMATION(FuseActivationFunctions) DECLARE_GRAPH_TRANSFORMATION(FuseBinaryIntoFollowingAffine) @@ -192,4 +194,4 @@ class RemoveTrivialReshape : public GraphTransformation { } // end namespace toco -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_GRAPH_TRANSFORMATIONS_GRAPH_TRANSFORMATIONS_H_ +#endif // TENSORFLOW_CONTRIB_LITE_TOCO_GRAPH_TRANSFORMATIONS_GRAPH_TRANSFORMATIONS_H_ diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_normalization.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_normalization.cc index 01b75e37c691d48fabf8832af04543be3f5eb3bc..419a0776a6b987a18df059d3c1d4bf4370cd24d8 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_normalization.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_normalization.cc @@ -150,19 +150,19 @@ bool IdentifyL2Normalization::Run(Model* model, std::size_t op_index) { // Erase the subgraph that is now replaced by L2Normalization model->operators.erase(FindOperator(model, square_op)); - model->arrays.erase(sum_op->inputs[0]); + model->EraseArray(sum_op->inputs[0]); if (sum_op->inputs.size() > 1) { - model->arrays.erase(sum_op->inputs[1]); + model->EraseArray(sum_op->inputs[1]); } model->operators.erase(FindOperator(model, sum_op)); if (add_op) { - model->arrays.erase(add_op->inputs[0]); - model->arrays.erase(add_op->inputs[1]); + model->EraseArray(add_op->inputs[0]); + model->EraseArray(add_op->inputs[1]); model->operators.erase(FindOperator(model, add_op)); } - model->arrays.erase(sqrt_or_rsqrt_op->inputs[0]); + model->EraseArray(sqrt_or_rsqrt_op->inputs[0]); model->operators.erase(FindOperator(model, sqrt_or_rsqrt_op)); - model->arrays.erase(div_or_mul_op->inputs[1]); + model->EraseArray(div_or_mul_op->inputs[1]); model->operators.erase(FindOperator(model, div_or_mul_op)); return true; } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_pool.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_pool.cc index 1865416fc2226d663dfd51a5c0a0e2129caf485c..e4d52476c649de53b3ab663f53ce7a5538dbb5ab 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_pool.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_pool.cc @@ -92,8 +92,8 @@ bool IdentifyL2Pool::Run(Model* model, std::size_t op_index) { AddMessageF("Creating %s replacing equivalent subgraph", LogName(*l2pool_op)); // Erase intermediate arrays, keeping input to square op. - model->arrays.erase(avpool_op->inputs[0]); - model->arrays.erase(sqrt_op->inputs[0]); + model->EraseArray(avpool_op->inputs[0]); + model->EraseArray(sqrt_op->inputs[0]); // Erase three operators being replaced. model->operators.erase(FindOperator(model, square_op)); diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_relu1.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_relu1.cc index cfc77024e7e56038878570c9d3a462715a53ae3f..d36e95060937d6af0789766bcb29ae70cef4569d 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/identify_relu1.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_relu1.cc @@ -89,12 +89,12 @@ bool IdentifyRelu1::Run(Model* model, std::size_t op_index) { AddMessageF("Creating %s replacing equivalent subgraph", LogName(*relu1_op)); // Erase Maximum scalar input & operator - model->arrays.erase(maximum_op->inputs[scalar_input_index]); + model->EraseArray(maximum_op->inputs[scalar_input_index]); model->operators.erase(FindOperator(model, maximum_op)); // Erase Minimum inputs & operator - model->arrays.erase(minimum_op->inputs[0]); - model->arrays.erase(minimum_op->inputs[1]); + model->EraseArray(minimum_op->inputs[0]); + model->EraseArray(minimum_op->inputs[1]); model->operators.erase(FindOperator(model, minimum_op)); return true; diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc index c6f17cf31967d4b5dfa004b5e76120482e92392d..f0d107232b4517115aa3f64b39b825dbaffb83ce 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc @@ -27,7 +27,7 @@ namespace { void SetDataTypeForAllOutputs(Model* model, Operator* op, ArrayDataType data_type) { for (const auto& output : op->outputs) { - model->arrays[output]->data_type = data_type; + model->GetArray(output).data_type = data_type; } } } // namespace @@ -38,7 +38,8 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) { // If the data type of some input is unknown, we need to yield. for (const auto& input : op->inputs) { - if (model->arrays[input]->data_type == ArrayDataType::kNone) { + if (!model->IsOptionalArray(input) && + model->GetArray(input).data_type == ArrayDataType::kNone) { return false; } } @@ -46,7 +47,7 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) { // end if we changed anything, and return the correct boolean value. std::unordered_map old_output_data_types; for (const auto& output : op->outputs) { - old_output_data_types[output] = model->arrays[output]->data_type; + old_output_data_types[output] = model->GetArray(output).data_type; } // Do the actual output data types propagation. if (op->type == OperatorType::kDequantize || @@ -68,18 +69,18 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) { op->type == OperatorType::kFill) { // These operators produce an output with the same type as their 2nd input CHECK_GE(op->inputs.size(), 2); - const ArrayDataType data_type = model->arrays[op->inputs[1]]->data_type; + const ArrayDataType data_type = model->GetArray(op->inputs[1]).data_type; SetDataTypeForAllOutputs(model, op, data_type); } else if (op->type == OperatorType::kCast) { // Data type of the Cast op is specified. CHECK_EQ(op->outputs.size(), 1); auto* cast_op = static_cast(op); - model->arrays[op->outputs[0]]->data_type = cast_op->dst_data_type; + model->GetArray(op->outputs[0]).data_type = cast_op->dst_data_type; } else if (op->type == OperatorType::kArgMax) { // Data type of the ArgMax op is specified. CHECK_EQ(op->outputs.size(), 1); auto* argmax_op = static_cast(op); - model->arrays[op->outputs[0]]->data_type = argmax_op->output_data_type; + model->GetArray(op->outputs[0]).data_type = argmax_op->output_data_type; } else if (op->type == OperatorType::kRange) { auto* range_op = static_cast(op); // Output type of the Range op can be set via an attribute @@ -90,7 +91,7 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) { } else { // Otherwise use the first input CHECK_GE(op->inputs.size(), 1); - data_type = model->arrays[op->inputs[0]]->data_type; + data_type = model->GetArray(op->inputs[0]).data_type; } CHECK_EQ(op->outputs.size(), 1); SetDataTypeForAllOutputs(model, op, data_type); @@ -102,7 +103,7 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) { for (int i = 0; i < unsupported_op->output_data_types.size(); ++i) { auto output = op->outputs[i]; auto data_type = unsupported_op->output_data_types[i]; - model->arrays[output]->data_type = data_type; + model->GetArray(output).data_type = data_type; } } else if (op->type == OperatorType::kExpandDims) { // Yield on ExpandDim until it is converted to Reshape @@ -110,12 +111,12 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) { } else { // These operators produce outputs with the same type as their 1st input CHECK_GT(op->inputs.size(), 0); - const ArrayDataType data_type = model->arrays[op->inputs[0]]->data_type; + const ArrayDataType data_type = model->GetArray(op->inputs[0]).data_type; SetDataTypeForAllOutputs(model, op, data_type); } // Return true if any output data type changed, false if none changed. for (const auto& output : op->outputs) { - if (old_output_data_types[output] != model->arrays[output]->data_type) { + if (old_output_data_types[output] != model->GetArray(output).data_type) { return true; } } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc index a939efb4dbbc6ec0af2e44270d7c028eff882b70..4fb3b6ae7a5fc5bfc2719b978331c67ae799eb54 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc @@ -85,7 +85,7 @@ void ComputeBinaryOperatorOutputSize(const Shape& input_shape1, int GetOutputDepthFromWeights(const Model& model, const Operator& op) { const string& weights_name = op.inputs[1]; - const auto& weights_shape = model.arrays.at(weights_name)->shape(); + const auto& weights_shape = model.GetArray(weights_name).shape(); if (op.type == OperatorType::kConv || op.type == OperatorType::kFullyConnected) { return weights_shape.dims(0); @@ -98,7 +98,7 @@ int GetOutputDepthFromWeights(const Model& model, const Operator& op) { bool EnsureBiasVectorShape(Model* model, Operator* op) { const string& weights_name = op->inputs[1]; - const auto& weights_array = *model->arrays[weights_name]; + const auto& weights_array = model->GetArray(weights_name); // Yield until weights shape has been resolved. if (!weights_array.has_shape()) { return false; @@ -107,7 +107,7 @@ bool EnsureBiasVectorShape(Model* model, Operator* op) { if (op->inputs.size() < 3) { return false; } - auto& bias_array = *model->arrays[op->inputs[2]]; + auto& bias_array = model->GetArray(op->inputs[2]); if (bias_array.has_shape()) { return true; } @@ -126,7 +126,7 @@ void ProcessConvOperator(Model* model, ConvOperator* op) { return; } - const auto& input_array = *model->arrays[op->inputs[0]]; + const auto& input_array = model->GetArray(op->inputs[0]); // Yield until input dims have been resolved. if (!input_array.has_shape()) { return; @@ -134,7 +134,7 @@ void ProcessConvOperator(Model* model, ConvOperator* op) { const auto& input_shape = input_array.shape(); CHECK_EQ(input_shape.dimensions_count(), 4); - const auto& weights_array = *model->arrays[op->inputs[1]]; + const auto& weights_array = model->GetArray(op->inputs[1]); // Yield until weights dims have been resolved. if (!weights_array.has_shape()) { return; @@ -156,7 +156,7 @@ void ProcessConvOperator(Model* model, ConvOperator* op) { if (op->outputs.size() == 2) { const auto& output_shape = output_array.shape(); const int input_depth = weights_shape.dims(3); - auto& im2col_array = *model->arrays[op->outputs[1]]; + auto& im2col_array = model->GetArray(op->outputs[1]); im2col_array.copy_shape(Shape{output_shape.dims(0), output_shape.dims(1), output_shape.dims(2), input_depth * kheight * kwidth}); @@ -168,7 +168,7 @@ void ProcessDepthwiseConvOperator(Model* model, DepthwiseConvOperator* op) { return; } - const auto& input_array = *model->arrays[op->inputs[0]]; + const auto& input_array = model->GetArray(op->inputs[0]); // Yield until input dims have been resolved. if (!input_array.has_shape()) { return; @@ -176,7 +176,7 @@ void ProcessDepthwiseConvOperator(Model* model, DepthwiseConvOperator* op) { const auto& input_shape = input_array.shape(); CHECK_EQ(input_shape.dimensions_count(), 4); - const auto& weights_array = *model->arrays[op->inputs[1]]; + const auto& weights_array = model->GetArray(op->inputs[1]); // Yield until weights dims have been resolved. if (!weights_array.has_shape()) { return; @@ -209,7 +209,7 @@ void ProcessDepthwiseConvOperator(Model* model, DepthwiseConvOperator* op) { } void ProcessDepthToSpaceOperator(Model* model, DepthToSpaceOperator* op) { - const auto& input_array = *model->arrays[op->inputs[0]]; + const auto& input_array = model->GetArray(op->inputs[0]); // Yield until input dims have been resolved. if (!input_array.has_shape()) { return; @@ -232,7 +232,7 @@ void ProcessDepthToSpaceOperator(Model* model, DepthToSpaceOperator* op) { } void ProcessSpaceToDepthOperator(Model* model, SpaceToDepthOperator* op) { - const auto& input_array = *model->arrays[op->inputs[0]]; + const auto& input_array = model->GetArray(op->inputs[0]); // Yield until input dims have been resolved. if (!input_array.has_shape()) { return; @@ -258,7 +258,7 @@ void ProcessSpaceToDepthOperator(Model* model, SpaceToDepthOperator* op) { void ProcessFillOperator(Model* model, FillOperator* op) { CHECK_EQ(op->inputs.size(), 2); CHECK_EQ(op->outputs.size(), 1); - auto& output_array = *model->arrays[op->outputs[0]]; + auto& output_array = model->GetArray(op->outputs[0]); if (output_array.has_shape()) { // We have already run return; @@ -287,7 +287,7 @@ void ProcessFullyConnectedOperator(Model* model, FullyConnectedOperator* op) { return; } - const auto& input_array = *model->arrays[op->inputs[0]]; + const auto& input_array = model->GetArray(op->inputs[0]); // Yield until input dims have been resolved. if (!input_array.has_shape()) { return; @@ -295,7 +295,7 @@ void ProcessFullyConnectedOperator(Model* model, FullyConnectedOperator* op) { const auto& input_shape = input_array.shape(); CHECK_GE(input_shape.dimensions_count(), 1); - const auto& weights_array = *model->arrays[op->inputs[1]]; + const auto& weights_array = model->GetArray(op->inputs[1]); // Yield until weights dims have been resolved. if (!weights_array.has_shape()) { return; @@ -315,13 +315,13 @@ void ProcessFullyConnectedOperator(Model* model, FullyConnectedOperator* op) { void ProcessTensorFlowReshapeOperator(Model* model, TensorFlowReshapeOperator* op) { - auto& output_array = *model->arrays[op->outputs[0]]; + auto& output_array = model->GetArray(op->outputs[0]); if (output_array.has_shape()) { // We have already run return; } - const auto& input_array = *model->arrays[op->inputs[0]]; + const auto& input_array = model->GetArray(op->inputs[0]); if (!input_array.has_shape()) { // Yield until input dims have been resolved. return; @@ -377,14 +377,14 @@ void ProcessTensorFlowReshapeOperator(Model* model, } void ProcessSimpleOperator(Model* model, Operator* op) { - const auto& input_array = *model->arrays[op->inputs[0]]; + const auto& input_array = model->GetArray(op->inputs[0]); // Yield until input dims have been resolved. if (!input_array.has_shape()) { return; } const string& output_name = op->outputs[0]; - auto& output_array = *model->arrays[output_name]; + auto& output_array = model->GetArray(output_name); if (output_array.has_shape()) { return; } @@ -394,18 +394,40 @@ void ProcessSimpleOperator(Model* model, Operator* op) { void ProcessSimpleBinaryOperator(Model* model, Operator* op) { CHECK_EQ(op->inputs.size(), 2); - const auto& input0_array = *model->arrays[op->inputs[0]]; - const auto& input1_array = *model->arrays[op->inputs[1]]; + const auto& input0_array = model->GetArray(op->inputs[0]); + const auto& input1_array = model->GetArray(op->inputs[1]); // Yield until input dims have been resolved. if (!input0_array.has_shape() || !input1_array.has_shape()) { return; } const string& output_name = op->outputs[0]; - auto& output_array = *model->arrays[output_name]; + auto& output_array = model->GetArray(output_name); ComputeBinaryOperatorOutputSize(input0_array.shape(), input1_array.shape(), &output_array); } +void ProcessAddNOperator(Model* model, Operator* op) { + // Yield until all input dims have been resolved. + // + // TODO(myenik): Since AddN does not support broadcasting, maybe we could + // actually use this to improve shape propagation by propagating the shape of + // one input to all other inputs once it is resolved instead of just the + // output, since all inputs must be the same size and shape for a well-formed + // graph. + for (const auto& input : op->inputs) { + const auto& input_array = model->GetArray(input); + if (!input_array.has_shape()) { + return; + } + } + + // AddN does not support broadcasting, all inputs must be the same shape, so + // we just take the first input shape and apply it to the output. + const auto& input0_array = model->GetArray(op->inputs[0]); + auto& output_array = model->GetArray(op->outputs[0]); + output_array.copy_shape(input0_array.shape()); +} + bool KeepDims(const Operator& op) { switch (op.type) { case OperatorType::kTensorFlowMin: @@ -424,11 +446,11 @@ bool KeepDims(const Operator& op) { void ProcessTensorFlowReductionOperator(Model* model, Operator* op) { CHECK_LE(op->inputs.size(), 2); - auto& output_array = *model->arrays[op->outputs[0]]; + auto& output_array = model->GetArray(op->outputs[0]); if (output_array.has_shape()) { return; } - const auto& input_array = *model->arrays[op->inputs[0]]; + const auto& input_array = model->GetArray(op->inputs[0]); if (!input_array.has_shape()) { return; } @@ -436,7 +458,7 @@ void ProcessTensorFlowReductionOperator(Model* model, Operator* op) { const bool keep_dims = KeepDims(*op); if (op->inputs.size() == 2) { // There is a reduction_indices input. - const auto& reduction_array = *model->arrays[op->inputs[1]]; + const auto& reduction_array = model->GetArray(op->inputs[1]); if (!reduction_array.buffer) { return; } @@ -476,11 +498,11 @@ void ProcessSliceOperator(Model* model, SliceOperator* op) { if (op->begin.empty()) return; // Yield until input dims have been resolved. - const auto& input_array = *model->arrays[op->inputs[0]]; + const auto& input_array = model->GetArray(op->inputs[0]); if (!input_array.has_shape()) return; const Shape& input_shape = input_array.shape(); - auto& output_array = *model->arrays[op->outputs[0]]; + auto& output_array = model->GetArray(op->outputs[0]); if (output_array.has_shape()) return; CHECK_EQ(input_shape.dims().size(), op->size.size()); @@ -500,7 +522,7 @@ void ProcessSliceOperator(Model* model, SliceOperator* op) { void ProcessReorderAxesOperator(Model* model, ReorderAxesOperator* op) { const string& input_name = op->inputs[0]; - const auto& input_array = *model->arrays[input_name]; + const auto& input_array = model->GetArray(input_name); // Yield until input dims have been resolved. if (!input_array.has_shape()) { return; @@ -515,20 +537,20 @@ void ProcessReorderAxesOperator(Model* model, ReorderAxesOperator* op) { void ProcessConcatenationOperator(Model* model, ConcatenationOperator* op) { // Yield until input dims have been resolved. for (const auto& input_name : op->inputs) { - auto& input_array = *model->arrays[input_name]; + auto& input_array = model->GetArray(input_name); if (!input_array.has_shape()) { return; } } auto& output_array = model->GetArray(op->outputs[0]); // Use 0 input as basis for output dimensions. - const auto& first_input_array = *model->arrays[op->inputs[0]]; + const auto& first_input_array = model->GetArray(op->inputs[0]); output_array.copy_shape(first_input_array.shape()); // Determine the concat size, and enfore that all inputs have // the same dimensions count. int concat_size = 0; for (const auto& input_name : op->inputs) { - auto& input_array = *model->arrays[input_name]; + auto& input_array = model->GetArray(input_name); CHECK(input_array.has_shape()); if (input_array.shape().dimensions_count() == 0) { continue; @@ -548,16 +570,16 @@ void ProcessConcatenationOperator(Model* model, ConcatenationOperator* op) { void ProcessRangeOperator(Model* model, RangeOperator* op) { CHECK_EQ(op->inputs.size(), 3); - const auto& start_array = *model->arrays[op->inputs[0]]; + const auto& start_array = model->GetArray(op->inputs[0]); if (!start_array.has_shape()) { // Yield until input dims have been resolved. return; } - const auto& limit_array = *model->arrays[op->inputs[1]]; + const auto& limit_array = model->GetArray(op->inputs[1]); if (!limit_array.has_shape()) { return; } - const auto& delta_array = *model->arrays[op->inputs[2]]; + const auto& delta_array = model->GetArray(op->inputs[2]); if (!delta_array.has_shape()) { return; } @@ -599,7 +621,7 @@ void ProcessRangeOperator(Model* model, RangeOperator* op) { void ProcessTensorFlowSplitOperator(Model* model, TensorFlowSplitOperator* op) { CHECK_EQ(op->inputs.size(), 2); const string& input_name = op->inputs[1]; - const auto& input_array = *model->arrays[input_name]; + const auto& input_array = model->GetArray(input_name); // Yield until input dims have been resolved. if (!input_array.has_shape()) { return; @@ -618,13 +640,13 @@ void ProcessTensorFlowSplitOperator(Model* model, TensorFlowSplitOperator* op) { CHECK_EQ(op->outputs.size(), op->num_split); for (const auto& output : op->outputs) { - model->arrays[output]->copy_shape(output_shape); + model->GetArray(output).copy_shape(output_shape); } } void ProcessAveragePoolOperator(Model* model, AveragePoolOperator* op) { const string& input_name = op->inputs[0]; - const auto& input_array = *model->arrays[input_name]; + const auto& input_array = model->GetArray(input_name); // Yield until input dims have been resolved. if (!input_array.has_shape()) { return; @@ -641,7 +663,7 @@ void ProcessAveragePoolOperator(Model* model, AveragePoolOperator* op) { void ProcessMaxPoolOperator(Model* model, MaxPoolOperator* op) { const string& input_name = op->inputs[0]; - const auto& input_array = *model->arrays[input_name]; + const auto& input_array = model->GetArray(input_name); // Yield until input dims have been resolved. if (!input_array.has_shape()) { return; @@ -658,7 +680,7 @@ void ProcessMaxPoolOperator(Model* model, MaxPoolOperator* op) { void ProcessL2PoolOperator(Model* model, L2PoolOperator* op) { const string& input_name = op->inputs[0]; - const auto& input_array = *model->arrays[input_name]; + const auto& input_array = model->GetArray(input_name); // Yield until input dims have been resolved. if (!input_array.has_shape()) { return; @@ -679,14 +701,14 @@ void ProcessResizeBilinearOperator(Model* model, ResizeBilinearOperator* op) { CHECK_EQ(op->inputs.size(), 2); CHECK_EQ(op->outputs.size(), 1); - if (!model->arrays[op->inputs[0]]->has_shape() || - !model->arrays[op->inputs[1]]->has_shape()) { + if (!model->GetArray(op->inputs[0]).has_shape() || + !model->GetArray(op->inputs[1]).has_shape()) { return; } - const auto& input_data_shape = model->arrays[op->inputs[0]]->shape(); + const auto& input_data_shape = model->GetArray(op->inputs[0]).shape(); const string& output_size_name = op->inputs[1]; - const auto& output_size_array = *model->arrays[output_size_name]; + const auto& output_size_array = model->GetArray(output_size_name); CHECK(output_size_array.data_type == ArrayDataType::kInt32); CHECK(output_size_array.has_shape()); const auto& output_size_shape = output_size_array.shape(); @@ -697,9 +719,9 @@ void ProcessResizeBilinearOperator(Model* model, ResizeBilinearOperator* op) { } std::vector output_shape = output_size_array.GetBuffer().data; - model->arrays[op->outputs[0]]->copy_shape( - Shape({input_data_shape.dims(0), output_shape[0], output_shape[1], - input_data_shape.dims(3)})); + model->GetArray(op->outputs[0]) + .copy_shape(Shape({input_data_shape.dims(0), output_shape[0], + output_shape[1], input_data_shape.dims(3)})); } void ProcessLstmCellOperator(Model* model, LstmCellOperator* op) { @@ -708,7 +730,7 @@ void ProcessLstmCellOperator(Model* model, LstmCellOperator* op) { QCHECK_EQ(op->outputs.size(), LstmCellOperator::NUM_OUTPUTS); const auto& input_array = - *model->arrays[op->inputs[LstmCellOperator::DATA_INPUT]]; + model->GetArray(op->inputs[LstmCellOperator::DATA_INPUT]); // Yield until all input dims have been resolved. if (!input_array.has_shape()) { return; @@ -717,7 +739,7 @@ void ProcessLstmCellOperator(Model* model, LstmCellOperator* op) { CHECK_GE(input_shape.dimensions_count(), 2); const auto& prev_activ_array = - *model->arrays[op->inputs[LstmCellOperator::PREV_ACTIV_INPUT]]; + model->GetArray(op->inputs[LstmCellOperator::PREV_ACTIV_INPUT]); // Yield until all input dims have been resolved. if (!prev_activ_array.has_shape()) { return; @@ -726,7 +748,7 @@ void ProcessLstmCellOperator(Model* model, LstmCellOperator* op) { CHECK_GE(prev_activ_shape.dimensions_count(), 2); const auto& weights_array = - *model->arrays[op->inputs[LstmCellOperator::WEIGHTS_INPUT]]; + model->GetArray(op->inputs[LstmCellOperator::WEIGHTS_INPUT]); // Yield until weights dims have been resolved. if (!weights_array.has_shape()) { return; @@ -735,7 +757,7 @@ void ProcessLstmCellOperator(Model* model, LstmCellOperator* op) { CHECK_EQ(weights_shape.dimensions_count(), 2); const auto& bias_array = - *model->arrays[op->inputs[LstmCellOperator::BIASES_INPUT]]; + model->GetArray(op->inputs[LstmCellOperator::BIASES_INPUT]); // Yield until bias dims have been resolved. if (!bias_array.has_shape()) { return; @@ -744,7 +766,7 @@ void ProcessLstmCellOperator(Model* model, LstmCellOperator* op) { CHECK_GE(bias_shape.dimensions_count(), 1); const auto& prev_state_array = - *model->arrays[op->inputs[LstmCellOperator::PREV_STATE_INPUT]]; + model->GetArray(op->inputs[LstmCellOperator::PREV_STATE_INPUT]); // Yield until all input dims have been resolved. if (!prev_state_array.has_shape()) { return; @@ -784,7 +806,7 @@ void ProcessLstmCellOperator(Model* model, LstmCellOperator* op) { } void ProcessSpaceToBatchNDOperator(Model* model, SpaceToBatchNDOperator* op) { - const auto& input_array = *model->arrays[op->inputs[0]]; + const auto& input_array = model->GetArray(op->inputs[0]); // Yield until input dims have been resolved. if (!input_array.has_shape()) { return; @@ -797,8 +819,8 @@ void ProcessSpaceToBatchNDOperator(Model* model, SpaceToBatchNDOperator* op) { const auto input_height = input_shape.dims(1); const auto input_width = input_shape.dims(2); - const auto& block_shape_array = *model->arrays[op->inputs[1]]; - const auto& paddings_array = *model->arrays[op->inputs[2]]; + const auto& block_shape_array = model->GetArray(op->inputs[1]); + const auto& paddings_array = model->GetArray(op->inputs[2]); const auto& block_shape_array_shape = block_shape_array.shape(); const auto& paddings_array_shape = paddings_array.shape(); QCHECK_EQ(block_shape_array_shape.dimensions_count(), 1); @@ -830,13 +852,13 @@ void ProcessSpaceToBatchNDOperator(Model* model, SpaceToBatchNDOperator* op) { int output_height = height_with_paddings / block_height; int output_width = width_with_paddings / block_width; - model->arrays[op->outputs[0]]->copy_shape( - Shape({input_shape.dims(0) * block_height * block_width, output_height, - output_width, input_shape.dims(3)})); + model->GetArray(op->outputs[0]) + .copy_shape(Shape({input_shape.dims(0) * block_height * block_width, + output_height, output_width, input_shape.dims(3)})); } void ProcessBatchToSpaceNDOperator(Model* model, BatchToSpaceNDOperator* op) { - const auto& input_array = *model->arrays[op->inputs[0]]; + const auto& input_array = model->GetArray(op->inputs[0]); // Yield until input dims have been resolved. if (!input_array.has_shape()) { return; @@ -846,8 +868,8 @@ void ProcessBatchToSpaceNDOperator(Model* model, BatchToSpaceNDOperator* op) { const auto input_height = input_shape.dims(1); const auto input_width = input_shape.dims(2); - const auto& block_shape_array = *model->arrays[op->inputs[1]]; - const auto& crops_array = *model->arrays[op->inputs[2]]; + const auto& block_shape_array = model->GetArray(op->inputs[1]); + const auto& crops_array = model->GetArray(op->inputs[2]); const auto& block_shape_array_shape = block_shape_array.shape(); const auto& crops_array_shape = crops_array.shape(); QCHECK_EQ(block_shape_array_shape.dimensions_count(), 1); @@ -882,15 +904,15 @@ void ProcessBatchToSpaceNDOperator(Model* model, BatchToSpaceNDOperator* op) { int output_height = input_height * block_height; int output_width = input_width * block_width; - model->arrays[op->outputs[0]]->copy_shape( - Shape({input_shape.dims(0) / (block_height * block_width), output_height, - output_width, input_shape.dims(3)})); + model->GetArray(op->outputs[0]) + .copy_shape(Shape({input_shape.dims(0) / (block_height * block_width), + output_height, output_width, input_shape.dims(3)})); } void ProcessGatherOperator(Model* model, GatherOperator* op) { - const auto& input_array = *model->arrays[op->inputs[0]]; - const auto& indices_array = *model->arrays[op->inputs[1]]; - auto& output_array = *model->arrays[op->outputs[0]]; + const auto& input_array = model->GetArray(op->inputs[0]); + const auto& indices_array = model->GetArray(op->inputs[1]); + auto& output_array = model->GetArray(op->outputs[0]); // Bail if we already know the output shape. if (output_array.has_shape()) { @@ -924,7 +946,7 @@ void ProcessPadOperator(Model* model, PadOperator* op) { CHECK_EQ(op->inputs.size(), 2); CHECK_EQ(op->outputs.size(), 1); - const auto& input_array = *model->arrays[op->inputs[0]]; + const auto& input_array = model->GetArray(op->inputs[0]); // Yield until input dims have been resolved. if (!input_array.has_shape()) return; @@ -932,7 +954,7 @@ void ProcessPadOperator(Model* model, PadOperator* op) { if (op->left_padding.empty()) return; CHECK_EQ(op->left_padding.size(), op->right_padding.size()); - auto& output_array = *model->arrays[op->outputs[0]]; + auto& output_array = model->GetArray(op->outputs[0]); if (output_array.has_shape()) return; Shape output_shape = input_array.shape(); @@ -949,13 +971,13 @@ void ProcessPadOperator(Model* model, PadOperator* op) { void ProcessRankOperator(Model* model, RankOperator* op) { CHECK_GE(op->inputs.size(), 1); CHECK_EQ(op->outputs.size(), 1); - auto& output_array = *model->arrays[op->outputs[0]]; + auto& output_array = model->GetArray(op->outputs[0]); if (output_array.has_shape()) { // Shape already propagated return; } - const auto& input_array = *model->arrays[op->inputs[0]]; + const auto& input_array = model->GetArray(op->inputs[0]); if (!input_array.has_shape()) { // Yield until input dims have been resolved. return; @@ -970,13 +992,13 @@ void ProcessRankOperator(Model* model, RankOperator* op) { void ProcessShapeOperator(Model* model, TensorFlowShapeOperator* op) { CHECK_GE(op->inputs.size(), 1); CHECK_EQ(op->outputs.size(), 1); - auto& output_array = *model->arrays[op->outputs[0]]; + auto& output_array = model->GetArray(op->outputs[0]); if (output_array.has_shape()) { // Shape already propagated return; } - const auto& input_array = *model->arrays[op->inputs[0]]; + const auto& input_array = model->GetArray(op->inputs[0]); if (!input_array.has_shape()) { // Yield until input dims have been resolved. return; @@ -991,7 +1013,7 @@ void ProcessShapeOperator(Model* model, TensorFlowShapeOperator* op) { void ProcessStackOperator(Model* model, StackOperator* op) { CHECK_GE(op->inputs.size(), 1); CHECK_EQ(op->outputs.size(), 1); - auto& output_array = *model->arrays[op->outputs[0]]; + auto& output_array = model->GetArray(op->outputs[0]); if (output_array.has_shape()) { // Shape already propagated return; @@ -1032,7 +1054,7 @@ void ProcessStackOperator(Model* model, StackOperator* op) { void ProcessStridedSliceOperator(Model* model, StridedSliceOperator* op) { CHECK_GE(op->inputs.size(), 1); CHECK_EQ(op->outputs.size(), 1); - auto& output_array = *model->arrays[op->outputs[0]]; + auto& output_array = model->GetArray(op->outputs[0]); if (output_array.has_shape()) { // Shape already propagated return; @@ -1112,12 +1134,12 @@ void ProcessSqueezeOperator(Model* model, SqueezeOperator* op) { CHECK_EQ(op->inputs.size(), 1); CHECK_EQ(op->outputs.size(), 1); - const auto& input_array = *model->arrays[op->inputs[0]]; + const auto& input_array = model->GetArray(op->inputs[0]); // Yield until input dims have been resolved. if (!input_array.has_shape()) return; - auto& output_array = *model->arrays[op->outputs[0]]; + auto& output_array = model->GetArray(op->outputs[0]); if (output_array.has_shape()) return; const std::vector& input_dims = input_array.shape().dims(); @@ -1136,18 +1158,18 @@ void ProcessSqueezeOperator(Model* model, SqueezeOperator* op) { void ProcessSvdfOperator(Model* model, SvdfOperator* op) { CHECK(op->inputs.size() == 3 || op->inputs.size() == 4); - const auto& input_array = *model->arrays[op->inputs[0]]; + const auto& input_array = model->GetArray(op->inputs[0]); if (!input_array.has_shape()) return; - auto& weights_feature_array = *model->arrays[op->inputs[1]]; + auto& weights_feature_array = model->GetArray(op->inputs[1]); if (!weights_feature_array.has_shape()) return; - const auto& weights_time_array = *model->arrays[op->inputs[2]]; + const auto& weights_time_array = model->GetArray(op->inputs[2]); if (!weights_time_array.has_shape()) return; const bool has_bias = (op->inputs.size() == 4); if (has_bias) { - const auto& bias_array = *model->arrays[op->inputs[3]]; + const auto& bias_array = model->GetArray(op->inputs[3]); if (!bias_array.has_shape()) return; } @@ -1164,13 +1186,13 @@ void ProcessSvdfOperator(Model* model, SvdfOperator* op) { } void ProcessTransposeOperator(Model* model, TransposeOperator* op) { - auto& output_array = *model->arrays[op->outputs[0]]; + auto& output_array = model->GetArray(op->outputs[0]); if (output_array.has_shape()) { // We have already run return; } - const auto& input_array = *model->arrays[op->inputs[0]]; + const auto& input_array = model->GetArray(op->inputs[0]); if (!input_array.has_shape()) { // Yield until input dims have been resolved. return; @@ -1204,7 +1226,7 @@ void ProcessTransposeOperator(Model* model, TransposeOperator* op) { void ProcessArgMaxOperator(Model* model, ArgMaxOperator* op) { CHECK_EQ(op->inputs.size(), 2); - const auto& input_array = *model->arrays[op->inputs[0]]; + const auto& input_array = model->GetArray(op->inputs[0]); // Yield until input dims have been resolved. if (!input_array.has_shape()) { return; @@ -1222,7 +1244,7 @@ void ProcessArgMaxOperator(Model* model, ArgMaxOperator* op) { } output_dims.push_back(1); const string& output_name = op->outputs[0]; - auto& output_array = *model->arrays[output_name]; + auto& output_array = model->GetArray(output_name); if (output_array.has_shape()) { return; } @@ -1236,8 +1258,8 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { auto* op = it->get(); std::unordered_map> old_output_dims; for (const auto& output : op->outputs) { - if (model->arrays[output]->has_shape()) { - old_output_dims[output] = model->arrays[output]->shape().dims(); + if (model->GetArray(output).has_shape()) { + old_output_dims[output] = model->GetArray(output).shape().dims(); } } @@ -1282,6 +1304,9 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { case OperatorType::kTensorFlowGreaterEqual: ProcessSimpleBinaryOperator(model, op); break; + case OperatorType::kAddN: + ProcessAddNOperator(model, op); + break; case OperatorType::kConv: ProcessConvOperator(model, static_cast(op)); break; @@ -1433,10 +1458,10 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { // Return true if any output dim changed, false if none changed. // Assumption: no transformation clears an output shape, they only add shapes. for (const auto& output : op->outputs) { - if (model->arrays[output]->has_shape() && - (old_output_dims[output] != model->arrays[output]->shape().dims())) { + if (model->GetArray(output).has_shape() && + (old_output_dims[output] != model->GetArray(output).shape().dims())) { AddMessageF("Set shape of %s to [%s]", output, - absl::StrJoin(model->arrays[output]->shape().dims(), ",")); + absl::StrJoin(model->GetArray(output).shape().dims(), ",")); return true; } } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc b/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc index 56082b965a7cbd9d61cca2e26f7d76764c0e54aa..b973b2b813147cc580d2e87cea7d395f180f5aa1 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc @@ -412,7 +412,7 @@ bool Quantize::Run(Model* model, std::size_t op_index) { model->flags.set_output_arrays(i, dequantize_op->inputs[0]); } } - model->arrays.erase(dequantize_op->outputs[0]); + model->EraseArray(dequantize_op->outputs[0]); model->operators.erase(dequantize_it); } } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/read_fake_quant_min_max.cc b/tensorflow/contrib/lite/toco/graph_transformations/read_fake_quant_min_max.cc index 371ced388a8111c18ada32cf31a784809479291d..11f8d4b6eea836c5fe4efcbd5136e6183a59dc62 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/read_fake_quant_min_max.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/read_fake_quant_min_max.cc @@ -80,7 +80,7 @@ bool ReadFakeQuantMinMax::Run(Model* model, std::size_t op_index) { // else. for (int i = 1; i <= 2; i++) { if (CountOpsWithInput(*model, fakequant_op->inputs[i]) == 1) { - model->arrays.erase(fakequant_op->inputs[i]); + model->EraseArray(fakequant_op->inputs[i]); } } fakequant_op->inputs.resize(1); diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_final_dequantize_op.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_final_dequantize_op.cc index 3992e7d1ef71edd4040e626d5848d2fd9bb3dab6..c3b2709a33d54213661ba96394b01aa2cfd1a278 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/remove_final_dequantize_op.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_final_dequantize_op.cc @@ -51,7 +51,7 @@ bool RemoveFinalDequantizeOp::Run(Model* model, std::size_t op_index) { // Remove the node and its output array. AddMessageF("Removed final %s", LogName(*dequantize_op)); - model->arrays.erase(output); + model->EraseArray(output); model->operators.erase(dequantize_it); return true; } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_binary.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_binary.cc index 6add443f2d62fd06e8c0d17e03bc78c5d74732a1..95a50c61794092b02e518d1f08d8cf4a668353a8 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_binary.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_binary.cc @@ -81,7 +81,7 @@ bool RemoveTrivialBinaryOperator::Run(Model* model, std::size_t op_index) { // Now check if the constant operand makes this binary // operator trivial. const auto& constant_input_array = - *model->arrays[binary_op->inputs[index_of_constant_input]]; + model->GetArray(binary_op->inputs[index_of_constant_input]); // For now, we only handle floats here. if (constant_input_array.data_type != ArrayDataType::kFloat) { return false; @@ -89,14 +89,14 @@ bool RemoveTrivialBinaryOperator::Run(Model* model, std::size_t op_index) { const auto& constant_input_float_data = constant_input_array.GetBuffer().data; bool is_trivial = false; - if (binary_op->type != OperatorType::kAdd) { + if (binary_op->type == OperatorType::kAdd) { is_trivial = AreAllBufferElementsEqualTo(constant_input_float_data, 0.f); - } else if (binary_op->type != OperatorType::kSub) { + } else if (binary_op->type == OperatorType::kSub) { is_trivial = index_of_constant_input == 1 && AreAllBufferElementsEqualTo(constant_input_float_data, 0.f); - } else if (binary_op->type != OperatorType::kMul) { + } else if (binary_op->type == OperatorType::kMul) { is_trivial = AreAllBufferElementsEqualTo(constant_input_float_data, 1.f); - } else if (binary_op->type != OperatorType::kDiv) { + } else if (binary_op->type == OperatorType::kDiv) { is_trivial = index_of_constant_input == 1 && AreAllBufferElementsEqualTo(constant_input_float_data, 1.f); } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_concatenation_input.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_concatenation_input.cc index 23a5c857e8b19f7edbb48f2c004d03e21008833d..936854a04fd600ea23ab5dda50370f85a311c28c 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_concatenation_input.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_concatenation_input.cc @@ -59,7 +59,7 @@ bool RemoveTrivialConcatenationInput::Run(Model* model, std::size_t op_index) { for (const string& input : trivial_inputs) { if (IsDiscardableArray(*model, input) && CountOpsWithInput(*model, input) == 1) { - model->arrays.erase(input); + model->EraseArray(input); } } concat_op->inputs = nontrivial_inputs; diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.cc index 047389f69a1d8987b52b07478b0d3eaf46f433ba..587f171bbf823408a45083c36d52f1d38c300123 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.cc @@ -124,7 +124,7 @@ bool RemoveTrivialPassthroughOp(GraphTransformation* transformation, } } if (!is_referenced) { - model->arrays.erase(removal_candidate); + model->EraseArray(removal_candidate); } } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h index a06181ca0b5f1cbb930fa4295fec3d6adf66440d..9d448c3ee9088c16b96aa7ddc84457d2cab3231a 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h +++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_GRAPH_TRANSFORMATIONS_REMOVE_TRIVIAL_PASSTHROUGH_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_GRAPH_TRANSFORMATIONS_REMOVE_TRIVIAL_PASSTHROUGH_H_ +#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_GRAPH_TRANSFORMATIONS_REMOVE_TRIVIAL_PASSTHROUGH_H_ +#define TENSORFLOW_CONTRIB_LITE_TOCO_GRAPH_TRANSFORMATIONS_REMOVE_TRIVIAL_PASSTHROUGH_H_ #include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" #include "tensorflow/contrib/lite/toco/model.h" @@ -54,4 +54,4 @@ bool RemoveTrivialPassthroughOp(GraphTransformation* transformation, } // namespace toco -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_GRAPH_TRANSFORMATIONS_REMOVE_TRIVIAL_PASSTHROUGH_H_ +#endif // TENSORFLOW_CONTRIB_LITE_TOCO_GRAPH_TRANSFORMATIONS_REMOVE_TRIVIAL_PASSTHROUGH_H_ diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_unused_op.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_unused_op.cc index e6cca8acf36745d989fb731aa948f257375d7e90..aa2c293382a98b476bee783ed8e177b19d35b858 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/remove_unused_op.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_unused_op.cc @@ -33,7 +33,7 @@ bool RemoveUnusedOp::Run(Model* model, std::size_t op_index) { // the model. We allow specifying an arbitrary input_array, // treating the part of the graph leading up to it as unused. for (const auto& output : op->outputs) { - CHECK(model->arrays.count(output)); + CHECK(model->HasArray(output)); // If this output is provided as the model's input array, // then we don't need this operator to produce its contents. if (IsInputArray(*model, output)) { @@ -93,7 +93,7 @@ bool RemoveUnusedOp::Run(Model* model, std::size_t op_index) { if (IsDiscardableArray(*model, input) && CountOpsWithInput(*model, input) == 1 && !GetOpWithOutput(*model, input)) { - model->arrays.erase(input); + model->EraseArray(input); } } @@ -116,7 +116,7 @@ bool RemoveUnusedOp::Run(Model* model, std::size_t op_index) { continue; } // Generic case: do delete this output array. - model->arrays.erase(output); + model->EraseArray(output); } model->operators.erase(it); return true; diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_normalization.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_normalization.cc index 3eb7fa3896c57ea612f21f8b4f3fa568d19420d4..fb109eb91b16e3a73005230f821c18b9ef82d2fb 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_normalization.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_normalization.cc @@ -121,9 +121,9 @@ bool ResolveBatchNormalization::Run(Model* model, std::size_t op_index) { } // Remove the old param arrays - model->arrays.erase(bn_op->inputs[1]); - model->arrays.erase(bn_op->inputs[2]); - model->arrays.erase(bn_op->inputs[3]); + model->EraseArray(bn_op->inputs[1]); + model->EraseArray(bn_op->inputs[2]); + model->EraseArray(bn_op->inputs[3]); // Remove the old operator DCHECK_EQ(bn_it->get(), bn_op); diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_to_space_nd_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_to_space_nd_attributes.cc index 7777d4f54359071c775806999ecf1418a8762d60..a06919e228dc2084f8943a714a0ca111d013c159 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_to_space_nd_attributes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_to_space_nd_attributes.cc @@ -42,7 +42,7 @@ bool ResolveBatchToSpaceNDAttributes::Run(Model* model, std::size_t op_index) { return false; // Handle crops - const auto& crops_array = *model->arrays[op->inputs[2]]; + const auto& crops_array = model->GetArray(op->inputs[2]); if (!crops_array.has_shape()) return false; const std::vector& crops_dims = crops_array.shape().dims(); if (crops_dims.size() != 2) { @@ -58,7 +58,7 @@ bool ResolveBatchToSpaceNDAttributes::Run(Model* model, std::size_t op_index) { } // Handle block_shape - const auto& block_shape_array = *model->arrays[op->inputs[1]]; + const auto& block_shape_array = model->GetArray(op->inputs[1]); if (!block_shape_array.has_shape()) return false; const std::vector& block_shape_dims = block_shape_array.shape().dims(); CHECK_EQ(block_shape_dims.size(), 1); diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_binary.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_binary.cc index fd51df4058dbda4732686983f9b9dab3781ec4d1..5e779f6765262326bd59db886c2feed603e0102e 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_binary.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_binary.cc @@ -166,8 +166,9 @@ void EvaluateBinaryOperatorOnConstantInputs(Model* model, void EvaluateBinaryOperatorOnConstantInputs(Model* model, const Operator* binary_op) { - const auto inputs_data_type = model->arrays[binary_op->inputs[0]]->data_type; - const auto output_data_type = model->arrays[binary_op->outputs[0]]->data_type; + const auto inputs_data_type = model->GetArray(binary_op->inputs[0]).data_type; + const auto output_data_type = + model->GetArray(binary_op->outputs[0]).data_type; #define TOCO_HANDLE_CASE(InputsDataType, OutputDataType) \ if (inputs_data_type == InputsDataType && \ output_data_type == OutputDataType) { \ @@ -214,7 +215,7 @@ bool ResolveConstantBinaryOperator::Run(Model* model, std::size_t op_index) { return false; } - auto& output_array = *model->arrays[binary_op->outputs[0]]; + auto& output_array = model->GetArray(binary_op->outputs[0]); // Yield until the output array dims have been resolved. if (!output_array.has_shape()) { return false; @@ -239,10 +240,10 @@ bool ResolveConstantBinaryOperator::Run(Model* model, std::size_t op_index) { // Remove the binary operator and its inputs if (CountOpsWithInput(*model, binary_op->inputs[0]) == 1) { - model->arrays.erase(binary_op->inputs[0]); + model->EraseArray(binary_op->inputs[0]); } if (CountOpsWithInput(*model, binary_op->inputs[1]) == 1) { - model->arrays.erase(binary_op->inputs[1]); + model->EraseArray(binary_op->inputs[1]); } AddMessageF("Resolved constant %s to the equivalent constant array", LogName(*binary_op)); diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_concatenation.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_concatenation.cc index 9835f86398a37f118d3ebd5b568ffddbcd56c38b..5ac449749adbc9b5422f996eeccb72575dca8722 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_concatenation.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_concatenation.cc @@ -189,7 +189,10 @@ bool ResolveConstantConcatenation::Run(Model* model, std::size_t op_index) { // Remove all the resolved arrays. for (const string& input_name : concat_op->inputs) { - model->arrays.erase(input_name); + // Check to prevent removal of shared tensors + if(CountOpsWithInput(*model, input_name) == 1) { + model->EraseArray(input_name); + } } // Remove concatenate operator diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc index 244adcc4c46eda9de79dd753565113bbeca970c5..81fe37d7e017c6e2440de34cc2daedf7fb2a422e 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc @@ -66,7 +66,7 @@ bool ResolveConstantFakeQuant::Run(Model* model, std::size_t op_index) { output_buffer.data[i] = dst_val; } if (CountOpsWithInput(*model, fakequant_op->inputs[0]) == 1) { - model->arrays.erase(fakequant_op->inputs[0]); + model->EraseArray(fakequant_op->inputs[0]); } model->operators.erase(fakequant_it); diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fill.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fill.cc index 9da51d9147a98a935d00db04827aa7ebb12998b9..f6f95481b57f58f497b119df73d331f13d9705c0 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fill.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fill.cc @@ -104,11 +104,11 @@ bool ResolveConstantFill::Run(Model* model, std::size_t op_index) { // Erase input arrays if no longer used if (IsDiscardableArray(*model, op->inputs[0]) && CountOpsWithInput(*model, op->inputs[0]) == 1) { - model->arrays.erase(op->inputs[0]); + model->EraseArray(op->inputs[0]); } if (IsDiscardableArray(*model, op->inputs[1]) && CountOpsWithInput(*model, op->inputs[1]) == 1) { - model->arrays.erase(op->inputs[1]); + model->EraseArray(op->inputs[1]); } // Erase the operator diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_range.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_range.cc index 383d54aa5a7fa4933a9eb9ffac014bab4497d40d..1a0ba9e2bc7235720b59210cdd6affa089613077 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_range.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_range.cc @@ -28,17 +28,17 @@ bool ResolveConstantRange::Run(Model* model, std::size_t op_index) { auto* op = static_cast(base_op); CHECK_EQ(op->inputs.size(), 3); - const auto& start_array = *model->arrays[op->inputs[0]]; + const auto& start_array = model->GetArray(op->inputs[0]); if (!start_array.has_shape()) { // Yield until all input dims have been resolved. return false; } - const auto& limit_array = *model->arrays[op->inputs[1]]; + const auto& limit_array = model->GetArray(op->inputs[1]); if (!limit_array.has_shape()) { // Yield until all input dims have been resolved. return false; } - const auto& delta_array = *model->arrays[op->inputs[2]]; + const auto& delta_array = model->GetArray(op->inputs[2]); if (!delta_array.has_shape()) { // Yield until all input dims have been resolved. return false; @@ -52,7 +52,7 @@ bool ResolveConstantRange::Run(Model* model, std::size_t op_index) { } CHECK_EQ(op->outputs.size(), 1); - auto& output_array = *model->arrays[op->outputs[0]]; + auto& output_array = model->GetArray(op->outputs[0]); if (output_array.data_type == ArrayDataType::kNone) { // Yield until the output type has been set by PropagateArrayDataTypes return false; @@ -87,15 +87,15 @@ bool ResolveConstantRange::Run(Model* model, std::size_t op_index) { // Delete the input array if no longer used if (IsDiscardableArray(*model, op->inputs[0]) && CountOpsWithInput(*model, op->inputs[0]) == 1) { - model->arrays.erase(op->inputs[0]); + model->EraseArray(op->inputs[0]); } if (IsDiscardableArray(*model, op->inputs[1]) && CountOpsWithInput(*model, op->inputs[1]) == 1) { - model->arrays.erase(op->inputs[1]); + model->EraseArray(op->inputs[1]); } if (IsDiscardableArray(*model, op->inputs[2]) && CountOpsWithInput(*model, op->inputs[2]) == 1) { - model->arrays.erase(op->inputs[2]); + model->EraseArray(op->inputs[2]); } // Delete the operator diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_shape_or_rank.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_shape_or_rank.cc index 35b81dd5506cfb0048ab1347bfefd07b128bc92b..9ea01acd05364224ce219bed533c999793a2a2f1 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_shape_or_rank.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_shape_or_rank.cc @@ -62,7 +62,7 @@ bool ResolveConstantShapeOrRank::Run(Model* model, std::size_t op_index) { // Delete the input array if no longer used if (IsDiscardableArray(*model, op->inputs[0]) && CountOpsWithInput(*model, op->inputs[0]) == 1) { - model->arrays.erase(op->inputs[0]); + model->EraseArray(op->inputs[0]); } model->operators.erase(it); diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_stack.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_stack.cc index 86c76141a4705de841c8e70790cce7be28fb59c9..ea0d6dc8200897db9266efbe41556dbf4c296db3 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_stack.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_stack.cc @@ -101,7 +101,7 @@ bool ResolveConstantStack::Run(Model* model, std::size_t op_index) { for (const auto& input : op->inputs) { if (IsDiscardableArray(*model, input) && CountOpsWithInput(*model, input) == 1) { - model->arrays.erase(input); + model->EraseArray(input); } } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc index 3976d9cbb492138c0c45801045833e08411acbd4..a0cfc3d59763dc1211ed4d1ac114d371a4a7ee0b 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc @@ -186,7 +186,7 @@ bool ResolveConstantStridedSlice::Run(Model* model, std::size_t op_index) { // Erase input array if no longer used if (IsDiscardableArray(*model, op->inputs[0]) && CountOpsWithInput(*model, op->inputs[0]) == 1) { - model->arrays.erase(op->inputs[0]); + model->EraseArray(op->inputs[0]); } // Erase the operator diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc index 26ff9d887b40651559ad030cd41a824679d6dd15..1cd2aff28c68eaba4e9b18d8e2c2803834328696 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc @@ -199,7 +199,7 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) { } for (const auto& input : unary_op->inputs) { if (CountOpsWithInput(*model, input) == 1) { - model->arrays.erase(input); + model->EraseArray(input); } } AddMessageF("Resolved constant %s to the equivalent constant array", diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_mean_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_mean_attributes.cc index b77be3f5c0d04b028391c1ce9de39afd7632eb36..013b50ac9ba8a51c23b19953d987b2fbf63fcea1 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_mean_attributes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_mean_attributes.cc @@ -36,7 +36,7 @@ bool ResolveMeanAttributes::Run(Model* model, std::size_t op_index) { if (op->inputs.size() != 2) return false; if (!IsConstantParameterArray(*model, op->inputs[1])) return false; - const auto& indices_array = *model->arrays[op->inputs[1]]; + const auto& indices_array = model->GetArray(op->inputs[1]); if (!indices_array.has_shape()) return false; op->axis = indices_array.GetBuffer().data; return true; diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_pad_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_pad_attributes.cc index d5f5869c625f419a825f6bd652a04eca1bce4a6f..8a8e723cf7b2d77ec199e3817464a068bf85afdd 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_pad_attributes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_pad_attributes.cc @@ -35,7 +35,7 @@ bool ResolvePadAttributes::Run(Model* model, std::size_t op_index) { CHECK_EQ(op->inputs.size(), 2); if (!IsConstantParameterArray(*model, op->inputs[1])) return false; - const auto& array = *model->arrays[op->inputs[1]]; + const auto& array = model->GetArray(op->inputs[1]); if (!array.has_shape()) return false; const std::vector& dims = array.shape().dims(); diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_reorder_axes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_reorder_axes.cc index b5093bc4c7c33b3e555ca14151c2489cddc6dbd3..5c68f87f6ccd912a94213c95a59a78076b0e768b 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_reorder_axes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_reorder_axes.cc @@ -103,7 +103,7 @@ bool ResolveReorderAxes::Run(Model* model, std::size_t op_index) { AddMessageF("Reordered axes for array %s", input_array_name); // Remove the op and output array. - model->arrays.erase(output_array_name); + model->EraseArray(output_array_name); model->operators.erase(reorder_it); return true; } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_reshape_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_reshape_attributes.cc index bed2a85bd262c49913f22e522d260c4dc6510246..2e063e35548aa5e51c3bcc94a2dfc7992180d014 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_reshape_attributes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_reshape_attributes.cc @@ -37,7 +37,7 @@ bool ResolveReshapeAttributes::Run(Model* model, std::size_t op_index) { if (!op->shape.empty()) return false; if (IsConstantParameterArray(*model, reshape_op->inputs[1])) { - const auto& constant_input_array = *model->arrays[reshape_op->inputs[1]]; + const auto& constant_input_array = model->GetArray(reshape_op->inputs[1]); op->shape = constant_input_array.GetBuffer().data; } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_slice_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_slice_attributes.cc index 1d0a2ec8f6c1f532f23873062534a37e07fff72b..e760d08e5a6c2f56db6b11fee922b701d33dd1a0 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_slice_attributes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_slice_attributes.cc @@ -36,10 +36,10 @@ bool ResolveSliceAttributes::Run(Model* model, std::size_t op_index) { if (!IsConstantParameterArray(*model, op->inputs[1])) return false; if (!IsConstantParameterArray(*model, op->inputs[2])) return false; - const auto& begin_array = *model->arrays[op->inputs[1]]; + const auto& begin_array = model->GetArray(op->inputs[1]); if (!begin_array.has_shape()) return false; - const auto& size_array = *model->arrays[op->inputs[2]]; + const auto& size_array = model->GetArray(op->inputs[2]); if (!size_array.has_shape()) return false; op->begin = begin_array.GetBuffer().data; diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_space_to_batch_nd_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_space_to_batch_nd_attributes.cc index a73f16735cb232753e8f64caae31f5c945b6bffd..dad6aceccfd201b3db07c29c99a8c6ef75bb89a1 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_space_to_batch_nd_attributes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_space_to_batch_nd_attributes.cc @@ -45,7 +45,7 @@ bool ResolveSpaceToBatchNDAttributes::Run(Model* model, std::size_t op_index) { return false; // Handle paddings. - const auto& paddings_array = *model->arrays[op->inputs[paddings_index]]; + const auto& paddings_array = model->GetArray(op->inputs[paddings_index]); if (!paddings_array.has_shape()) return false; const std::vector& paddings_dims = paddings_array.shape().dims(); if (paddings_dims.size() != 2) { @@ -61,7 +61,8 @@ bool ResolveSpaceToBatchNDAttributes::Run(Model* model, std::size_t op_index) { } // Handle block_shape. - const auto& block_shape_array = *model->arrays[op->inputs[block_shape_index]]; + const auto& block_shape_array = + model->GetArray(op->inputs[block_shape_index]); if (!block_shape_array.has_shape()) return false; const std::vector& block_shape_dims = block_shape_array.shape().dims(); CHECK_EQ(block_shape_dims.size(), 1); diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_strided_slice_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_strided_slice_attributes.cc index dbe69adcbd34bb0544239ebb096fb8bfc4bfcb49..7e8b249b07ecca551cbb75afd8007efad0b52eaf 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_strided_slice_attributes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_strided_slice_attributes.cc @@ -31,13 +31,17 @@ bool ResolveStridedSliceAttributes::Run(Model* model, std::size_t op_index) { } CHECK_EQ(op->inputs.size(), 4); - const auto& start_array = *model->arrays[op->inputs[1]]; + const auto& start_array = model->GetArray(op->inputs[1]); if (!start_array.has_shape()) return false; + if (toco::RequiredBufferSizeForShape(start_array.shape()) > 4) { + // Only 1-4D arrays are supported for now. + return false; + } - const auto& stop_array = *model->arrays[op->inputs[2]]; + const auto& stop_array = model->GetArray(op->inputs[2]); if (!stop_array.has_shape()) return false; - const auto& stride_array = *model->arrays[op->inputs[3]]; + const auto& stride_array = model->GetArray(op->inputs[3]); if (!stride_array.has_shape()) return false; if (!IsConstantParameterArray(*model, op->inputs[1])) return false; diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_concat.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_concat.cc index c6723a880ed0e51cc5828f77742a6c8eb70fa864..5c0c1e3478fa0d94104d1b76bab176b98b314c50 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_concat.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_concat.cc @@ -75,7 +75,7 @@ bool ResolveTensorFlowConcat::Run(Model* model, std::size_t op_index) { // Remove the axis array if it is not used by anything else. if (CountOpsWithInput(*model, axis_name) == 1) { - model->arrays.erase(axis_name); + model->EraseArray(axis_name); } // Remove the TensorFlowConcat op model->operators.erase(concat_it); diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc index bea7487051a58344a56a3186a05d0fdceebc8727..ad1e56888e53133c5a84cc0e3d5e76b7ef3b29b4 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc @@ -69,7 +69,7 @@ bool ResolveTensorFlowMatMul::Run(Model* model, std::size_t op_index) { LogName(*matmul_op), LogName(*fc_op)); const auto& previous_op_output = previous_op->outputs[0]; if (CountOpsWithInput(*model, previous_op_output) == 1) { - model->arrays.erase(previous_op_output); + model->EraseArray(previous_op_output); } CHECK_EQ(previous_op->inputs.size(), 2); fc_op->inputs = {previous_op->inputs[0], matmul_op->inputs[1]}; @@ -78,7 +78,7 @@ bool ResolveTensorFlowMatMul::Run(Model* model, std::size_t op_index) { const auto& previous_op_shape = previous_op->inputs[1]; if (CountOpsWithInput(*model, previous_op_shape) == 1 && !GetOpWithOutput(*model, previous_op_shape)) { - model->arrays.erase(previous_op_shape); + model->EraseArray(previous_op_shape); } model->operators.erase(previous_op_it); } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_merge.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_merge.cc index cfa5ce0716523adbfb0a76e89ce3b202f0595763..477e7f13da3d88a68547d494011cd4984936b909 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_merge.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_merge.cc @@ -55,7 +55,7 @@ bool ResolveTensorFlowMerge::Run(Model* model, std::size_t op_index) { // Remove the node and its output array. AddMessageF("Removing already-resolved %s", LogName(*merge_op)); - model->arrays.erase(merge_op->outputs[0]); + model->EraseArray(merge_op->outputs[0]); model->operators.erase(merge_it); return true; } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_switch.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_switch.cc index 150cf53da3099227c5c637ee58c44512d5a41d4f..a418073441f1241a5acb1164b36f332828ea2e99 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_switch.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_switch.cc @@ -103,7 +103,7 @@ bool ResolveTensorFlowSwitch::Run(Model* model, std::size_t op_index) { // Remove the output arrays if they are now unused. for (int i = 0; i < 2; i++) { if (!GetOpWithInput(*model, switch_op->outputs[i])) { - model->arrays.erase(switch_op->outputs[i]); + model->EraseArray(switch_op->outputs[i]); } } // Remove input arrays if they are only used by the switch itself and aren't @@ -111,7 +111,7 @@ bool ResolveTensorFlowSwitch::Run(Model* model, std::size_t op_index) { for (const auto& input : switch_op->inputs) { if (CountOpsWithInput(*model, input) == 1 && !GetOpWithOutput(*model, input)) { - model->arrays.erase(input); + model->EraseArray(input); } } // Remove the switch node itself. diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_tile.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_tile.cc index 9f7e7c42a26b60c96573be6653babb78fdb5fd73..1ddf54c778cd1fae7a8fce0ecb97209274e71ac0 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_tile.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_tile.cc @@ -45,10 +45,10 @@ void RemoveTileOperator(Model* model, Operator* tile_op, Operator* binary_op, model->operators.erase(tile_it); if (!CountOpsWithInput(*model, tile_multiplier_array) && !GetOpWithOutput(*model, tile_multiplier_array)) { - model->arrays.erase(tile_multiplier_array); + model->EraseArray(tile_multiplier_array); } if (!CountOpsWithInput(*model, tile_output_array)) { - model->arrays.erase(tile_output_array); + model->EraseArray(tile_output_array); } } } // namespace diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_transpose_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_transpose_attributes.cc index 12d966b26104fd491f914fbdb39e0a62fdda19bc..a657ee00af66bd431f96c361e12d5213e203b3df 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_transpose_attributes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_transpose_attributes.cc @@ -35,7 +35,7 @@ bool ResolveTransposeAttributes::Run(Model* model, std::size_t op_index) { if (!IsConstantParameterArray(*model, op->inputs[1])) return false; // Handling perm. - const auto& perm_array = *model->arrays[op->inputs[1]]; + const auto& perm_array = model->GetArray(op->inputs[1]); if (!perm_array.has_shape()) return false; const std::vector& perm_dims = perm_array.shape().dims(); diff --git a/tensorflow/contrib/lite/toco/graph_transformations/tests/resolve_constant_concatenation_test.cc b/tensorflow/contrib/lite/toco/graph_transformations/tests/resolve_constant_concatenation_test.cc index a14016e8e2705a66c392118899335eb3997fa1de..3a1d175b9823f085c9b8730caba8bedd7eb87d52 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/tests/resolve_constant_concatenation_test.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/tests/resolve_constant_concatenation_test.cc @@ -19,7 +19,6 @@ limitations under the License. #include #include -//#include "tensorflow/contrib/lite/kernels/test_util.h" #include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" #include "tensorflow/contrib/lite/toco/model.h" #include "tensorflow/contrib/lite/toco/tooling_util.h" @@ -168,11 +167,11 @@ TEST_F(ResolveConstantConcatenationTest, ConcatAtAxis0) { GraphTransformationsSet graph_transformation_set; graph_transformation_set.Add(new toco::ResolveConstantConcatenation); - EXPECT_THAT(model.arrays.size(), 5); + EXPECT_THAT(model.GetArrayMap().size(), 5); (*graph_transformation_set.begin())->Run(&model, /*op_index=*/0); - EXPECT_THAT(model.arrays.size(), 1); + EXPECT_THAT(model.GetArrayMap().size(), 1); - auto& concatenated_array = (*model.arrays.begin()).second; + auto& concatenated_array = (*model.GetArrayMap().begin()).second; EXPECT_THAT(concatenated_array->GetBuffer().data, ElementsAreArray(ArrayFloatNear( {0., 1., 2., 3., 4., 5., 6., 7., 10., 11., 12., @@ -187,11 +186,11 @@ TEST_F(ResolveConstantConcatenationTest, ConcatAtAxis1) { GraphTransformationsSet graph_transformation_set; graph_transformation_set.Add(new toco::ResolveConstantConcatenation); - EXPECT_THAT(model.arrays.size(), 5); + EXPECT_THAT(model.GetArrayMap().size(), 5); (*graph_transformation_set.begin())->Run(&model, /*op_index=*/0); - EXPECT_THAT(model.arrays.size(), 1); + EXPECT_THAT(model.GetArrayMap().size(), 1); - auto& concatenated_array = (*model.arrays.begin()).second; + auto& concatenated_array = (*model.GetArrayMap().begin()).second; EXPECT_THAT(concatenated_array->GetBuffer().data, ElementsAreArray(ArrayFloatNear( {0., 1., 2., 3., 10., 11., 12., 13., 20., 21., 22., @@ -206,11 +205,11 @@ TEST_F(ResolveConstantConcatenationTest, ConcatAtAxis2) { GraphTransformationsSet graph_transformation_set; graph_transformation_set.Add(new toco::ResolveConstantConcatenation); - EXPECT_THAT(model.arrays.size(), 5); + EXPECT_THAT(model.GetArrayMap().size(), 5); (*graph_transformation_set.begin())->Run(&model, /*op_index=*/0); - EXPECT_THAT(model.arrays.size(), 1); + EXPECT_THAT(model.GetArrayMap().size(), 1); - auto& concatenated_array = (*model.arrays.begin()).second; + auto& concatenated_array = (*model.GetArrayMap().begin()).second; EXPECT_THAT(concatenated_array->GetBuffer().data, ElementsAreArray(ArrayFloatNear( {0., 1., 10., 11., 20., 21., 30., 31., 2., 3., 12., diff --git a/tensorflow/contrib/lite/toco/graph_transformations/unfuse_activation_functions.cc b/tensorflow/contrib/lite/toco/graph_transformations/unfuse_activation_functions.cc index 4e273343df9f3e5ade8f23a2fbd868bcab72c62e..2c7046c8c77c94a89fc05a26d7d72b3661380475 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/unfuse_activation_functions.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/unfuse_activation_functions.cc @@ -63,7 +63,7 @@ bool UnfuseActivationFunctions::Run(Model* model, std::size_t op_index) { ac_op->outputs = op->outputs; const string& tmp_array_name = AvailableArrayName(*model, op->outputs[0] + "_unfused"); - CHECK(!model->arrays.count(tmp_array_name)); + CHECK(!model->HasArray(tmp_array_name)); model->GetOrCreateArray(tmp_array_name); ac_op->inputs = {tmp_array_name}; op->outputs = {tmp_array_name}; diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc index 995e9d67ca3ae34471595d2d629d2fe993c21ab5..ca378af4c5c1e1b8cf42a10d3820db3feeb49a05 100644 --- a/tensorflow/contrib/lite/toco/import_tensorflow.cc +++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc @@ -696,6 +696,19 @@ void ConvertAddOperator(const NodeDef& node, model->operators.emplace_back(op); } +void ConvertAddNOperator(const NodeDef& node, + const TensorFlowImportFlags& tf_import_flags, + Model* model) { + CHECK_EQ(node.op(), "AddN"); + const int num_inputs = GetInputsCount(node, tf_import_flags); + auto* op = new AddNOperator; + for (int i = 0; i < num_inputs; ++i) { + op->inputs.push_back(node.input(i)); + } + op->outputs.push_back(node.name()); + model->operators.emplace_back(op); +} + void ConvertMulOperator(const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, Model* model) { @@ -1179,6 +1192,8 @@ void ConvertStridedSliceOperator(const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, Model* model) { CHECK_EQ(node.op(), "StridedSlice"); + // TODO(soroosh): The 4th input (strides) should be e optional, to be + // consistent with TF. CheckInputsCount(node, tf_import_flags, 4); auto* op = new StridedSliceOperator; @@ -1652,7 +1667,7 @@ void StripCaretFromArrayNames(Model* model) { output = string(absl::StripPrefix(output, "^")); } } - for (auto& array : model->arrays) { + for (auto& array : model->GetArrayMap()) { if (absl::StartsWith(array.first, "^")) { LOG(FATAL) << "What?"; } @@ -1860,6 +1875,8 @@ std::unique_ptr ImportTensorFlowGraphDef( ConvertSquareOperator(node, tf_import_flags, model); } else if (node.op() == "Add") { ConvertAddOperator(node, tf_import_flags, model); + } else if (node.op() == "AddN") { + ConvertAddNOperator(node, tf_import_flags, model); } else if (node.op() == "Mul") { ConvertMulOperator(node, tf_import_flags, model); } else if (node.op() == "Sub") { diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.h b/tensorflow/contrib/lite/toco/import_tensorflow.h index 312e3b8f17cfaa012bf25696937f97d396802bb2..2177872334bfec6147f865be1518e440c2c636ea 100644 --- a/tensorflow/contrib/lite/toco/import_tensorflow.h +++ b/tensorflow/contrib/lite/toco/import_tensorflow.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_IMPORT_TENSORFLOW_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_IMPORT_TENSORFLOW_H_ +#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_IMPORT_TENSORFLOW_H_ +#define TENSORFLOW_CONTRIB_LITE_TOCO_IMPORT_TENSORFLOW_H_ #include #include @@ -39,4 +39,4 @@ std::unique_ptr ImportTensorFlowGraphDef( } // namespace toco -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_IMPORT_TENSORFLOW_H_ +#endif // TENSORFLOW_CONTRIB_LITE_TOCO_IMPORT_TENSORFLOW_H_ diff --git a/tensorflow/contrib/lite/toco/model.h b/tensorflow/contrib/lite/toco/model.h index 7b2235e2751e1bb359195a3d69f91725a5463434..d1af371fd4c43d7059bfd70597ea765c9c2e51fd 100644 --- a/tensorflow/contrib/lite/toco/model.h +++ b/tensorflow/contrib/lite/toco/model.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_MODEL_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_MODEL_H_ +#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_MODEL_H_ +#define TENSORFLOW_CONTRIB_LITE_TOCO_MODEL_H_ #include #include @@ -32,6 +32,7 @@ enum class OperatorType { kNone, // General-purpose neural network operators. kAdd, + kAddN, kAveragePool, kBatchNormalization, kConv, @@ -559,6 +560,16 @@ struct AddOperator : Operator { AddOperator() : Operator(OperatorType::kAdd) {} }; +// Element-wise addition operator for N inputs. +// +// Inputs: +// inputs[i]: The i-th array to add together to form the output. +// +// TensorFlow equivalent: AddN +struct AddNOperator : Operator { + AddNOperator() : Operator(OperatorType::kAddN) {} +}; + // Concatenation operator: concatenates its inputs // along the axis. // @@ -738,6 +749,9 @@ struct PadOperator : Operator { // // Inputs: // inputs[0]: required: the input array +// inputs[1]: required: the begin array +// inputs[2]: required: the end array +// inputs[3]: optional: the strides array // // TensorFlow equivalent: StridedSlice struct StridedSliceOperator : Operator { @@ -1521,29 +1535,54 @@ struct Array { // Our Model struct, represents an entire model (our "top-level" struct). // Owns everything. -struct Model { +class Model { + public: + using ArrayMap = std::unordered_map>; + + bool HasArray(const string& name) const { return arrays.count(name) > 0; } Array& GetArray(const string& name) const { - DCHECK(arrays.count(name)); + DCHECK(HasArray(name)); return *arrays.at(name); } Array& GetOrCreateArray(const string& name) { - if (!arrays.count(name)) { + // Make sure name is not used by an optional array + DCHECK(!optional_arrays.count(name)); + if (!HasArray(name)) { Array* ptr = new Array; arrays[name] = std::unique_ptr(ptr); } Array& result = GetArray(name); return result; } + void CreateOptionalArray(const string& name) { + DCHECK(!arrays.count(name) && !optional_arrays.count(name)); + optional_arrays.insert(name); + } + bool IsOptionalArray(const string& name) const { + return optional_arrays.count(name); + } + + // Note that this invalidates all array iterators. + void EraseArray(const string& name) { arrays.erase(name); } + void EraseArrays(std::function discardable) { + for (auto it = arrays.begin(); it != arrays.end();) { + if (discardable(it->first)) { + it = arrays.erase(it); + } else { + ++it; + } + } + } + const ArrayMap& GetArrayMap() const { return arrays; } + + // Optional arrays are used for optional tensors, + // these tensors do not have data, but with reserved names as op inputs. + std::set optional_arrays; // The list of operators. Notice how it's a list of unique_ptr's, implying // that the Model is what owns Operator's and keeps them alive. std::vector> operators; - // The associative array mapping names to Array's. - // Notice how it's a container of unique_ptr's, implying - // that the Model is what owns Array's and keeps them alive. - // The Operator's refer to these Array's by their name strings, not by their - // addresses. See Operator::inputs, Operator::outputs. - std::unordered_map> arrays; + // Generic flags, a place where we combine information passed to us via // command-line parameters (e.g. --input_width=N) with information that // we may or may not find in the input model file. @@ -1552,7 +1591,15 @@ struct Model { std::size_t transient_data_size = 0; // For code-generation only: required alignment of the transient_data buffer std::size_t transient_data_alignment = 0; + + private: + // The associative array mapping names to Array's. + // Notice how it's a container of unique_ptr's, implying + // that the Model is what owns Array's and keeps them alive. + // The Operator's refer to these Array's by their name strings, not by their + // addresses. See Operator::inputs, Operator::outputs. + std::unordered_map> arrays; }; } // namespace toco -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_MODEL_H_ +#endif // TENSORFLOW_CONTRIB_LITE_TOCO_MODEL_H_ diff --git a/tensorflow/contrib/lite/toco/model_cmdline_flags.h b/tensorflow/contrib/lite/toco/model_cmdline_flags.h index 027d7ae1aa62b5b31b8fcebdc29d4f547507b7fe..c868d5c7d0b5a6ee81d99423414c87e4e6e7cf66 100644 --- a/tensorflow/contrib/lite/toco/model_cmdline_flags.h +++ b/tensorflow/contrib/lite/toco/model_cmdline_flags.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_MODEL_CMDLINE_FLAGS_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_MODEL_CMDLINE_FLAGS_H_ +#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_MODEL_CMDLINE_FLAGS_H_ +#define TENSORFLOW_CONTRIB_LITE_TOCO_MODEL_CMDLINE_FLAGS_H_ #include #include @@ -40,5 +40,4 @@ ParsedModelFlags* GlobalParsedModelFlags(); } // namespace toco - -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_MODEL_CMDLINE_FLAGS_H_ +#endif // TENSORFLOW_CONTRIB_LITE_TOCO_MODEL_CMDLINE_FLAGS_H_ diff --git a/tensorflow/contrib/lite/toco/runtime/common.h b/tensorflow/contrib/lite/toco/runtime/common.h index bd55544f57f9a266514e878edd8f1f7dec1cb7b7..3c6828840c4a963a4a68774ec5d559b7f80baf22 100644 --- a/tensorflow/contrib/lite/toco/runtime/common.h +++ b/tensorflow/contrib/lite/toco/runtime/common.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_RUNTIME_COMMON_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_RUNTIME_COMMON_H_ +#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_RUNTIME_COMMON_H_ +#define TENSORFLOW_CONTRIB_LITE_TOCO_RUNTIME_COMMON_H_ #ifndef ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK #ifdef GEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK @@ -23,4 +23,4 @@ limitations under the License. #include "tensorflow/contrib/lite/kernels/internal/common.h" -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_RUNTIME_COMMON_H_ +#endif // TENSORFLOW_CONTRIB_LITE_TOCO_RUNTIME_COMMON_H_ diff --git a/tensorflow/contrib/lite/toco/runtime/types.h b/tensorflow/contrib/lite/toco/runtime/types.h index df63b2d59ea2a98f1ec9009614c18791e8822c14..f5de5a5781a5304634642680e6a3cef60e7b844b 100644 --- a/tensorflow/contrib/lite/toco/runtime/types.h +++ b/tensorflow/contrib/lite/toco/runtime/types.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_RUNTIME_TYPES_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_RUNTIME_TYPES_H_ +#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_RUNTIME_TYPES_H_ +#define TENSORFLOW_CONTRIB_LITE_TOCO_RUNTIME_TYPES_H_ #include "tensorflow/contrib/lite/kernels/internal/common.h" #include "tensorflow/contrib/lite/kernels/internal/compatibility.h" @@ -29,4 +29,4 @@ using tflite::RequiredBufferSizeForDims; } // namespace toco -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_RUNTIME_TYPES_H_ +#endif // TENSORFLOW_CONTRIB_LITE_TOCO_RUNTIME_TYPES_H_ diff --git a/tensorflow/contrib/lite/toco/tensorflow_util.h b/tensorflow/contrib/lite/toco/tensorflow_util.h index 152b4f7a727a88f721f1a63299ea4fa709bb5d52..61f91042685288a48ba19f8c67d4c7c1960a7787 100644 --- a/tensorflow/contrib/lite/toco/tensorflow_util.h +++ b/tensorflow/contrib/lite/toco/tensorflow_util.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TENSORFLOW_UTIL_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TENSORFLOW_UTIL_H_ +#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_TENSORFLOW_UTIL_H_ +#define TENSORFLOW_CONTRIB_LITE_TOCO_TENSORFLOW_UTIL_H_ #include #include @@ -29,4 +29,4 @@ void LogDumpGraphDef(int log_level, const string& message, } // namespace toco -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TENSORFLOW_UTIL_H_ +#endif // TENSORFLOW_CONTRIB_LITE_TOCO_TENSORFLOW_UTIL_H_ diff --git a/tensorflow/contrib/lite/toco/tflite/BUILD b/tensorflow/contrib/lite/toco/tflite/BUILD index 332253a092aff812fb18601862c66bc0423599c2..72c926656449da981abf6c11c03cd7c00a634ce7 100644 --- a/tensorflow/contrib/lite/toco/tflite/BUILD +++ b/tensorflow/contrib/lite/toco/tflite/BUILD @@ -27,7 +27,7 @@ cc_library( "//tensorflow/contrib/lite/toco:model", "//tensorflow/core:protos_all_cc", "@com_google_absl//absl/memory", - "@flatbuffers//:flatbuffers", + "@flatbuffers", ], ) @@ -41,7 +41,7 @@ tf_cc_test( "//tensorflow/contrib/lite/toco:tooling_util", "//tensorflow/core:protos_all_cc", "@com_google_googletest//:gtest_main", - "@flatbuffers//:flatbuffers", + "@flatbuffers", ], ) @@ -87,7 +87,7 @@ cc_library( "//tensorflow/contrib/lite/toco:model", "//tensorflow/contrib/lite/toco:tooling_util", "@com_google_absl//absl/strings", - "@flatbuffers//:flatbuffers", + "@flatbuffers", ], ) @@ -117,7 +117,7 @@ cc_library( ":types", "//tensorflow/contrib/lite/schema:schema_fbs", "//tensorflow/contrib/lite/toco:model", - "@flatbuffers//:flatbuffers", + "@flatbuffers", ], ) @@ -131,7 +131,7 @@ tf_cc_test( "//tensorflow/contrib/lite:schema_fbs_version", "//tensorflow/contrib/lite/schema:schema_fbs", "@com_google_googletest//:gtest_main", - "@flatbuffers//:flatbuffers", + "@flatbuffers", ], ) diff --git a/tensorflow/contrib/lite/toco/tflite/builtin_operator.h b/tensorflow/contrib/lite/toco/tflite/builtin_operator.h index 93cc79ddb64fbc46a97a47ecdc155a8aabf5c3ef..cfe7ecd9f982618dea3b3a5d02e69e3f15434bc2 100644 --- a/tensorflow/contrib/lite/toco/tflite/builtin_operator.h +++ b/tensorflow/contrib/lite/toco/tflite/builtin_operator.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_BUILTIN_OPERATOR_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_BUILTIN_OPERATOR_H_ +#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_BUILTIN_OPERATOR_H_ +#define TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_BUILTIN_OPERATOR_H_ #include "absl/memory/memory.h" #include "tensorflow/contrib/lite/toco/tflite/operator.h" @@ -71,4 +71,4 @@ class BuiltinOperator : public BaseOperator { } // namespace toco -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_BUILTIN_OPERATOR_H_ +#endif // TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_BUILTIN_OPERATOR_H_ diff --git a/tensorflow/contrib/lite/toco/tflite/custom_operator.h b/tensorflow/contrib/lite/toco/tflite/custom_operator.h index 1a4bfac7d4f684043d2a9ce8fc2c78dd738f4b69..bd5713618ff379be42fd1b76649cfb2cf55b843d 100644 --- a/tensorflow/contrib/lite/toco/tflite/custom_operator.h +++ b/tensorflow/contrib/lite/toco/tflite/custom_operator.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_CUSTOM_OPERATOR_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_CUSTOM_OPERATOR_H_ +#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_CUSTOM_OPERATOR_H_ +#define TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_CUSTOM_OPERATOR_H_ #include "flatbuffers/flexbuffers.h" #include "absl/memory/memory.h" @@ -71,4 +71,4 @@ class CustomOperator : public BaseOperator { } // namespace toco -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_CUSTOM_OPERATOR_H_ +#endif // TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_CUSTOM_OPERATOR_H_ diff --git a/tensorflow/contrib/lite/toco/tflite/export.cc b/tensorflow/contrib/lite/toco/tflite/export.cc index bec694a23377c7c70684000069e9c08ee446b6c0..391ef87029d019ab52af2716f72883f5f82f94d9 100644 --- a/tensorflow/contrib/lite/toco/tflite/export.cc +++ b/tensorflow/contrib/lite/toco/tflite/export.cc @@ -62,7 +62,7 @@ namespace details { void LoadTensorsMap(const Model& model, TensorsMap* tensors_map) { // First find a list of unique array names. std::set names; - for (const auto& array_pair : model.arrays) { + for (const auto& array_pair : model.GetArrayMap()) { names.insert(array_pair.first); } @@ -96,7 +96,7 @@ Offset>> ExportTensors( // tensors in the tensors_map. std::map> ordered_tensors; - for (const auto& array_pair : model.arrays) { + for (const auto& array_pair : model.GetArrayMap()) { const string& tensor_name = array_pair.first; const toco::Array& array = *array_pair.second; @@ -235,9 +235,10 @@ Offset>> ExportOperators( for (const auto& op : model.operators) { std::vector inputs; for (const string& input : op->inputs) { - inputs.push_back(tensors_map.at(input)); + // -1 is the ID for optional tensor in TFLite output + int id = model.IsOptionalArray(input) ? -1 : tensors_map.at(input); + inputs.push_back(id); } - std::vector outputs; for (const string& output : op->outputs) { outputs.push_back(tensors_map.at(output)); diff --git a/tensorflow/contrib/lite/toco/tflite/export.h b/tensorflow/contrib/lite/toco/tflite/export.h index 44012b7126e17d730ea248551dea2414ad0072d9..8c79cb820015e16847ce48c171e8f6e41f60c319 100644 --- a/tensorflow/contrib/lite/toco/tflite/export.h +++ b/tensorflow/contrib/lite/toco/tflite/export.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_EXPORT_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_EXPORT_H_ +#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_EXPORT_H_ +#define TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_EXPORT_H_ #include "tensorflow/contrib/lite/toco/model.h" @@ -73,4 +73,4 @@ void LoadOperatorsMap(const Model& model, OperatorsMap* operators_map); } // namespace toco -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_EXPORT_H_ +#endif // TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_EXPORT_H_ diff --git a/tensorflow/contrib/lite/toco/tflite/import.h b/tensorflow/contrib/lite/toco/tflite/import.h index 3c27a2843c47814ad46c8f1bbd77b7afcb324375..280677bae189fa345c2e19f6399a7b9ac7629ab5 100644 --- a/tensorflow/contrib/lite/toco/tflite/import.h +++ b/tensorflow/contrib/lite/toco/tflite/import.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_IMPORT_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_IMPORT_H_ +#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_IMPORT_H_ +#define TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_IMPORT_H_ #include "tensorflow/contrib/lite/schema/schema_generated.h" #include "tensorflow/contrib/lite/toco/model.h" @@ -46,4 +46,4 @@ void LoadOperatorsTable(const ::tflite::Model &input_model, } // namespace toco -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_IMPORT_H_ +#endif // TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_IMPORT_H_ diff --git a/tensorflow/contrib/lite/toco/tflite/import_test.cc b/tensorflow/contrib/lite/toco/tflite/import_test.cc index 309fa6d7f688ba1dd99a7e6eeda14d513a9e49d4..aad6e780d5eb5c3dbc880906df5053ad231ffd54 100644 --- a/tensorflow/contrib/lite/toco/tflite/import_test.cc +++ b/tensorflow/contrib/lite/toco/tflite/import_test.cc @@ -114,7 +114,7 @@ TEST_F(ImportTest, Tensors) { auto model = Import(ModelFlags(), InputModelAsString()); - ASSERT_GT(model->arrays.count("tensor_one"), 0); + ASSERT_GT(model->HasArray("tensor_one"), 0); Array& a1 = model->GetArray("tensor_one"); EXPECT_EQ(ArrayDataType::kFloat, a1.data_type); EXPECT_THAT(a1.GetBuffer().data, diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc index 0111e1ed92f479cd35f03971ff74ab08c4ccf55a..298f49025f9dc8b636dc76a04b8e2e5f11d27db7 100644 --- a/tensorflow/contrib/lite/toco/tflite/operator.cc +++ b/tensorflow/contrib/lite/toco/tflite/operator.cc @@ -474,19 +474,11 @@ class Pad : public BuiltinOperator WriteOptions( const TocoOperator& op, flatbuffers::FlatBufferBuilder* builder) const override { - auto before_padding = builder->CreateVector(op.left_padding); - auto after_padding = builder->CreateVector(op.right_padding); - return ::tflite::CreatePadOptions(*builder, before_padding, after_padding); + return ::tflite::CreatePadOptions(*builder); } void ReadOptions(const TfLiteOptions& options, TocoOperator* op) const override { - op->left_padding.insert(op->left_padding.end(), - options.before_padding()->begin(), - options.before_padding()->end()); - op->right_padding.insert(op->right_padding.end(), - options.after_padding()->begin(), - options.after_padding()->end()); } }; @@ -617,6 +609,30 @@ class Split : public CustomOperator { } }; +class StridedSlice + : public BuiltinOperator { + public: + using BuiltinOperator::BuiltinOperator; + flatbuffers::Offset WriteOptions( + const TocoOperator& op, + flatbuffers::FlatBufferBuilder* builder) const override { + return ::tflite::CreateStridedSliceOptions( + *builder, op.begin_mask, op.end_mask, op.ellipsis_mask, + op.new_axis_mask, op.shrink_axis_mask); + } + + void ReadOptions(const TfLiteOptions& options, + TocoOperator* op) const override { + op->begin_mask = options.begin_mask(); + op->end_mask = options.end_mask(); + op->ellipsis_mask = options.ellipsis_mask(); + op->new_axis_mask = options.new_axis_mask(); + op->shrink_axis_mask = options.shrink_axis_mask(); + } +}; + class TensorFlowUnsupported : public BaseOperator { public: using BaseOperator::BaseOperator; @@ -777,6 +793,8 @@ std::vector> BuildOperatorList() { new Mean(::tflite::BuiltinOperator_MEAN, OperatorType::kMean)); ops.emplace_back( new Squeeze(::tflite::BuiltinOperator_SQUEEZE, OperatorType::kSqueeze)); + ops.emplace_back(new StridedSlice(::tflite::BuiltinOperator_STRIDED_SLICE, + OperatorType::kStridedSlice)); // Custom Operators. ops.emplace_back(new Cast("CAST", OperatorType::kCast)); @@ -789,6 +807,8 @@ std::vector> BuildOperatorList() { // There operators are supported by Toco, but not by TF Lite, and has no // attributes. + ops.emplace_back( + new SimpleOperator("ADDN", OperatorType::kAddN)); ops.emplace_back(new SimpleOperator("NEG", OperatorType::kNeg)); ops.emplace_back(new SimpleOperator( "RSQRT", OperatorType::kTensorFlowRsqrt)); diff --git a/tensorflow/contrib/lite/toco/tflite/operator.h b/tensorflow/contrib/lite/toco/tflite/operator.h index 37df302d4697c78e0349bcd30e0e1adc540066bc..88af3d6ab6c6af150af83ed5c52931f9f089aa3c 100644 --- a/tensorflow/contrib/lite/toco/tflite/operator.h +++ b/tensorflow/contrib/lite/toco/tflite/operator.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_OPERATOR_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_OPERATOR_H_ +#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_OPERATOR_H_ +#define TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_OPERATOR_H_ #include "flatbuffers/flatbuffers.h" #include "tensorflow/contrib/lite/schema/schema_generated.h" @@ -86,4 +86,4 @@ class BaseOperator { } // namespace toco -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_OPERATOR_H_ +#endif // TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_OPERATOR_H_ diff --git a/tensorflow/contrib/lite/toco/tflite/operator_test.cc b/tensorflow/contrib/lite/toco/tflite/operator_test.cc index 77c70847d1e94fc5c7eeac6480a5286ba6557fab..9036a16d1c928702a71ccbe3fdad826fb037fcaf 100644 --- a/tensorflow/contrib/lite/toco/tflite/operator_test.cc +++ b/tensorflow/contrib/lite/toco/tflite/operator_test.cc @@ -258,16 +258,6 @@ TEST_F(OperatorTest, BuiltinMaxPool) { EXPECT_EQ(op.kheight, output_toco_op->kheight); } -TEST_F(OperatorTest, BuiltinPad) { - PadOperator op; - op.left_padding = {1, 2, 3}; - op.right_padding = {1, 2, 3}; - auto output_toco_op = - SerializeAndDeserialize(GetOperator("PAD", OperatorType::kPad), op); - EXPECT_EQ(op.left_padding, output_toco_op->left_padding); - EXPECT_EQ(op.right_padding, output_toco_op->right_padding); -} - TEST_F(OperatorTest, BuiltinReshape) { TensorFlowReshapeOperator op; op.shape = {1, 2, 4, 5, 8}; @@ -398,6 +388,28 @@ TEST_F(OperatorTest, Squeeze) { EXPECT_EQ(op.squeeze_dims, output_toco_op->squeeze_dims); } +TEST_F(OperatorTest, StridedSlice) { + StridedSliceOperator op; + + op.begin_mask = 1; + op.end_mask = 2; + op.ellipsis_mask = 1; + op.new_axis_mask = 1; + op.shrink_axis_mask = 2; + + auto output_toco_op = SerializeAndDeserialize( + GetOperator("STRIDED_SLICE", OperatorType::kStridedSlice), op); + EXPECT_EQ(op.start_indices, output_toco_op->start_indices); + EXPECT_EQ(op.stop_indices, output_toco_op->stop_indices); + EXPECT_EQ(op.strides, output_toco_op->strides); + EXPECT_EQ(op.begin_mask, output_toco_op->begin_mask); + EXPECT_EQ(op.end_mask, output_toco_op->end_mask); + EXPECT_EQ(op.end_mask, output_toco_op->end_mask); + EXPECT_EQ(op.ellipsis_mask, output_toco_op->ellipsis_mask); + EXPECT_EQ(op.new_axis_mask, output_toco_op->new_axis_mask); + EXPECT_EQ(op.shrink_axis_mask, output_toco_op->shrink_axis_mask); +} + TEST_F(OperatorTest, TensorFlowUnsupported) { TensorFlowUnsupportedOperator op; op.tensorflow_op = "MyCustomUnsupportedOp"; diff --git a/tensorflow/contrib/lite/toco/tflite/simple_operator.h b/tensorflow/contrib/lite/toco/tflite/simple_operator.h index 992b98bacafecb080e792ae87a2940977482eed6..72678c82a22a7168f858747b0b1c6a2b515b6578 100644 --- a/tensorflow/contrib/lite/toco/tflite/simple_operator.h +++ b/tensorflow/contrib/lite/toco/tflite/simple_operator.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_SIMPLE_OPERATOR_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_SIMPLE_OPERATOR_H_ +#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_SIMPLE_OPERATOR_H_ +#define TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_SIMPLE_OPERATOR_H_ #include "tensorflow/contrib/lite/toco/tflite/operator.h" @@ -47,4 +47,4 @@ class SimpleOperator : public BaseOperator { } // namespace toco -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_SIMPLE_OPERATOR_H_ +#endif // TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_SIMPLE_OPERATOR_H_ diff --git a/tensorflow/contrib/lite/toco/tflite/types.h b/tensorflow/contrib/lite/toco/tflite/types.h index f7c51405107d954fa259809b72f56af193e344fb..3923756fc94e3175a6505740a96cce8d614c3990 100644 --- a/tensorflow/contrib/lite/toco/tflite/types.h +++ b/tensorflow/contrib/lite/toco/tflite/types.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_TYPES_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_TYPES_H_ +#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_TYPES_H_ +#define TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_TYPES_H_ #include "tensorflow/contrib/lite/schema/schema_generated.h" #include "tensorflow/contrib/lite/toco/model.h" @@ -55,4 +55,4 @@ struct ActivationFunction { } // namespace toco -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_TYPES_H_ +#endif // TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_TYPES_H_ diff --git a/tensorflow/contrib/lite/toco/toco_cmdline_flags.h b/tensorflow/contrib/lite/toco/toco_cmdline_flags.h index ba35ca8d5d23f07d843ae6fa2099cc7e15b1e9a3..46eb3f57283cc52bf2877f578500f3a4a633be86 100644 --- a/tensorflow/contrib/lite/toco/toco_cmdline_flags.h +++ b/tensorflow/contrib/lite/toco/toco_cmdline_flags.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_CMDLINE_FLAGS_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_CMDLINE_FLAGS_H_ +#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_CMDLINE_FLAGS_H_ +#define TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_CMDLINE_FLAGS_H_ #include #include @@ -33,4 +33,4 @@ void ReadTocoFlagsFromCommandLineFlags(const ParsedTocoFlags& parsed_toco_flags, } // namespace toco -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_CMDLINE_FLAGS_H_ +#endif // TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_CMDLINE_FLAGS_H_ diff --git a/tensorflow/contrib/lite/toco/toco_graphviz_dump_options.h b/tensorflow/contrib/lite/toco/toco_graphviz_dump_options.h index ae0541f62b61581e3ba183725a85fe51c54116dc..d6c3ba6543378b3e15b5fb7816f52376fe05123d 100644 --- a/tensorflow/contrib/lite/toco/toco_graphviz_dump_options.h +++ b/tensorflow/contrib/lite/toco/toco_graphviz_dump_options.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_GRAPHVIZ_DUMP_OPTIONS_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_GRAPHVIZ_DUMP_OPTIONS_H_ +#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_GRAPHVIZ_DUMP_OPTIONS_H_ +#define TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_GRAPHVIZ_DUMP_OPTIONS_H_ #include @@ -31,4 +31,4 @@ struct GraphVizDumpOptions { } // namespace toco -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_GRAPHVIZ_DUMP_OPTIONS_H_ +#endif // TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_GRAPHVIZ_DUMP_OPTIONS_H_ diff --git a/tensorflow/contrib/lite/toco/toco_port.h b/tensorflow/contrib/lite/toco/toco_port.h index b5cb7a11e7c46d02d398ff937d46e52368e88098..0572848cb5a998457cd669a2b0bce5fe8a0e15a2 100644 --- a/tensorflow/contrib/lite/toco/toco_port.h +++ b/tensorflow/contrib/lite/toco/toco_port.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_PORT_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_PORT_H_ +#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_PORT_H_ +#define TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_PORT_H_ // Portability layer for toco tool. Mainly, abstract filesystem access so we // can build and use on google internal environments and on OSX. @@ -77,4 +77,4 @@ void CopyToBuffer(const string& src, char* dest); } // namespace port } // namespace toco -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_PORT_H_ +#endif // TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_PORT_H_ diff --git a/tensorflow/contrib/lite/toco/toco_tooling.cc b/tensorflow/contrib/lite/toco/toco_tooling.cc index 94b4d146968d4bf92bd8f662763eecdc92a66663..720c33777d707994c6e1003bb1210eadd96bc8a8 100644 --- a/tensorflow/contrib/lite/toco/toco_tooling.cc +++ b/tensorflow/contrib/lite/toco/toco_tooling.cc @@ -52,7 +52,9 @@ void MakeGeneralGraphTransformationsSet( GraphTransformationsSet* transformations) { CHECK(transformations->empty()); transformations->Add(new ConvertExpandDimsToReshape); + transformations->Add(new ConvertTrivialAddNToAdd); transformations->Add(new ConvertTrivialTransposeToReshape); + transformations->Add(new ConvertReorderAxes); transformations->Add(new ResolveReshapeAttributes); transformations->Add(new PropagateArrayDataTypes); transformations->Add(new PropagateFixedSizes); @@ -96,7 +98,6 @@ void MakeGeneralGraphTransformationsSet( bool SupportsQuantization(FileFormat format) { return (format == GRAPHVIZ_DOT || format == TFLITE); - ; } bool SupportsFusedActivationFunction(FileFormat format) { @@ -133,7 +134,7 @@ void SetFinalDataTypeOnInputs(const TocoFlags& toco_flags, Model* model) { for (int i = 0; i < model->flags.input_arrays_size(); i++) { string const& array_name = model->flags.input_arrays(i).name(); - auto* array = model->arrays[array_name].get(); + auto* array = &model->GetArray(array_name); // Note that the notion of changing data types only applies to real-numbers // arrays (see the documentation for inference_input_type). // TODO(benoitjacob) this is assuming that uint8 arrays are quantized, diff --git a/tensorflow/contrib/lite/toco/toco_tooling.h b/tensorflow/contrib/lite/toco/toco_tooling.h index 9c5a93a21170ba773b1160eb2e1261f85cdd70e5..e731c149eef412d3048a1d5f84145ce6ff87208d 100644 --- a/tensorflow/contrib/lite/toco/toco_tooling.h +++ b/tensorflow/contrib/lite/toco/toco_tooling.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_TOOLING_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_TOOLING_H_ +#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_TOOLING_H_ +#define TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_TOOLING_H_ #include #include @@ -47,4 +47,4 @@ inline void Export(const TocoFlags& toco_flags, const Model& model, } // namespace toco -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_TOOLING_H_ +#endif // TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_TOOLING_H_ diff --git a/tensorflow/contrib/lite/toco/toco_types.h b/tensorflow/contrib/lite/toco/toco_types.h index ad42497ada6cb0dbda673bf3aad406c9fedfb078..d72a3bd1f382679f81061a51f35586631b571400 100644 --- a/tensorflow/contrib/lite/toco/toco_types.h +++ b/tensorflow/contrib/lite/toco/toco_types.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TYPES_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TYPES_H_ +#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_TYPES_H_ +#define TENSORFLOW_CONTRIB_LITE_TOCO_TYPES_H_ #include #include "tensorflow/core/platform/platform.h" @@ -42,4 +42,4 @@ using tensorflow::uint8; } // namespace toco -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TYPES_H_ +#endif // TENSORFLOW_CONTRIB_LITE_TOCO_TYPES_H_ diff --git a/tensorflow/contrib/lite/toco/tooling_util.cc b/tensorflow/contrib/lite/toco/tooling_util.cc index e09a469d55bae7d2abc6bfa5a3e78ce41ae7a4f5..6577bb778184ef774c4102aa0a22153a428d5c61 100644 --- a/tensorflow/contrib/lite/toco/tooling_util.cc +++ b/tensorflow/contrib/lite/toco/tooling_util.cc @@ -93,7 +93,7 @@ int CountOpsWithInput(const Model& model, const string& array_name) { bool DeleteArrayIfUnused(const string& array_name, Model* model) { if (CountOpsWithInput(*model, array_name) == 0) { - model->arrays.erase(array_name); + model->EraseArray(array_name); return true; } return false; @@ -197,6 +197,7 @@ const char* OperatorTypeName(OperatorType type) { case OperatorType::k##c: \ return #c; HANDLE_OPERATORTYPENAME_CASE(Add) + HANDLE_OPERATORTYPENAME_CASE(AddN) HANDLE_OPERATORTYPENAME_CASE(AveragePool) HANDLE_OPERATORTYPENAME_CASE(BatchNormalization) HANDLE_OPERATORTYPENAME_CASE(Conv) @@ -566,11 +567,11 @@ int RequiredBufferSizeForShape(const Shape& shape) { } bool IsConstantParameterArray(const Model& model, const string& name) { - if (!model.arrays.count(name)) { + if (!model.HasArray(name)) { return false; } - return !!model.arrays.at(name)->buffer; + return !!model.GetArray(name).buffer; } namespace { @@ -633,17 +634,17 @@ void CheckNonExistentIOArrays(const Model& model) { return; } for (const auto& input_array : model.flags.input_arrays()) { - CHECK(model.arrays.count(input_array.name())) + CHECK(model.HasArray(input_array.name())) << "Input array not found: " << input_array.name(); } for (const string& output_array : model.flags.output_arrays()) { - CHECK(model.arrays.count(output_array)) + CHECK(model.HasArray(output_array)) << "Output array not found: " << output_array; } for (const auto& rnn_state : model.flags.rnn_states()) { if (!rnn_state.discardable()) { - CHECK(model.arrays.count(rnn_state.state_array())); - CHECK(model.arrays.count(rnn_state.back_edge_source_array())); + CHECK(model.HasArray(rnn_state.state_array())); + CHECK(model.HasArray(rnn_state.back_edge_source_array())); } } } @@ -652,10 +653,13 @@ void CheckNonExistentIOArrays(const Model& model) { void CheckNoMissingArray(const Model& model) { for (const auto& op : model.operators) { for (const auto& input : op->inputs) { - CHECK(model.arrays.count(input)); + CHECK(model.HasArray(input) || model.optional_arrays.count(input)) + << "Input: " << input << " missing for op: " + << op->outputs[0] << "."; } for (const auto& output : op->outputs) { - CHECK(model.arrays.count(output)); + CHECK(model.HasArray(output)) << "Output: " << output + << " missing."; } } CheckNonExistentIOArrays(model); @@ -664,12 +668,12 @@ void CheckNoMissingArray(const Model& model) { void FixNoMissingArray(Model* model) { for (const auto& op : model->operators) { for (const auto& input : op->inputs) { - if (!model->arrays.count(input)) { + if (!model->HasArray(input)) { model->GetOrCreateArray(input); } } for (const auto& output : op->outputs) { - if (!model->arrays.count(output)) { + if (!model->HasArray(output)) { model->GetOrCreateArray(output); } } @@ -687,7 +691,7 @@ void FixNoMissingArray(Model* model) { void CheckNoOrphanedArray(const Model& model) { std::unordered_set arrays_without_known_use; - for (const auto& array : model.arrays) { + for (const auto& array : model.GetArrayMap()) { if (IsDiscardableArray(model, array.first)) { arrays_without_known_use.insert(array.first); } @@ -714,7 +718,7 @@ void CheckNoOrphanedArray(const Model& model) { void FixNoOrphanedArray(Model* model) { std::unordered_set arrays_without_known_use; - for (const auto& array : model->arrays) { + for (const auto& array : model->GetArrayMap()) { arrays_without_known_use.insert(array.first); } for (const auto& op : model->operators) { @@ -731,13 +735,13 @@ void FixNoOrphanedArray(Model* model) { } for (const auto& array : arrays_without_known_use) { if (IsDiscardableArray(*model, array)) { - model->arrays.erase(array); + model->EraseArray(array); } } } void CheckArrayFieldsConsistent(const Model& model) { - for (const auto& array_entry : model.arrays) { + for (const auto& array_entry : model.GetArrayMap()) { const auto& array = array_entry.second; if (array->has_shape()) { for (int d : array->shape().dims()) { @@ -756,11 +760,13 @@ void CheckArrayFieldsConsistent(const Model& model) { void CheckOperatorOrdering(const Model& model) { std::unordered_set arrays_behind_us; - for (const auto& array_entry : model.arrays) { + for (const auto& array_entry : model.GetArrayMap()) { if (!GetOpWithOutput(model, array_entry.first)) { arrays_behind_us.insert(array_entry.first); } } + arrays_behind_us.insert(model.optional_arrays.begin(), + model.optional_arrays.end()); for (const auto& op : model.operators) { for (const auto& input : op->inputs) { if (!IsConstantParameterArray(model, input)) { @@ -779,11 +785,13 @@ void CheckOperatorOrdering(const Model& model) { void FixOperatorOrdering(Model* model) { std::unordered_set arrays_behind_us; - for (const auto& array_entry : model->arrays) { + for (const auto& array_entry : model->GetArrayMap()) { if (!GetOpWithOutput(*model, array_entry.first)) { arrays_behind_us.insert(array_entry.first); } } + arrays_behind_us.insert(model->optional_arrays.begin(), + model->optional_arrays.end()); std::vector> old_operators; std::swap(old_operators, model->operators); std::set remaining; @@ -932,7 +940,8 @@ void CheckModelCounts(const Model& model) { if (count_type == "None") { continue; } else if (count_type == "Arrays") { - CheckCountInRange(model_check, model.arrays.size(), "count of arrays"); + CheckCountInRange(model_check, model.GetArrayMap().size(), + "count of arrays"); } else if (count_type == "Total") { CheckCountInRange(model_check, model.operators.size(), "count of all operator instances"); @@ -1281,6 +1290,8 @@ void DropMinMax(Model* model, const string& array_name) { } bool IsAllocatableTransientArray(const Model& model, const string& array_name) { + // Optional array is not transient + if (model.IsOptionalArray(array_name)) return false; // The model's input and output arrays are externally allocated. // They are not transient arrays. if (IsInputArray(model, array_name)) { @@ -1291,7 +1302,7 @@ bool IsAllocatableTransientArray(const Model& model, const string& array_name) { return false; } } - const auto& array = model.arrays.at(array_name); + const auto& array = &model.GetArray(array_name); // An array with a constant buffer isn't a transient array. if (!!array->buffer) { return false; @@ -1304,13 +1315,13 @@ bool IsAllocatableTransientArray(const Model& model, const string& array_name) { } string AvailableArrayName(const Model& model, const string& name) { - if (!model.arrays.count(name)) { + if (!model.HasArray(name) && !model.optional_arrays.count(name)) { return name; } const int kNumSuffixesToTry = 1000; for (int i = 0; i < kNumSuffixesToTry; i++) { const string& name_with_suffix = toco::port::StringF("%s_%d", name, i); - if (!model.arrays.count(name_with_suffix)) { + if (!model.HasArray(name_with_suffix)) { return name_with_suffix; } } @@ -1328,12 +1339,12 @@ string ShapeToString(const Shape& shape) { } void PrintArrayShape(Model* model, const string& name) { - if (!model->arrays[name]->has_shape()) { + if (!model->GetArray(name).has_shape()) { LOG(INFO) << name << " has no shape"; return; } LOG(INFO) << name - << " has shape: " << ShapeToString(model->arrays[name]->shape()); + << " has shape: " << ShapeToString(model->GetArray(name).shape()); } bool IsArrayFullyConnectedWeights(const Model& model, const string& name) { @@ -1389,6 +1400,16 @@ bool EstimateArithmeticOpsCount(const Model& model, int64* result) { total += RequiredBufferSizeForShape(output_array.shape()); break; } + case OperatorType::kAddN: { + const auto& output_array = model.GetArray(op->outputs[0]); + if (!output_array.has_shape()) { + return false; + } + // AddN cost is roughly the same cost as N-1 Adds. + const int num_adds = op->inputs.size() - 1; + total += num_adds * RequiredBufferSizeForShape(output_array.shape()); + break; + } case OperatorType::kLogistic: case OperatorType::kSoftmax: case OperatorType::kTanh: { @@ -1456,8 +1477,6 @@ bool EstimateArithmeticOpsCount(const Model& model, int64* result) { return true; } -namespace { - void GetShuffleShape(AxesOrder input_axes_order, AxesOrder output_axes_order, std::vector* shuffle) { CHECK_EQ(AxesCount(input_axes_order), AxesCount(output_axes_order)); @@ -1492,6 +1511,8 @@ void GetShuffleShape(AxesOrder input_axes_order, AxesOrder output_axes_order, } } +namespace { + // Extend shuffle is designed to match ExtendShape, which pads the shape with // unit dimensions at the beginning. void ExtendShuffle(const std::vector& input_shuffle, int newdim, @@ -1667,7 +1688,7 @@ bool IsDiscardableArray(const Model& model, const string& array_name) { } void CheckFinalDataTypesSatisfied(const Model& model) { - for (const auto& array_entry : model.arrays) { + for (const auto& array_entry : model.GetArrayMap()) { const auto& array = *array_entry.second; if (array.final_data_type != ArrayDataType::kNone) { CHECK(array.final_data_type == array.data_type) diff --git a/tensorflow/contrib/lite/toco/tooling_util.h b/tensorflow/contrib/lite/toco/tooling_util.h index c81e77874e36d78ca3ee23f84f55596627e9c73d..5986d6364939e0f01b057ce3fb653b19fe8040cd 100644 --- a/tensorflow/contrib/lite/toco/tooling_util.h +++ b/tensorflow/contrib/lite/toco/tooling_util.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TOOLING_UTIL_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TOOLING_UTIL_H_ +#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_TOOLING_UTIL_H_ +#define TENSORFLOW_CONTRIB_LITE_TOCO_TOOLING_UTIL_H_ #include #include @@ -274,6 +274,11 @@ bool EstimateArithmeticOpsCount(const Model& model, int64* result); int AxesCount(AxesOrder axes_order); +// Returns the permutation of the dimensions based on the input axes order and +// output axes order. +void GetShuffleShape(AxesOrder input_axes_order, AxesOrder output_axes_order, + std::vector* shuffle); + void ShuffleDims(const Shape& input_shape, AxesOrder input_axes_order, AxesOrder output_axes_order, Shape* output_shape); void ShuffleArray(const Shape& input_shape, AxesOrder input_axes_order, @@ -295,4 +300,4 @@ ArrayDataType ConvertIODataTypeToArrayDataType(IODataType type); } // namespace toco -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TOOLING_UTIL_H_ +#endif // TENSORFLOW_CONTRIB_LITE_TOCO_TOOLING_UTIL_H_ diff --git a/tensorflow/contrib/lite/tools/BUILD b/tensorflow/contrib/lite/tools/BUILD index 389ef2323a376f33c0f539ef27a29c92b3d8be6e..20df905270b0692e2bc9b78fc020447108282d01 100644 --- a/tensorflow/contrib/lite/tools/BUILD +++ b/tensorflow/contrib/lite/tools/BUILD @@ -42,6 +42,8 @@ tf_cc_binary( }), deps = [ ":mutable_op_resolver", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite:string_util", "//tensorflow/contrib/lite/kernels:builtin_ops", ], ) diff --git a/tensorflow/contrib/lite/tools/gen_op_registration.h b/tensorflow/contrib/lite/tools/gen_op_registration.h index 318859e23d7b404c130f003b0e249893f2ed92fe..5f2ac6ca97fde9a2fe6f4bcf20184f6ef6606f0b 100644 --- a/tensorflow/contrib/lite/tools/gen_op_registration.h +++ b/tensorflow/contrib/lite/tools/gen_op_registration.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOOLS_GEN_OP_REGISTRATION_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOOLS_GEN_OP_REGISTRATION_H_ +#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_GEN_OP_REGISTRATION_H_ +#define TENSORFLOW_CONTRIB_LITE_TOOLS_GEN_OP_REGISTRATION_H_ #include "tensorflow/contrib/lite/model.h" #include "tensorflow/contrib/lite/string.h" @@ -36,4 +36,4 @@ void ReadOpsFromModel(const ::tflite::Model* model, } // namespace tflite -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOOLS_GEN_OP_REGISTRATION_H_ +#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_GEN_OP_REGISTRATION_H_ diff --git a/tensorflow/contrib/lite/tools/mutable_op_resolver.h b/tensorflow/contrib/lite/tools/mutable_op_resolver.h index 906553da570720a0c4b90bbd2eebb6d8bdea6bb8..573a359c458acb6e4320c5a21cb378cdde720924 100644 --- a/tensorflow/contrib/lite/tools/mutable_op_resolver.h +++ b/tensorflow/contrib/lite/tools/mutable_op_resolver.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOOLS_MUTABLE_OP_RESOLVER_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOOLS_MUTABLE_OP_RESOLVER_H_ +#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_MUTABLE_OP_RESOLVER_H_ +#define TENSORFLOW_CONTRIB_LITE_TOOLS_MUTABLE_OP_RESOLVER_H_ #include #include "tensorflow/contrib/lite/context.h" @@ -52,4 +52,4 @@ class MutableOpResolver : public OpResolver { } // namespace tflite -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOOLS_MUTABLE_OP_RESOLVER_H_ +#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_MUTABLE_OP_RESOLVER_H_ diff --git a/tensorflow/contrib/lite/version.h b/tensorflow/contrib/lite/version.h index a751afabe7460f0c9e88385faf1497b2c0a25d6b..efd63f4006ae661c6fdbbaa81cb02fa8947271f3 100644 --- a/tensorflow/contrib/lite/version.h +++ b/tensorflow/contrib/lite/version.h @@ -12,12 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_VERSION_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_VERSION_H_ +#ifndef TENSORFLOW_CONTRIB_LITE_VERSION_H_ +#define TENSORFLOW_CONTRIB_LITE_VERSION_H_ // The version number of the Schema. Ideally all changes will be backward // compatible. If that ever changes, we must ensure that version is the first // entry in the new tflite root so that we can see that version is not 1. #define TFLITE_SCHEMA_VERSION (3) -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_VERSION_H_ +#endif // TENSORFLOW_CONTRIB_LITE_VERSION_H_ diff --git a/tensorflow/contrib/makefile/Makefile b/tensorflow/contrib/makefile/Makefile index dd5770dc996b3efab8647a5e3ee4a069593c679b..c50f8ceec0a634010a7f04dbb47f267be1d7074b 100644 --- a/tensorflow/contrib/makefile/Makefile +++ b/tensorflow/contrib/makefile/Makefile @@ -377,10 +377,10 @@ $(MARCH_OPTION) \ ifeq ($(BUILD_FOR_TEGRA),1) NVCC := $(JETPACK)/cuda/bin/nvcc - NVCCFLAGS := -x=cu -D__CUDACC__ -DNVCC -DNVIDIA_TEGRA -ccbin $(NDK_ROOT)/toolchains/$(TOOLCHAIN)/prebuilt/$(ANDROID_HOST_OS_ARCH)/bin/$(BIN_PREFIX)-g++ --std c++11 --expt-relaxed-constexpr -m64 -gencode arch=compute_53,\"code=sm_53\" -gencode arch=compute_62,\"code=sm_62\" -DEIGEN_AVOID_STL_ARRAY -DTENSORFLOW_USE_EIGEN_THREADPOOL -DLANG_CXX11 -DEIGEN_HAS_C99_MATH -DGOOGLE_CUDA=1 -DTF_EXTRA_CUDA_CAPABILITIES=5.3 + NVCCFLAGS := -x=cu -D__CUDACC__ -DNVCC -DANDROID_TEGRA -ccbin $(NDK_ROOT)/toolchains/$(TOOLCHAIN)/prebuilt/$(ANDROID_HOST_OS_ARCH)/bin/$(BIN_PREFIX)-g++ --std c++11 --expt-relaxed-constexpr -m64 -gencode arch=compute_53,\"code=sm_53\" -gencode arch=compute_62,\"code=sm_62\" -DEIGEN_AVOID_STL_ARRAY -DTENSORFLOW_USE_EIGEN_THREADPOOL -DLANG_CXX11 -DEIGEN_HAS_C99_MATH -DGOOGLE_CUDA=1 -DTF_EXTRA_CUDA_CAPABILITIES=5.3 CXXFLAGS4NVCC =\ -DIS_SLIM_BUILD \ --DNVIDIA_TEGRA \ +-DANDROID_TEGRA \ -fno-exceptions \ -DNDEBUG $(OPTFLAGS) \ -march=armv8-a \ @@ -391,7 +391,7 @@ $(MARCH_OPTION) \ CXXFLAGS +=\ -DGOOGLE_CUDA=1 \ -D__ANDROID_TYPES_FULL__ \ --DNVIDIA_TEGRA \ +-DANDROID_TEGRA \ -DEIGEN_AVOID_STL_ARRAY \ -DEIGEN_HAS_C99_MATH \ -DLANG_CXX11 -DTENSORFLOW_USE_EIGEN_THREADPOOL -DTF_EXTRA_CUDA_CAPABILITIES=5.3 diff --git a/tensorflow/contrib/mpi/BUILD b/tensorflow/contrib/mpi/BUILD index d9d55faf50b7f5043bfd0ed3b3d9ca5c404c7627..23f90cf77ef0bde34f3938688aa6ca2f6e9bbc53 100644 --- a/tensorflow/contrib/mpi/BUILD +++ b/tensorflow/contrib/mpi/BUILD @@ -71,6 +71,8 @@ cc_library( "//tensorflow/core:protos_cc", "//tensorflow/core:worker_proto_cc", "//tensorflow/core/distributed_runtime:base_rendezvous_mgr", + "//tensorflow/core/distributed_runtime:recent_request_ids", + "//tensorflow/core/distributed_runtime:request_id", "//tensorflow/core/distributed_runtime:session_mgr", "//tensorflow/core/distributed_runtime:tensor_coding", "//tensorflow/core/distributed_runtime:worker_env", diff --git a/tensorflow/contrib/mpi/mpi_rendezvous_mgr.cc b/tensorflow/contrib/mpi/mpi_rendezvous_mgr.cc index 1a2563d20fdc33d3c5e4a85561b61d04d3eeabff..8d14a3ef0404e727c47ad2ab39a69838fe1588aa 100644 --- a/tensorflow/contrib/mpi/mpi_rendezvous_mgr.cc +++ b/tensorflow/contrib/mpi/mpi_rendezvous_mgr.cc @@ -33,8 +33,10 @@ limitations under the License. namespace tensorflow { MPIRendezvousMgr::MPIRendezvousMgr(const WorkerEnv* env) - : BaseRendezvousMgr(env), worker_env_2(env), use_optimal_transfer_(false) { - + : BaseRendezvousMgr(env), + worker_env_2(env), + use_optimal_transfer_(false), + recv_tensor_recent_request_ids_(100000) { const char* mpienv = getenv("MPI_OPTIMAL_PATH"); if (mpienv && mpienv[0] == '1') { LOG(INFO) << "MPI Optimal copy path enabled (Requires CUDA-Aware MPI when " @@ -149,6 +151,8 @@ MPIRemoteRendezvous::~MPIRemoteRendezvous() {} */ void MPIRendezvousMgr::AddRequest(RecvTensorRequest request, const int mpi_dst) { + TF_CHECK_OK(recv_tensor_recent_request_ids_.TrackUnique( + req.request_id(), "RecvTensor (MPIRendezvousMgr)", req)); const int64 step_id = request.step_id(); const std::string& key = request.rendezvous_key(); Rendezvous::ParsedKey parsed; diff --git a/tensorflow/contrib/mpi/mpi_rendezvous_mgr.h b/tensorflow/contrib/mpi/mpi_rendezvous_mgr.h index b15748d63c9fdbc5134069b63fd998e46c499e16..ca42ee2f6d246f67f5c4c668fe27b16722bc6130 100644 --- a/tensorflow/contrib/mpi/mpi_rendezvous_mgr.h +++ b/tensorflow/contrib/mpi/mpi_rendezvous_mgr.h @@ -30,10 +30,11 @@ limitations under the License. #include +#include "tensorflow/contrib/mpi/mpi_msg.pb.h" #include "tensorflow/contrib/mpi/mpi_utils.h" #include "tensorflow/core/distributed_runtime/base_rendezvous_mgr.h" +#include "tensorflow/core/distributed_runtime/request_id.h" #include "tensorflow/core/distributed_runtime/worker_env.h" -#include "tensorflow/contrib/mpi/mpi_msg.pb.h" #include "tensorflow/core/protobuf/worker.pb.h" #define TAG_REQTENSOR 1010 @@ -104,6 +105,7 @@ class MPIRequestTensorCall { void Init(const Rendezvous::ParsedKey& parsed, const int64 step_id) { req_.set_step_id(step_id); req_.set_rendezvous_key(parsed.FullKey().data(), parsed.FullKey().size()); + req_.set_request_id(GetUniqueRequestId()); request_buffer_size_ = req_.ByteSize(); // request_buffer_ = new char[request_buffer_size_]; // req_.SerializeToArray(request_buffer_, request_buffer_size_); @@ -177,6 +179,8 @@ class MPIRendezvousMgr : public BaseRendezvousMgr { std::map> recv_tensor_map_ GUARDED_BY(mrq_); + RecentRequestIds recv_tensor_recent_request_ids_; + void AddRequest(RecvTensorRequest, const int); void MPIBackgroundThread(); diff --git a/tensorflow/contrib/nccl/kernels/nccl_manager.h b/tensorflow/contrib/nccl/kernels/nccl_manager.h index cb1719c3be6a5c042db6e258d68663e70bfbfa15..bb219e0edc8a2c4ba0ce0583cbe4018a4fa3a1d1 100644 --- a/tensorflow/contrib/nccl/kernels/nccl_manager.h +++ b/tensorflow/contrib/nccl/kernels/nccl_manager.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_NCCL_COMMUNICATOR_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_NCCL_COMMUNICATOR_H_ +#ifndef TENSORFLOW_CORE_KERNELS_NCCL_COMMUNICATOR_H_ +#define TENSORFLOW_CORE_KERNELS_NCCL_COMMUNICATOR_H_ #ifdef GOOGLE_CUDA @@ -136,4 +136,4 @@ class NcclManager { #endif // GOOGLE_CUDA -#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_NCCL_COMMUNICATOR_H_ +#endif // TENSORFLOW_CORE_KERNELS_NCCL_COMMUNICATOR_H_ diff --git a/tensorflow/contrib/ndlstm/python/lstm1d.py b/tensorflow/contrib/ndlstm/python/lstm1d.py index d3c3531f405a74d89ce736dae0134939e189f7ae..b24e332e4aea7f0ef981909558dcd6d730ca08a7 100644 --- a/tensorflow/contrib/ndlstm/python/lstm1d.py +++ b/tensorflow/contrib/ndlstm/python/lstm1d.py @@ -22,7 +22,6 @@ from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.contrib.framework.python.ops import variables from tensorflow.python.framework import constant_op from tensorflow.python.ops import array_ops -from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import rnn @@ -85,18 +84,11 @@ def ndlstm_base_dynamic(inputs, noutput, scope=None, reverse=False): Output sequence (length, batch_size, noutput) """ with variable_scope.variable_scope(scope, "SeqLstm", [inputs]): - # TODO(tmb) make batch size, sequence_length dynamic - # example: sequence_length = tf.shape(inputs)[0] - _, batch_size, _ = _shape(inputs) - lstm_cell = rnn_cell.BasicLSTMCell(noutput, state_is_tuple=False) - state = array_ops.zeros([batch_size, lstm_cell.state_size]) - sequence_length = int(inputs.get_shape()[0]) - sequence_lengths = math_ops.to_int64( - array_ops.fill([batch_size], sequence_length)) + lstm_cell = rnn_cell.BasicLSTMCell(noutput) if reverse: inputs = array_ops.reverse_v2(inputs, [0]) outputs, _ = rnn.dynamic_rnn( - lstm_cell, inputs, sequence_lengths, state, time_major=True) + lstm_cell, inputs, time_major=True, dtype=inputs.dtype) if reverse: outputs = array_ops.reverse_v2(outputs, [0]) return outputs diff --git a/tensorflow/contrib/nearest_neighbor/kernels/heap.h b/tensorflow/contrib/nearest_neighbor/kernels/heap.h index 6e33a574e25d39a13a256383cbc9848fdb8b788f..32925569a82c43be75a0b6e93d7d781cda3d53f4 100644 --- a/tensorflow/contrib/nearest_neighbor/kernels/heap.h +++ b/tensorflow/contrib/nearest_neighbor/kernels/heap.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_NEAREST_NEIGHBOR_KERNELS_HEAP_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_NEAREST_NEIGHBOR_KERNELS_HEAP_H_ +#ifndef TENSORFLOW_CONTRIB_NEAREST_NEIGHBOR_KERNELS_HEAP_H_ +#define TENSORFLOW_CONTRIB_NEAREST_NEIGHBOR_KERNELS_HEAP_H_ #include #include @@ -205,4 +205,4 @@ class AugmentedHeap : public HeapBase { } // namespace nearest_neighbor } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_NEAREST_NEIGHBOR_KERNELS_HEAP_H_ +#endif // TENSORFLOW_CONTRIB_NEAREST_NEIGHBOR_KERNELS_HEAP_H_ diff --git a/tensorflow/contrib/nearest_neighbor/kernels/hyperplane_lsh_probes.h b/tensorflow/contrib/nearest_neighbor/kernels/hyperplane_lsh_probes.h index 1670e2f83b3afa10ca76b765bf97cc1c08038fba..c53205e1a4089c8bb5159621662496b798acf242 100644 --- a/tensorflow/contrib/nearest_neighbor/kernels/hyperplane_lsh_probes.h +++ b/tensorflow/contrib/nearest_neighbor/kernels/hyperplane_lsh_probes.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_NEAREST_NEIGHBOR_KERNELS_HYPERPLANE_LSH_PROBES_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_NEAREST_NEIGHBOR_KERNELS_HYPERPLANE_LSH_PROBES_H_ +#ifndef TENSORFLOW_CONTRIB_NEAREST_NEIGHBOR_KERNELS_HYPERPLANE_LSH_PROBES_H_ +#define TENSORFLOW_CONTRIB_NEAREST_NEIGHBOR_KERNELS_HYPERPLANE_LSH_PROBES_H_ #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" @@ -232,4 +232,4 @@ class HyperplaneMultiprobe { } // namespace nearest_neighbor } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_NEAREST_NEIGHBOR_KERNELS_HYPERPLANE_LSH_PROBES_H_ +#endif // TENSORFLOW_CONTRIB_NEAREST_NEIGHBOR_KERNELS_HYPERPLANE_LSH_PROBES_H_ diff --git a/tensorflow/contrib/opt/BUILD b/tensorflow/contrib/opt/BUILD index 9c961f2b9c828f7406516860b7e3fd3dc343d993..827279bd476f9666a972f43ad557fde6d0b6c59a 100644 --- a/tensorflow/contrib/opt/BUILD +++ b/tensorflow/contrib/opt/BUILD @@ -19,6 +19,7 @@ py_library( "python/training/elastic_average_optimizer.py", "python/training/external_optimizer.py", "python/training/lazy_adam_optimizer.py", + "python/training/model_average_optimizer.py", "python/training/moving_average_optimizer.py", "python/training/multitask_optimizer_wrapper.py", "python/training/nadam_optimizer.py", @@ -193,6 +194,27 @@ tf_py_test( ], ) +tf_py_test( + name = "model_average_optimizer_test", + srcs = ["python/training/model_average_optimizer_test.py"], + additional_deps = [ + ":opt_py", + "//tensorflow/python:client", + "//tensorflow/python:client_testlib", + "//tensorflow/python:array_ops", + "//tensorflow/python:variables", + "//tensorflow/python:framework", + "//tensorflow/python:platform", + "//tensorflow/python:training", + "//tensorflow/python:ops", + "//tensorflow/python:framework_for_generated_wrappers", + "//third_party/py/numpy", + ], + tags = [ + "notap", # This test launches local server. + ], +) + py_test( name = "sign_decay_test", srcs = ["python/training/sign_decay_test.py"], diff --git a/tensorflow/contrib/opt/__init__.py b/tensorflow/contrib/opt/__init__.py index 90d2f924629800ccf26c160edd22c13b817f4584..6c1bb1adc096f5b8e6945ea1492727d16cf29e65 100644 --- a/tensorflow/contrib/opt/__init__.py +++ b/tensorflow/contrib/opt/__init__.py @@ -29,6 +29,7 @@ from tensorflow.contrib.opt.python.training.nadam_optimizer import * from tensorflow.contrib.opt.python.training.powersign import * from tensorflow.contrib.opt.python.training.variable_clipping_optimizer import * from tensorflow.contrib.opt.python.training.elastic_average_optimizer import * +from tensorflow.contrib.opt.python.training.model_average_optimizer import * # pylint: enable=wildcard-import from tensorflow.python.util.all_util import remove_undocumented @@ -48,7 +49,9 @@ _allowed_symbols = [ 'MultitaskOptimizerWrapper', 'clip_gradients_by_global_norm', 'ElasticAverageOptimizer', - 'ElasticAverageCustomGetter' + 'ElasticAverageCustomGetter', + 'ModelAverageOptimizer', + 'ModelAverageCustomGetter' ] remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/opt/python/training/lazy_adam_optimizer.py b/tensorflow/contrib/opt/python/training/lazy_adam_optimizer.py index 4c3fec067287e8edefcc4e36ca9fa91f5657013b..aeca900bc8ff4c4cc26da490ce43dfec70fd9f11 100644 --- a/tensorflow/contrib/opt/python/training/lazy_adam_optimizer.py +++ b/tensorflow/contrib/opt/python/training/lazy_adam_optimizer.py @@ -47,8 +47,9 @@ class LazyAdamOptimizer(adam.AdamOptimizer): """ def _apply_sparse(self, grad, var): - beta1_power = math_ops.cast(self._beta1_power, var.dtype.base_dtype) - beta2_power = math_ops.cast(self._beta2_power, var.dtype.base_dtype) + beta1_power, beta2_power = self._get_beta_accumulators() + beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype) + beta2_power = math_ops.cast(beta2_power, var.dtype.base_dtype) lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype) beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype) beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype) diff --git a/tensorflow/contrib/opt/python/training/model_average_optimizer.py b/tensorflow/contrib/opt/python/training/model_average_optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..a7c97a1da2baf29914337094c6153447c997af08 --- /dev/null +++ b/tensorflow/contrib/opt/python/training/model_average_optimizer.py @@ -0,0 +1,308 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Wrapper optimizer for Model Average.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import data_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import state_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables +from tensorflow.python.training import optimizer +from tensorflow.python.training import session_run_hook + +GLOBAL_VARIABLE_NAME = "global_center_variable" + + +class ModelAverageCustomGetter(object): + """Custom_getter class is used to do. + + 1. Change trainable variables to local collection and place them at worker + device + 2. Generate global variables + Notice that the class should be used with tf.replica_device_setter, + so that the global center variables and global step variable can be placed + at ps device. Besides, use 'tf.get_variable' instead of 'tf.Variable' to + use this custom getter. + + For example, + ma_custom_getter = ModelAverageCustomGetter(worker_device) + with tf.device( + tf.train.replica_device_setter( + worker_device=worker_device, + ps_device="/job:ps/cpu:0", + cluster=cluster)), + tf.variable_scope('',custom_getter=ma_custom_getter): + hid_w = tf.get_variable( + initializer=tf.truncated_normal( + [IMAGE_PIXELS * IMAGE_PIXELS, FLAGS.hidden_units], + stddev=1.0 / IMAGE_PIXELS), + name="hid_w") + hid_b = tf.get_variable(initializer=tf.zeros([FLAGS.hidden_units]), + name="hid_b") + """ + + def __init__(self, worker_device): + """Create a new `ElasticAverageCustomGetter`. + + Args: + worker_device: String. Name of the `worker` job. + """ + self._worker_device = worker_device + self._local_2_global = {} + + def __call__(self, getter, name, trainable, collections, *args, **kwargs): + if trainable: + with ops.device(self._worker_device): + local_var = getter( + name, + trainable=True, + collections=[ops.GraphKeys.LOCAL_VARIABLES], + *args, + **kwargs) + + global_variable = variable_scope.variable( + name="%s/%s" % (GLOBAL_VARIABLE_NAME, name), + initial_value=local_var.initialized_value(), + trainable=False, + collections=[ops.GraphKeys.GLOBAL_VARIABLES]) + + self._local_2_global[local_var] = global_variable + return local_var + else: + return getter(name, trainable, collections, *args, **kwargs) + + +class ModelAverageOptimizer(optimizer.Optimizer): + """Wrapper optimizer that implements the Model Average algorithm. + + This is a sync optimizer. During the training, each worker will update + the local variables and maintains its own local_step, which starts from 0 + and is incremented by 1 after each update of local variables. Whenever the + interval_steps divides the local step, the local variables from all the + workers will be averaged and assigned to global center variables. Then the + local variables will be assigned by global center variables. + """ + + def __init__(self, + opt, + num_worker, + is_chief, + ma_custom_getter, + interval_steps=100, + use_locking=True, + name="ModelAverageOptimizer"): + """Construct a new model average optimizer. + + Args: + opt: The actual optimizer that will be used to update local variables + num_worker: The number of workers + is_chief: whether chief worker + ma_custom_getter: ModelAverageCustomGetter + interval_steps: An int point value to controls the frequency of the + average of local variables + use_locking: If True use locks for update operations + name: string. Optional name of the returned operation + """ + super(ModelAverageOptimizer, self).__init__(use_locking, name) + self._opt = opt + self._num_worker = num_worker + self._is_chief = is_chief + self._local_2_global = ma_custom_getter._local_2_global # pylint:disable=protected-access + self._interval_steps = interval_steps + self._accumulator_list = [] + self._chief_init_op = None + + self._local_step = variable_scope.get_variable( + initializer=0, + trainable=False, + collections=[ops.GraphKeys.LOCAL_VARIABLES], + name="local_step") + + self._opt._prepare() # pylint:disable=protected-access + + def compute_gradients(self, *args, **kwargs): + """Compute gradients of "loss" for the variables in "var_list". + + This simply wraps the compute_gradients() from the real optimizer. + + Args: + *args: Arguments for compute_gradients(). + **kwargs: Keyword arguments for compute_gradients(). + + Returns: + A list of (gradient, variable) pairs. + """ + return self._opt.compute_gradients(*args, **kwargs) + + def _local_vars_update(self, var_list): + """Get the update ops for the local variables in "var_list". + + Args: + var_list: Optional list or tuple of 'tf.Variable' to update + + Returns: + An update op + + Raises: + ValueError: if var_list is empty. + """ + if not var_list: + raise ValueError("The list of local_variables should not be empty") + update_ops = [] + global_center_vars = [self._local_2_global[var] for var in var_list] + for lvar, gvar in zip(var_list, global_center_vars): + with ops.device(lvar.device): + update_ops.append(state_ops.assign(lvar, gvar.read_value())) + return control_flow_ops.group(*(update_ops)) + + def apply_gradients(self, grads_and_vars, global_step=None, name=None): + """Apply gradients to variables. + + This contains most of the synchronization implementation and also wraps the + apply_gradients() from the real optimizer. The chief work updates global + variables. + + Args: + grads_and_vars: List of (gradient, variable) pairs as returned by + compute_gradients(). + global_step: Optional Variable to increment by one after the + variables have been updated. + name: Optional name for the returned operation. Default to the + name passed to the Optimizer constructor. + + Returns: + A conditional 'Operation' that update both local and global variables or + just local variables + + Raises: + ValueError: If the grads_and_vars is empty. + ValueError: If global step is not provided, the staleness cannot be + checked. + """ + + # update local variables + if not grads_and_vars: + raise ValueError("Must supply at least one variable") + if global_step is None: + raise ValueError("Global step is required") + + apply_updates = self._opt.apply_gradients(grads_and_vars) + with ops.control_dependencies([apply_updates]): + local_update = state_ops.assign_add( + self._local_step, 1, name="local_step_update").op + + # update global variables. + def _update_global_variables(): # pylint: disable=missing-docstring + local_vars = [v for g, v in grads_and_vars if g is not None] + global_vars = [self._local_2_global[v] for v in local_vars] + # sync queue + with ops.colocate_with(global_step): + sync_queue = data_flow_ops.FIFOQueue( + -1, [dtypes.bool], shapes=[[]], shared_name="sync_queue") + train_ops = [] + aggregated_vars = [] + with ops.name_scope(None, self._name + "/global"): + for var, gvar in zip(local_vars, global_vars): + # pylint: disable=protected-access + with ops.device(gvar.device): + if isinstance(var._ref(), ops.Tensor): + var_accum = data_flow_ops.ConditionalAccumulator( + var.dtype, + shape=var.get_shape(), + shared_name=gvar.name + "/var_accum") + train_ops.append( + var_accum.apply_grad(var._ref(), local_step=global_step)) + aggregated_vars.append(var_accum.take_grad(self._num_worker)) + else: + raise ValueError("Unknown local variable type!") + self._accumulator_list.append((var_accum, gvar.device)) + # chief worker updates global vars and enqueues tokens to the sync queue + if self._is_chief: + update_ops = [] + with ops.control_dependencies(train_ops): + for avg_var, gvar in zip(aggregated_vars, global_vars): + with ops.device(gvar.device): + update_ops.append(state_ops.assign(gvar, avg_var)) + with ops.device(global_step.device): + update_ops.append(state_ops.assign_add(global_step, 1)) + with ops.control_dependencies(update_ops), ops.device( + global_step.device): + tokens = array_ops.fill([self._num_worker - 1], + constant_op.constant(False)) + sync_op = sync_queue.enqueue_many(tokens) + else: + with ops.control_dependencies(train_ops), ops.device( + global_step.device): + sync_op = sync_queue.dequeue() + + with ops.control_dependencies([sync_op]): + local_update_op = self._local_vars_update(local_vars) + return local_update_op + + with ops.control_dependencies([local_update]): + condition = math_ops.equal( + math_ops.mod(self._local_step, self._interval_steps), 0) + conditional_update = control_flow_ops.cond( + condition, _update_global_variables, control_flow_ops.no_op) + + chief_init_ops = [] + for accum, dev in self._accumulator_list: + with ops.device(dev): + chief_init_ops.append( + accum.set_global_step(global_step, name="SetGlobalStep")) + self._chief_init_op = control_flow_ops.group(*(chief_init_ops)) + + return conditional_update + + def get_init_op(self): + """Returns the op. + + This method lets all the local variables equal to the global + variables before the training begins. + """ + return self._local_vars_update(variables.trainable_variables()) + + def make_session_run_hook(self): + """Creates a hook to handle ModelAverage ops such as initialization.""" + return _ModelAverageOptimizerHook(self, self._is_chief) + + +class _ModelAverageOptimizerHook(session_run_hook.SessionRunHook): # pylint: disable=missing-docstring + + def __init__(self, ma_optimizer, is_chief): + """Creates hook to handle ModelAverageOptimizer initialization ops. + + Args: + ma_optimizer: `ModelAverageOptimizer` which this hook will initialize. + is_chief: `Bool`, whether is this a chief replica or not. + """ + self._ma_optimizer = ma_optimizer + self._is_chief = is_chief + + def begin(self): + self._local_init_op = variables.local_variables_initializer() + self._global_init_op = None + if self._is_chief: + self._global_init_op = variables.global_variables_initializer() + self._chief_init_op = self._ma_optimizer._chief_init_op # pylint: disable=protected-access + self._variable_init_op = self._ma_optimizer.get_init_op() diff --git a/tensorflow/contrib/opt/python/training/model_average_optimizer_test.py b/tensorflow/contrib/opt/python/training/model_average_optimizer_test.py new file mode 100644 index 0000000000000000000000000000000000000000..6cca0a8a009456f266245fd9a638bfab371c9b34 --- /dev/null +++ b/tensorflow/contrib/opt/python/training/model_average_optimizer_test.py @@ -0,0 +1,198 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for ModelAverageOptimizer.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import portpicker + +from tensorflow.contrib.opt.python.training import model_average_optimizer +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables +from tensorflow.python.platform import test +from tensorflow.python.training import device_setter +from tensorflow.python.training import gradient_descent +from tensorflow.python.training import server_lib +from tensorflow.python.training import training +from tensorflow.python.training import training_util + + +def create_local_cluster(num_workers, num_ps, protocol="grpc"): + """Create local GRPC servers and return them.""" + worker_ports = [portpicker.pick_unused_port() for _ in range(num_workers)] + ps_ports = [portpicker.pick_unused_port() for _ in range(num_ps)] + cluster_dict = { + "worker": ["localhost:%s" % port for port in worker_ports], + "ps": ["localhost:%s" % port for port in ps_ports] + } + cs = server_lib.ClusterSpec(cluster_dict) + + workers = [ + server_lib.Server( + cs, job_name="worker", protocol=protocol, task_index=ix, start=True) + for ix in range(num_workers) + ] + ps_servers = [ + server_lib.Server( + cs, job_name="ps", protocol=protocol, task_index=ix, start=True) + for ix in range(num_ps) + ] + + return cluster_dict, workers, ps_servers + + +# Creates the workers and return their sessions, graphs, train_ops. +# Cheif worker will update at last +def _get_workers(num_workers, steps, workers): + sessions = [] + graphs = [] + train_ops = [] + for worker_id in range(num_workers): + graph = ops.Graph() + is_chief = (worker_id == 0) + with graph.as_default(): + worker_device = "/job:worker/task:%d/cpu:0" % (worker_id) + ma_coustom = model_average_optimizer.ModelAverageCustomGetter( + worker_device=worker_device) + with variable_scope.variable_scope( + "", custom_getter=ma_coustom), ops.device( + device_setter.replica_device_setter( + worker_device=worker_device, + ps_device="/job:ps/task:0/cpu:0", + ps_tasks=1)): + + global_step = variables.Variable(0, name="global_step", trainable=False) + var_0 = variable_scope.get_variable(initializer=0.0, name="v0") + var_1 = variable_scope.get_variable(initializer=1.0, name="v1") + + with ops.device("/job:worker/task:" + str(worker_id)): + if worker_id == 0: + grads_0 = constant_op.constant(-1.0) + grads_1 = constant_op.constant(-1.0) + else: + grads_0 = constant_op.constant(-2.0) + grads_1 = constant_op.constant(-2.0) + sgd_opt = gradient_descent.GradientDescentOptimizer(1.0) + opt = model_average_optimizer.ModelAverageOptimizer( + opt=sgd_opt, + num_worker=num_workers, + ma_custom_getter=ma_coustom, + is_chief=is_chief, + interval_steps=steps) + train_op = [ + opt.apply_gradients([[grads_0, var_0], [grads_1, var_1]], + global_step) + ] + easgd_hook = opt.make_session_run_hook() + # Creates MonitoredSession + sess = training.MonitoredTrainingSession( + workers[worker_id].target, hooks=[easgd_hook]) + + sessions.append(sess) + graphs.append(graph) + train_ops.append(train_op) + return sessions, graphs, train_ops + + +class ModelAverageOptimizerTest(test.TestCase): + def _run(self, train_op, sess): + sess.run(train_op) + + def test1Workers2Period(self): + num_workers = 2 + steps = 2 + num_ps = 1 + _, workers, _ = create_local_cluster( + num_workers=num_workers, num_ps=num_ps) + + sessions, graphs, train_ops = _get_workers(num_workers, steps, workers) + + var_0 = graphs[0].get_tensor_by_name("v0:0") + var_1 = graphs[0].get_tensor_by_name("v1:0") + global_step = training_util.get_global_step(graphs[0]) + global_var_0 = graphs[0].get_tensor_by_name( + model_average_optimizer.GLOBAL_VARIABLE_NAME + "/v0:0") + global_var_1 = graphs[0].get_tensor_by_name( + model_average_optimizer.GLOBAL_VARIABLE_NAME + "/v1:0") + + # Verify the initialized value. + self.assertAllEqual(0.0, sessions[0].run(var_0)) + self.assertAllEqual(1.0, sessions[0].run(var_1)) + self.assertAllEqual(0.0, sessions[0].run(global_var_0)) + self.assertAllEqual(1.0, sessions[0].run(global_var_1)) + self.assertAllEqual(0, sessions[0].run(global_step)) + + sessions[0].run(train_ops[0]) + sessions[1].run(train_ops[1]) + + self.assertAllEqual(1.0, sessions[0].run(var_0)) + self.assertAllEqual(2.0, sessions[0].run(var_1)) + self.assertAllEqual(0.0, sessions[0].run(global_var_0)) + self.assertAllEqual(1.0, sessions[0].run(global_var_1)) + self.assertAllEqual(0, sessions[0].run(global_step)) + + # iteration 2, global varibale update + thread_0 = self.checkedThread( + target=self._run, args=(train_ops[0], sessions[0])) + thread_1 = self.checkedThread( + target=self._run, args=(train_ops[1], sessions[1])) + thread_0.start() + thread_1.start() + thread_0.join() + thread_1.join() + + self.assertAllEqual(3.0, sessions[0].run(var_0)) + self.assertAllEqual(4.0, sessions[0].run(var_1)) + self.assertAllEqual(3.0, sessions[0].run(global_var_0)) + self.assertAllEqual(4.0, sessions[0].run(global_var_1)) + self.assertAllEqual(1, sessions[0].run(global_step)) + + # iteration 3 + sessions[0].run(train_ops[0]) + + self.assertAllEqual(4.0, sessions[0].run(var_0)) + self.assertAllEqual(5.0, sessions[0].run(var_1)) + self.assertAllEqual(3.0, sessions[0].run(global_var_0)) + self.assertAllEqual(4.0, sessions[0].run(global_var_1)) + self.assertAllEqual(1, sessions[0].run(global_step)) + + def testPS2TasksWithClusterSpecClass(self): + cluster_spec = server_lib.ClusterSpec({ + "ps": ["ps0:2222", "ps1:2222"], + "worker": ["worker0:2222", "worker1:2222", "worker2:2222"] + }) + worker_device = "/job:worker/task:0" + ma_coustom = model_average_optimizer.ModelAverageCustomGetter( + worker_device=worker_device) + from tensorflow.python.training import device_setter + with ops.device( + device_setter.replica_device_setter(cluster=cluster_spec, + worker_device=worker_device, + ps_device="/job:ps")), \ + variable_scope.variable_scope("", custom_getter=ma_coustom): + v = variable_scope.get_variable(initializer=[1, 2], name="v") + w = variable_scope.get_variable(initializer=[2, 1], name="w") + v_g, w_g = ma_coustom._local_2_global[v], ma_coustom._local_2_global[w] + self.assertDeviceEqual("/job:worker/task:0", v.device) + self.assertDeviceEqual("job:ps/task:0", v_g.device) + self.assertDeviceEqual("/job:worker/task:0", w.device) + self.assertDeviceEqual("job:ps/task:1", w_g.device) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/opt/python/training/nadam_optimizer.py b/tensorflow/contrib/opt/python/training/nadam_optimizer.py index a4421ecfe6b0af9759c6aaa51d644f1211965b6a..44a8890cb107440b79cf8fbbdfcfda503b1c910f 100644 --- a/tensorflow/contrib/opt/python/training/nadam_optimizer.py +++ b/tensorflow/contrib/opt/python/training/nadam_optimizer.py @@ -34,12 +34,13 @@ class NadamOptimizer(adam.AdamOptimizer): def _apply_dense(self, grad, var): m = self.get_slot(var, "m") v = self.get_slot(var, "v") + beta1_power, beta2_power = self._get_beta_accumulators() return training_ops.apply_adam( var, m, v, - math_ops.cast(self._beta1_power, var.dtype.base_dtype), - math_ops.cast(self._beta2_power, var.dtype.base_dtype), + math_ops.cast(beta1_power, var.dtype.base_dtype), + math_ops.cast(beta2_power, var.dtype.base_dtype), math_ops.cast(self._lr_t, var.dtype.base_dtype), math_ops.cast(self._beta1_t, var.dtype.base_dtype), math_ops.cast(self._beta2_t, var.dtype.base_dtype), @@ -51,12 +52,13 @@ class NadamOptimizer(adam.AdamOptimizer): def _resource_apply_dense(self, grad, var): m = self.get_slot(var, "m") v = self.get_slot(var, "v") + beta1_power, beta2_power = self._get_beta_accumulators() return training_ops.resource_apply_adam( var.handle, m.handle, v.handle, - math_ops.cast(self._beta1_power, grad.dtype.base_dtype), - math_ops.cast(self._beta2_power, grad.dtype.base_dtype), + math_ops.cast(beta1_power, grad.dtype.base_dtype), + math_ops.cast(beta2_power, grad.dtype.base_dtype), math_ops.cast(self._lr_t, grad.dtype.base_dtype), math_ops.cast(self._beta1_t, grad.dtype.base_dtype), math_ops.cast(self._beta2_t, grad.dtype.base_dtype), @@ -66,8 +68,9 @@ class NadamOptimizer(adam.AdamOptimizer): use_nesterov=True) def _apply_sparse_shared(self, grad, var, indices, scatter_add): - beta1_power = math_ops.cast(self._beta1_power, var.dtype.base_dtype) - beta2_power = math_ops.cast(self._beta2_power, var.dtype.base_dtype) + beta1_power, beta2_power = self._get_beta_accumulators() + beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype) + beta2_power = math_ops.cast(beta2_power, var.dtype.base_dtype) lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype) beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype) beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype) diff --git a/tensorflow/contrib/periodic_resample/BUILD b/tensorflow/contrib/periodic_resample/BUILD index 71582f9c9a01eb221666e2c71c4a2edb18e7cb98..bd9078ae76ee27ec26c09d1aa2012f871cbdf5e9 100644 --- a/tensorflow/contrib/periodic_resample/BUILD +++ b/tensorflow/contrib/periodic_resample/BUILD @@ -6,6 +6,7 @@ exports_files(["LICENSE"]) load( "//tensorflow:tensorflow.bzl", + "py_test", "tf_gen_op_libs", "tf_custom_op_library", "tf_custom_op_py_library", @@ -64,11 +65,28 @@ py_library( "python/__init__.py", ], srcs_version = "PY2AND3", + tags = [ + "notap", + ], deps = [ ":periodic_resample_op_py", ], ) +py_test( + name = "periodic_resample_op_test", + srcs = ["python/kernel_tests/periodic_resample_op_test.py"], + srcs_version = "PY2AND3", + tags = [ + "notap", + ], + deps = [ + ":init_py", + "//tensorflow/contrib/util:util_py", + "//tensorflow/python:framework_test_lib", + ], +) + # py_library( # name = "periodic_resample_op_py", # srcs = ["python/ops/periodic_resample_op.py"], diff --git a/tensorflow/contrib/periodic_resample/kernels/periodic_resample_op.h b/tensorflow/contrib/periodic_resample/kernels/periodic_resample_op.h index bef21f7a5c8a27011f95eb7fae8451ca944d3cde..ba410f025d497178cfc1666ae231e75bad55b05e 100644 --- a/tensorflow/contrib/periodic_resample/kernels/periodic_resample_op.h +++ b/tensorflow/contrib/periodic_resample/kernels/periodic_resample_op.h @@ -100,6 +100,8 @@ template = input_tensor_shape.dim_size(i), + tensorflow::errors::InvalidArgument( + "periodic_resample expects the size of non-adjustable " + "dimensions be at least as large as size of input tensor." + " Dimension ", i, " input tensor has size ", + input_tensor_shape.dim_size(i), ", desired shape has size ", + desired_shape[i], ".")); + // target_dimensions[i] = desired_shape(i); target_dimensions[i] = desired_shape[i]; new_sliced_size *= target_dimensions[i]; diff --git a/tensorflow/contrib/periodic_resample/ops/array_ops.cc b/tensorflow/contrib/periodic_resample/ops/array_ops.cc index c90fc06c7fb9d79e8fd7a937e786a34947d8c1cb..82bd79695646e3673c2c78ad99dd2bd200fc2fbf 100644 --- a/tensorflow/contrib/periodic_resample/ops/array_ops.cc +++ b/tensorflow/contrib/periodic_resample/ops/array_ops.cc @@ -34,26 +34,40 @@ This function implements a slightly more generic version of the subpixel convolutions found in this [paper](https://arxiv.org/abs/1609.05158). The formula for computing the elements in the `output` tensor is as follows: + `T` = `values` tensor of rank `R` + `S` = desired `shape` of output tensor (vector of length `R`) + `P` = `output` tensor of rank `R` - \((T_1,\ldots,T_R)\) = shape(`T`) - \([S_1,\ldots,S_q,\ldots,S_R]\) = elements of vector `S` - A single element in `S` is left unspecified (denoted \(S_q=-1\)). - Let \(f_i\) denote the (possibly non-integer) factor that relates the original - dimension to the desired dimensions, \(S_i=f_i T_i\), for \(i\neq q\) where - \(f_i>0\). + \\((T_1,\\ldots,T_R)\\) = shape(`T`) + + \\([S_1,\\ldots,S_q,\\ldots,S_R]\\) = elements of vector `S` + + A single element in `S` is left unspecified (denoted \\(S_q=-1\\)). + + Let \\(f_i\\) denote the (possibly non-integer) factor that relates the original + dimension to the desired dimensions, \\(S_i=f_i T_i\\), for \\(i\\neq q\\) where + \\(f_i>0\\). + Define the following: - \(g_i=\lceil f_i\rceil\) - \(t=\prod_i T_i\) - \(s=\prod_{i\neq q} S_i\) - \(S_q\) can then be defined as by \(S_q=\lfloor t/s\rfloor\). + + \\(g_i=\\lceil f_i\\rceil\\) + + \\(t=\\prod_i T_i\\) + + \\(s=\\prod_{i\\neq q} S_i\\) + + \\(S_q\\) can then be defined by \\(S_q=\\lfloor t/s\\rfloor\\). The elements of the resulting tensor are defined as - \(P_{s_1,\ldots,s_R}=T_{h_1,\ldots,h_q,\ldots,h_R}\). - The \(h_i\) (\(i\neq q\)) are defined by \(h_i=\lfloor s_i/g_i\rfloor\). - \(h_q=S_q\sum_{j\neq q}^{q-1}G_j \mathrm{mod}(s_j,g_j) + s_q\), where - \(G_j=\prod_{i}^{j-1}g_i\) (\(G_0=1\)). + + \\(P_{s_1,\\ldots,s_R}=T_{h_1,\\ldots,h_q,\\ldots,h_R}\\). + + The \\(h_i\\) (\\(i\\neq q\\)) are defined by \\(h_i=\\lfloor s_i/g_i\\rfloor\\). + + \\(h_q=S_q\\sum_{j\\neq q}^{q-1}G_j \\mathrm{mod}(s_j,g_j) + s_q\\), where + \\(G_j=\\prod_{i}^{j-1}g_i\\) (\\(G_0=1\\)). One drawback of this method is that whenever the output dimensions are slightly less than integer multiples of the input dimensions, many of the tensor elements diff --git a/tensorflow/contrib/periodic_resample/python/kernel_tests/periodic_resample_op_test.py b/tensorflow/contrib/periodic_resample/python/kernel_tests/periodic_resample_op_test.py index 1d727870f652f3606218928983ea18e990d0afe6..a25de55e18b223db2b724aafb54b18d8f48a5baa 100644 --- a/tensorflow/contrib/periodic_resample/python/kernel_tests/periodic_resample_op_test.py +++ b/tensorflow/contrib/periodic_resample/python/kernel_tests/periodic_resample_op_test.py @@ -19,8 +19,9 @@ from __future__ import division from __future__ import print_function import numpy -import tensorflow + from tensorflow.contrib.periodic_resample import periodic_resample +from tensorflow.python.framework import errors_impl from tensorflow.python.framework import test_util from tensorflow.python.ops import variables from tensorflow.python.platform import googletest @@ -52,12 +53,11 @@ class PeriodicResampleTest(test_util.TensorFlowTestCase): def testPeriodicResampleBasic3D(self): - input_tensor = numpy.arange(2*2*4).reshape((2, 2, 4)) + input_tensor = numpy.arange(2 * 2 * 4).reshape((2, 2, 4)) desired_shape = numpy.array([4, 4, None]) - output_tensor = numpy.array([[[0], [2], [4], [6]], - [[1], [3], [5], [7]], - [[8], [10], [12], [14]], - [[9], [11], [13], [15]]]) + output_tensor = numpy.array([[[0], [2], [4], [6]], [[1], [3], [5], [7]], + [[8], [10], [12], [14]], [[9], [11], [13], + [15]]]) # NOTE: output_tensor != input_tensor.reshape((4, 4, -1)) with self.test_session(): @@ -71,24 +71,18 @@ class PeriodicResampleTest(test_util.TensorFlowTestCase): def testPeriodicResampleBasic4D(self): - input_tensor = numpy.arange(2*2*2*8).reshape((2, 2, 2, 8)) + input_tensor = numpy.arange(2 * 2 * 2 * 8).reshape((2, 2, 2, 8)) desired_shape = numpy.array([4, 4, 4, None]) - output_tensor = numpy.array([[[[0], [4], [8], [12]], - [[2], [6], [10], [14]], - [[16], [20], [24], [28]], - [[18], [22], [26], [30]]], - [[[1], [5], [9], [13]], - [[3], [7], [11], [15]], - [[17], [21], [25], [29]], - [[19], [23], [27], [31]]], - [[[32], [36], [40], [44]], - [[34], [38], [42], [46]], - [[48], [52], [56], [60]], - [[50], [54], [58], [62]]], - [[[33], [37], [41], [45]], - [[35], [39], [43], [47]], - [[49], [53], [57], [61]], - [[51], [55], [59], [63]]]]) + output_tensor = numpy.array( + [[[[0], [4], [8], [12]], [[2], [6], [10], [14]], + [[16], [20], [24], [28]], [[18], [22], [26], [30]]], + [[[1], [5], [9], [13]], [[3], [7], [11], [15]], [[17], [21], [25], + [29]], + [[19], [23], [27], + [31]]], [[[32], [36], [40], [44]], [[34], [38], [42], [46]], + [[48], [52], [56], [60]], [[50], [54], [58], [62]]], + [[[33], [37], [41], [45]], [[35], [39], [43], [47]], + [[49], [53], [57], [61]], [[51], [55], [59], [63]]]]) # NOTE: output_tensor != input_tensor.reshape((4, 4, 4, -1)) with self.test_session(): @@ -96,6 +90,19 @@ class PeriodicResampleTest(test_util.TensorFlowTestCase): result = periodic_resample(input_tensor, desired_shape).eval() self.assertAllEqual(result, output_tensor) + def testPeriodicResampleErrors(self): + input_tensor = numpy.zeros(shape=[1, 2, 2, 4]) + with self.test_session(): + variables.global_variables_initializer().run() + with self.assertRaisesWithPredicateMatch( + errors_impl.InvalidArgumentError, + 'Dimension 3 input tensor has size 4, desired shape has size 1'): + periodic_resample(input_tensor, [None, 4, 4, 1]).eval() + with self.assertRaisesWithPredicateMatch( + errors_impl.InvalidArgumentError, + '4, to be the same as the length of the desired shape, 3'): + periodic_resample(input_tensor, [None, 4, 4]).eval() + -if __name__ == "__main__": +if __name__ == '__main__': googletest.main() diff --git a/tensorflow/contrib/py2tf/BUILD b/tensorflow/contrib/py2tf/BUILD index 7358822ef5ca7dba87cc1046001aa7f07f45f845..d395de986d2364f1f6567e1ecbf0a873cbb0aa8c 100644 --- a/tensorflow/contrib/py2tf/BUILD +++ b/tensorflow/contrib/py2tf/BUILD @@ -26,7 +26,7 @@ py_library( srcs_version = "PY2AND3", visibility = ["//visibility:public"], deps = [ - "//tensorflow/contrib/py2tf/convert", + "//tensorflow/contrib/py2tf/converters", "//tensorflow/contrib/py2tf/pyct", "//tensorflow/contrib/py2tf/pyct/static_analysis", "@gast_archive//:gast", @@ -46,7 +46,7 @@ py_library( srcs_version = "PY2AND3", visibility = ["//tensorflow:__subpackages__"], deps = [ - "//tensorflow/contrib/py2tf/convert", + "//tensorflow/contrib/py2tf/converters", "//tensorflow/contrib/py2tf/pyct", "//tensorflow/contrib/py2tf/pyct/static_analysis", "@gast_archive//:gast", diff --git a/tensorflow/contrib/py2tf/api.py b/tensorflow/contrib/py2tf/api.py index 3a367209694d3210913e515ece62ad1f9e3fc3ed..ca1f4e2645ee20fd78c0d837885823d2e199537a 100644 --- a/tensorflow/contrib/py2tf/api.py +++ b/tensorflow/contrib/py2tf/api.py @@ -18,6 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from functools import wraps + import gast import six @@ -32,7 +34,115 @@ from tensorflow.python.util import tf_inspect # (currently we require (module + class name, type)) -def to_graph(o, arg_value_hints=None): +def graph_ready(f): + """No-op decorator that explicitly marks a function as graph-ready. + + Graph-ready functions are assumed to not need any conversion. + + Args: + f: Any callable. + Returns: + f itself. + """ + setattr(f, '__pyct_is_compile_decorator', True) + return f + + +def convert_inline(f, *args, **kwargs): + """Shorthand to convert and call a function. + + For example, the following two statements are equivalent: + + @convert() + def foo(): + ... + foo(bar) + + def foo(): + ... + convert_inline(foo, bar) + + Args: + f: Function to convert. Only this call will be converted. + *args: Passed through to f. + **kwargs: Passed through to f, with the following exceptions: + * arg_value_hints: A dict mapping parameter names to objects that can + hint at the type of those parameters. + + Returns: + The result of the converted f applied to args and kwargs. + """ + if 'arg_value_hints' in kwargs: + arg_value_hints = kwargs['arg_value_hints'] + del kwargs['arg_value_hints'] + else: + arg_value_hints = None + if tf_inspect.ismethod(f): + # When converting methods, the result is still an unbound function. + args = (f.__self__,) + args + return convert(arg_value_hints)(f)(*args, **kwargs) + + +def convert(recursive=False, arg_types=None): + """Decorator that compiles a function to graph mode. + + The decorator is dynamic - invoking compilation whenever the decorated function + is called. This means the parameter values are known at compilation. + + Args: + recursive: Whether to recusrively convert any functions that the decorator + function may call. + arg_types: See to_graph. + + Returns: + A decorator that compiles the given function to graph mode. + + Raises: + ValueError: If any of the arguments are illegal. + """ + if arg_types is None: + arg_types = {} + + def decorator(f): + """Decorator implementation.""" + + @wraps(f) + def wrapper(*args, **kwargs): + """Wrapper that calls the compiled version of the wrapped function.""" + partial_types = () + arg_values = {} + arg_names = tf_inspect.getargspec(f)[0] + for name, arg in zip(arg_names, args): + arg_values[name] = arg + arg_class = arg.__class__ + # If arg_value_hints specifies any name, use that instead. + if name not in arg_types: + arg_types[name] = (arg_class.__name__, arg_class) + if name == 'self' and tf_inspect.isclass(arg_class): + # Annotated methods need to specify that their owner type is partial, + # otherwise other members they call will not be converted. + partial_types = (arg_class,) + wrapped = to_graph( + f, + recursive=recursive, + arg_values=arg_values, + arg_types=arg_types, + partial_types=partial_types) + return wrapped(*args, **kwargs) + + # Sometimes the decorator is just desugared, making it impossible to detect. + # This attribute makes detection easier. + setattr(wrapper, '__pyct_is_compile_decorator', True) + return wrapper + + return decorator + + +def to_graph(e, + recursive=True, + arg_values=None, + arg_types=None, + partial_types=None): """Compile a Python entity into equivalent TensorFlow code. Currently supported entities: @@ -42,16 +152,26 @@ def to_graph(o, arg_value_hints=None): Classes are handled by converting all their methods into a new class. Args: - o: A Python function or class. - arg_value_hints: A dict mapping parameter names to objects that can hint - at the type of those parameters. + e: A Python entity. + recursive: Whether to recusrively convert any functions that the decorator + function may call. + arg_values: A dict containing value hints for symbols like function + parameters. + arg_types: A dict containing type hints for symbols like function + parameters. + partial_types: A set of types (e.g. classes) that will not be converted + entirely. Calls to member functions for these types will be renamed + independently. Returns: A function with a signature identical to `o`, but which when executed it creates TF a graph that has the same functionality as the original entity. """ - conversion_map = conversion.ConversionMap() - _, name = conversion.object_to_graph(o, conversion_map, arg_value_hints) + conversion_map = conversion.ConversionMap( + recursive=recursive, + nocompile_decorators=(convert, graph_ready, convert_inline), + partial_types=partial_types) + _, name = conversion.entity_to_graph(e, conversion_map, arg_values, arg_types) module = gast.Module([]) for import_line in config.COMPILED_IMPORT_STATEMENTS: @@ -62,29 +182,39 @@ def to_graph(o, arg_value_hints=None): # The compiled code should see everything the entry function saw. # TODO(mdan): This might not work well if the call tree spans modules? - if tf_inspect.isfunction(o): - compiled_node.__dict__.update(six.get_function_globals(o)) + if tf_inspect.isfunction(e): + compiled_node.__dict__.update(six.get_function_globals(e)) compiled_fn = getattr(compiled_node, name) return compiled_fn -def to_code(o, arg_value_hints=None, indentation=' '): +def to_code(e, + recursive=True, + arg_values=None, + arg_types=None, + partial_types=None, + indentation=' '): """Return the equivalent of an entity in TensorFlow code. See `to_graph` for more details. Args: - o: A Python function or class. - arg_value_hints: A dict mapping parameter names to objects that can hint - at the type of those parameters. + e: A Python entity. + recursive: See to_graph. + arg_values: See to_graph. + arg_types: See to_graph. + partial_types: See to_graph. indentation: String, when to use for each level of indentation. Returns: String. """ - conversion_map = conversion.ConversionMap() - conversion.object_to_graph(o, conversion_map, arg_value_hints) + conversion_map = conversion.ConversionMap( + recursive=recursive, + nocompile_decorators=(convert, graph_ready, convert_inline), + partial_types=partial_types) + conversion.entity_to_graph(e, conversion_map, arg_values, arg_types) imports = '\n'.join(config.COMPILED_IMPORT_STATEMENTS) code = '\n'.join( diff --git a/tensorflow/contrib/py2tf/api_test.py b/tensorflow/contrib/py2tf/api_test.py index 225b6d305fa5fe5a89cf0a639df84c2e29cda527..2384447708d7e0ab5dbfbeb592a47353f1909f50 100644 --- a/tensorflow/contrib/py2tf/api_test.py +++ b/tensorflow/contrib/py2tf/api_test.py @@ -28,17 +28,146 @@ from tensorflow.python.platform import test class ApiTest(test.TestCase): + def setUp(self): + config.DEFAULT_UNCOMPILED_MODULES.add((math_ops.__name__,)) + config.COMPILED_IMPORT_STATEMENTS = ( + 'from tensorflow.python.ops ' + 'import control_flow_ops as tf',) + + def test_decorator_recurses(self): + + class TestClass(object): + + def called_member(self, a): + if a < 0: + a = -a + return a + + @api.convert(recursive=True) + def test_method(self, x, s, a): + while math_ops.reduce_sum(x) > s: + x //= self.called_member(a) + return x + + tc = TestClass() + with self.test_session() as sess: + x = tc.test_method( + constant_op.constant([2, 4]), constant_op.constant(1), + constant_op.constant(-2)) + self.assertListEqual([0, 1], sess.run(x).tolist()) + + def test_decorator_does_not_recurse(self): + + class TestClass(object): + + def called_member(self, a): + return math_ops.negative(a) + + @api.convert(recursive=False) + def test_method(self, x, s, a): + while math_ops.reduce_sum(x) > s: + x //= self.called_member(a) + return x + + tc = TestClass() + with self.test_session() as sess: + x = tc.test_method( + constant_op.constant([2, 4]), constant_op.constant(1), + constant_op.constant(-2)) + self.assertListEqual([0, 1], sess.run(x).tolist()) + + def test_decorator_calls_converted(self): + + class TestClass(object): + + @api.graph_ready + def called_member(self, a): + return math_ops.negative(a) + + @api.convert(recursive=True) + def test_method(self, x, s, a): + while math_ops.reduce_sum(x) > s: + x //= self.called_member(a) + return x + + tc = TestClass() + with self.test_session() as sess: + x = tc.test_method( + constant_op.constant([2, 4]), constant_op.constant(1), + constant_op.constant(-2)) + self.assertListEqual([0, 1], sess.run(x).tolist()) + + def test_decorator_calls_decorated(self): + + class TestClass(object): + + @api.convert() + def called_member(self, a): + if a < 0: + a = -a + return a + + @api.convert(recursive=True) + def test_method(self, x, s, a): + while math_ops.reduce_sum(x) > s: + x //= self.called_member(a) + return x + + tc = TestClass() + with self.test_session() as sess: + x = tc.test_method( + constant_op.constant([2, 4]), constant_op.constant(1), + constant_op.constant(-2)) + self.assertListEqual([0, 1], sess.run(x).tolist()) + + def test_convert_call_site_decorator(self): + + class TestClass(object): + + def called_member(self, a): + if a < 0: + a = -a + return a + + @api.convert(recursive=True) + def test_method(self, x, s, a): + while math_ops.reduce_sum(x) > s: + x //= api.convert_inline(self.called_member, a) + return x + + tc = TestClass() + with self.test_session() as sess: + x = tc.test_method( + constant_op.constant([2, 4]), constant_op.constant(1), + constant_op.constant(-2)) + self.assertListEqual([0, 1], sess.run(x).tolist()) + + def test_graph_ready_call_site_decorator(self): + + class TestClass(object): + + def called_member(self, a): + return math_ops.negative(a) + + @api.convert(recursive=True) + def test_method(self, x, s, a): + while math_ops.reduce_sum(x) > s: + x //= api.graph_ready(self.called_member(a)) + return x + + tc = TestClass() + with self.test_session() as sess: + x = tc.test_method( + constant_op.constant([2, 4]), constant_op.constant(1), + constant_op.constant(-2)) + self.assertListEqual([0, 1], sess.run(x).tolist()) + def test_to_graph_basic(self): def test_fn(x, s): while math_ops.reduce_sum(x) > s: x //= 2 return x - config.DEFAULT_UNCOMPILED_MODULES.add((math_ops.__name__,)) - config.COMPILED_IMPORT_STATEMENTS = ( - 'from tensorflow.python.ops ' - 'import control_flow_ops as tf', - ) compiled_fn = api.to_graph(test_fn) with self.test_session() as sess: @@ -51,7 +180,6 @@ class ApiTest(test.TestCase): x /= 2 return x - config.DEFAULT_UNCOMPILED_MODULES.add((math_ops.__name__,)) compiled_code = api.to_code(test_fn) # Just check for some key words and that it is parseable Python code. diff --git a/tensorflow/contrib/py2tf/config.py b/tensorflow/contrib/py2tf/config.py index 0a9d52136eab494907992db0b6ad0cebcc1985ac..8c502a7a9e546dd9b9b40d7cf6d3c9821038afb3 100644 --- a/tensorflow/contrib/py2tf/config.py +++ b/tensorflow/contrib/py2tf/config.py @@ -22,6 +22,7 @@ PYTHON_LITERALS = { 'None': None, 'False': False, 'True': True, + 'float': float, } DEFAULT_UNCOMPILED_MODULES = set(( diff --git a/tensorflow/contrib/py2tf/conversion.py b/tensorflow/contrib/py2tf/conversion.py index 43bccae9538c4c68867764a9e433cac81bb98e78..b484eebbd58b955d1e783359269d16101d83cfd2 100644 --- a/tensorflow/contrib/py2tf/conversion.py +++ b/tensorflow/contrib/py2tf/conversion.py @@ -23,15 +23,17 @@ import six from tensorflow.contrib.py2tf import config from tensorflow.contrib.py2tf import naming -from tensorflow.contrib.py2tf.convert import break_canonicalization -from tensorflow.contrib.py2tf.convert import builtin_functions -from tensorflow.contrib.py2tf.convert import call_trees -from tensorflow.contrib.py2tf.convert import continue_canonicalization -from tensorflow.contrib.py2tf.convert import control_flow -from tensorflow.contrib.py2tf.convert import for_canonicalization -from tensorflow.contrib.py2tf.convert import logical_expressions -from tensorflow.contrib.py2tf.convert import print_functions -from tensorflow.contrib.py2tf.convert import side_effect_guards +from tensorflow.contrib.py2tf.converters import break_canonicalization +from tensorflow.contrib.py2tf.converters import builtin_functions +from tensorflow.contrib.py2tf.converters import call_trees +from tensorflow.contrib.py2tf.converters import continue_canonicalization +from tensorflow.contrib.py2tf.converters import control_flow +from tensorflow.contrib.py2tf.converters import decorators +from tensorflow.contrib.py2tf.converters import for_canonicalization +from tensorflow.contrib.py2tf.converters import logical_expressions +from tensorflow.contrib.py2tf.converters import print_functions +from tensorflow.contrib.py2tf.converters import side_effect_guards +from tensorflow.contrib.py2tf.pyct import context from tensorflow.contrib.py2tf.pyct import parser from tensorflow.contrib.py2tf.pyct.static_analysis import access from tensorflow.contrib.py2tf.pyct.static_analysis import live_values @@ -39,22 +41,35 @@ from tensorflow.contrib.py2tf.pyct.static_analysis import type_info from tensorflow.python.util import tf_inspect +# TODO(mdan): Might we not need any renaming at all? + + class ConversionMap(object): """ConversionMaps keep track of converting function hierarchies. Attributes: - dependency_cache: dict[object]: ast; maps original objects to their + recursive: Whether to recusrively convert any functions that the decorator + function may call. + nocompile_decorators: tuple of decorator functions that toggle compilation + off. + dependency_cache: dict[object]: ast; maps original entities to their converted AST - name_map: dict[string]: string; maps original objects to the name of + name_map: dict[string]: string; maps original entities to the name of their converted counterparts """ - def __init__(self): + # TODO(mdan): Rename to ConversionContext, and pull in additional flags. + + def __init__(self, recursive, nocompile_decorators, partial_types): + self.recursive = recursive + self.nocompile_decorators = nocompile_decorators + self.partial_types = partial_types if partial_types else () self.dependency_cache = {} self.name_map = {} - def new_namer(self, global_symbols): - return naming.Namer(global_symbols, self.name_map) + def new_namer(self, namespace): + return naming.Namer(namespace, self.recursive, self.name_map, + self.partial_types) def update_name_map(self, namer): for o, name in namer.renamed_calls.items(): @@ -62,77 +77,81 @@ class ConversionMap(object): if self.name_map[o] != name: raise ValueError( 'Calls to %s were converted using multiple names (%s). This is ' - 'possible when an object with one of these names already ' + 'possible when an entity with one of these names already ' 'existed. To fix, avoid using any of these names.') else: self.name_map[o] = name - def add_to_cache(self, original_object, converted_ast): - self.dependency_cache[original_object] = converted_ast + def add_to_cache(self, original_entity, converted_ast): + self.dependency_cache[original_entity] = converted_ast -def object_to_graph(o, conversion_map, value_hints): - """Compile a Python object into equivalent TensorFlow. +def entity_to_graph(o, conversion_map, arg_values, arg_types): + """Compile a Python entity into equivalent TensorFlow. - The function will also recursively compile all the objects that `o` + The function will also recursively compile all the entities that `o` references, updating `dependency_cache`. This function is reentrant, and relies on dependency_cache to avoid generating duplicate code. Args: - o: A Python object. + o: A Python entity. conversion_map: A ConversionMap object. - value_hints: A dict containing value hints for symbols like function + arg_values: A dict containing value hints for symbols like function + parameters. + arg_types: A dict containing type hints for symbols like function parameters. Returns: A tuple (ast, new_name): - * ast: An AST representing an object with interface equivalent to `o`, + * ast: An AST representing an entity with interface equivalent to `o`, but which when executed it creates TF a graph. - * new_name: The symbol name under which the new object can be found. + * new_name: The symbol name under which the new entity can be found. Raises: - ValueError: if the object is not supported. + ValueError: if the entity type is not supported. """ - if value_hints is None: - value_hints = {} - if tf_inspect.isclass(o): - node, new_name = class_to_graph(o, conversion_map, value_hints) + node, new_name = class_to_graph(o, conversion_map) elif tf_inspect.isfunction(o): - node, new_name = function_to_graph(o, conversion_map, value_hints) + node, new_name = function_to_graph(o, conversion_map, arg_values, arg_types) + elif tf_inspect.ismethod(o): + node, new_name = function_to_graph(o, conversion_map, arg_values, arg_types) else: raise ValueError( - 'Unsupported object type %s. Only functions and classes are supported' - ' for now.') + 'Entity "%s" has unsupported type "%s". Only functions and classes are ' + 'supported for now.' % (o, type(o))) conversion_map.add_to_cache(o, node) - # Recursively convert remaining dependencies. - for obj in conversion_map.name_map.keys(): - if obj not in conversion_map.dependency_cache: - if hasattr(obj, 'im_class'): - # Class members are converted with their objects. - continue - object_to_graph(obj, conversion_map, None) + if conversion_map.recursive: + for obj in conversion_map.name_map.keys(): + if obj not in conversion_map.dependency_cache: + if (hasattr(obj, 'im_class') and + getattr(obj, 'im_class') not in conversion_map.partial_types): + # Class members are converted with their objects, unless they're + # only converted partially. + continue + entity_to_graph(obj, conversion_map, {}, {}) return node, new_name -def class_to_graph(c, conversion_map, param_value_hints): - """Specialization of `object_to_graph` for classes.""" +def class_to_graph(c, conversion_map): + """Specialization of `entity_to_graph` for classes.""" converted_members = {} members = tf_inspect.getmembers(c, predicate=tf_inspect.ismethod) if not members: raise ValueError('Cannot convert %s: it has no member methods.') - if 'self' in param_value_hints: - raise ValueError('Hints may not be provided for reserved name "self".') - param_value_hints['self'] = (c.__name__, c) - class_globals = None for _, m in members: - node, _ = function_to_graph(m, conversion_map, param_value_hints, c) + node, _ = function_to_graph( + m, + conversion_map=conversion_map, + arg_values={}, + arg_types={'self': (c.__name__, c)}, + owner_type=c) # TODO(mdan): Do not assume all members have the same view of globals. if class_globals is None: class_globals = six.get_function_globals(m) @@ -149,10 +168,11 @@ def class_to_graph(c, conversion_map, param_value_hints): return node, class_name -def function_to_graph(f, conversion_map, param_value_hints, owner_type=None): - """Specialization of `object_to_graph` for callable functions.""" +def function_to_graph(f, conversion_map, arg_values, arg_types, + owner_type=None): + """Specialization of `entity_to_graph` for callable functions.""" node = parser.parse_object(f).body[0] - node_globals = six.get_function_globals(f) + namespace = six.get_function_globals(f) # This is needed for non-global functions. closure = six.get_function_closure(f) @@ -160,10 +180,17 @@ def function_to_graph(f, conversion_map, param_value_hints, owner_type=None): for e in closure: if callable(e.cell_contents): fn = e.cell_contents - node_globals[fn.__name__] = fn - - namer = conversion_map.new_namer(node_globals) - node = node_to_graph(node, namer, node_globals, param_value_hints) + namespace[fn.__name__] = fn + + namer = conversion_map.new_namer(namespace) + ctx = context.EntityContext( + namer=namer, + source_code=tf_inspect.getsource(f), + source_file=tf_inspect.getfile(f), + namespace=namespace, + arg_values=arg_values, + arg_types=arg_types) + node = node_to_graph(node, ctx, conversion_map.nocompile_decorators) # Simulate a rename to ensure the top level is in the name map. This is needed # for top level functions, and it also helps the consistency verification made @@ -177,29 +204,30 @@ def function_to_graph(f, conversion_map, param_value_hints, owner_type=None): return node, conversion_map.name_map[f] -def _static_analysis_pass(node, namespace, value_hints): +def _static_analysis_pass(node, ctx): node = access.resolve(node) - node = live_values.resolve(node, namespace, config.PYTHON_LITERALS) - node = type_info.resolve(node, value_hints) + node = live_values.resolve(node, ctx.namespace, config.PYTHON_LITERALS) + node = type_info.resolve(node, ctx) return node -def node_to_graph(node, namer, namespace, value_hints): +def node_to_graph(node, ctx, nocompile_decorators): """Convert Python code to equivalent TF graph mode code. Args: node: A Python AST node representing the code to convert. - namer: A naming.Namer object. - namespace: Dict mapping symbol names to their corresponding live objects. - value_hints: A dict containing value hints for symbols like function - parameters. + ctx: An EntityContext object. + nocompile_decorators: A tuple containing decorators to be stripped from + functions during conversion. Returns: A tuple (node, deps): * node: A Python ast node, representing the converted code. - * deps: A set of strings, the fully qualified names of object + * deps: A set of strings, the fully qualified names of entity dependencies that this node has. """ + # TODO(mdan): Verify arguments for correctness. + # TODO(mdan): Factor out common elements. # These include: # * keeping track of symbols that have been created @@ -212,27 +240,30 @@ def node_to_graph(node, namer, namespace, value_hints): # tree, which must be accounted. Although less efficient, it is most robust # to re-run the analysis. - node = _static_analysis_pass(node, namespace, value_hints) - node = break_canonicalization.transform(node, namer) + node = _static_analysis_pass(node, ctx) + node = decorators.transform(node, nocompile_decorators) + node = break_canonicalization.transform(node, ctx.namer) # Note: sequencing continue canonicalization before for loop one avoids # dealing with the extra loop increment operation that the for # canonicalization creates. - node = continue_canonicalization.transform(node, namer) - namespace['len'] = len + node = continue_canonicalization.transform(node, ctx.namer) + ctx.namespace['len'] = len - node = _static_analysis_pass(node, namespace, value_hints) - node = for_canonicalization.transform(node, namer) + node = _static_analysis_pass(node, ctx) + node = for_canonicalization.transform(node, ctx.namer) # for_canonicalization may insert new global references. node = builtin_functions.transform(node) # builtin_functions may insert new global references. - namespace['print'] = print + ctx.namespace['print'] = print - node = _static_analysis_pass(node, namespace, value_hints) + node = _static_analysis_pass(node, ctx) node = print_functions.transform(node) - node = call_trees.transform(node, namer, config.DEFAULT_UNCOMPILED_MODULES) - node = control_flow.transform(node, namer) + node = call_trees.transform(node, ctx.namer, ctx.namespace, + config.DEFAULT_UNCOMPILED_MODULES, + nocompile_decorators) + node = control_flow.transform(node, ctx.namer) node = logical_expressions.transform(node) - node = side_effect_guards.transform(node, namer) + node = side_effect_guards.transform(node, ctx.namer) return node diff --git a/tensorflow/contrib/py2tf/conversion_test.py b/tensorflow/contrib/py2tf/conversion_test.py index d76f14180951217810a3f5ddbca6423d8be63ce3..26f915f4f46e54c9648ae6b35415c4e2639af774 100644 --- a/tensorflow/contrib/py2tf/conversion_test.py +++ b/tensorflow/contrib/py2tf/conversion_test.py @@ -26,28 +26,31 @@ from tensorflow.python.platform import test class ConversionTest(test.TestCase): - def test_object_to_graph_unsupported_types(self): + def test_entity_to_graph_unsupported_types(self): with self.assertRaises(ValueError): - conversion.object_to_graph('dummy', {}, {}) + conversion_map = conversion.ConversionMap(True, (), ()) + conversion.entity_to_graph('dummy', conversion_map, None, None) + + def test_entity_to_graph_callable(self): - def test_object_to_graph_callable(self): def f(a): return a - conversion_map = conversion.ConversionMap() - ast, new_name = conversion.object_to_graph(f, conversion_map, {}) + conversion_map = conversion.ConversionMap(True, (), ()) + ast, new_name = conversion.entity_to_graph(f, conversion_map, None, None) self.assertTrue(isinstance(ast, gast.FunctionDef), ast) self.assertEqual('tf__f', new_name) - def test_object_to_graph_call_tree(self): + def test_entity_to_graph_call_tree(self): + def g(a): return a def f(a): return g(a) - conversion_map = conversion.ConversionMap() - conversion.object_to_graph(f, conversion_map, {}) + conversion_map = conversion.ConversionMap(True, (), ()) + conversion.entity_to_graph(f, conversion_map, None, None) self.assertTrue(f in conversion_map.dependency_cache) self.assertTrue(g in conversion_map.dependency_cache) diff --git a/tensorflow/contrib/py2tf/convert/BUILD b/tensorflow/contrib/py2tf/converters/BUILD similarity index 79% rename from tensorflow/contrib/py2tf/convert/BUILD rename to tensorflow/contrib/py2tf/converters/BUILD index 0eb7998dc4c6acdc7760024b8e4359360b60c23e..2b0a1234e6934c8a0ee73316a2fb7bfdb991f7e9 100644 --- a/tensorflow/contrib/py2tf/convert/BUILD +++ b/tensorflow/contrib/py2tf/converters/BUILD @@ -15,13 +15,14 @@ filegroup( ) py_library( - name = "convert", + name = "converters", srcs = [ "break_canonicalization.py", "builtin_functions.py", "call_trees.py", "continue_canonicalization.py", "control_flow.py", + "decorators.py", "for_canonicalization.py", "logical_expressions.py", "print_functions.py", @@ -34,13 +35,26 @@ py_library( ], ) +py_library( + name = "test_lib", + srcs = [ + "converter_test_base.py", + ], + srcs_version = "PY2AND3", + visibility = ["//tensorflow:__subpackages__"], + deps = [ + ":converters", + "//tensorflow/contrib/py2tf/pyct/static_analysis", + "@gast_archive//:gast", + ], +) + py_test( name = "break_canonicalization_test", srcs = ["break_canonicalization_test.py"], deps = [ - ":convert", + ":test_lib", "//tensorflow/contrib/py2tf/pyct", - "//tensorflow/contrib/py2tf/pyct/static_analysis", "//tensorflow/python:client_testlib", ], ) @@ -49,9 +63,8 @@ py_test( name = "call_trees_test", srcs = ["call_trees_test.py"], deps = [ - ":convert", + ":test_lib", "//tensorflow/contrib/py2tf/pyct", - "//tensorflow/contrib/py2tf/pyct/static_analysis", "//tensorflow/python:client_testlib", ], ) @@ -60,9 +73,8 @@ py_test( name = "continue_canonicalization_test", srcs = ["continue_canonicalization_test.py"], deps = [ - ":convert", + ":test_lib", "//tensorflow/contrib/py2tf/pyct", - "//tensorflow/contrib/py2tf/pyct/static_analysis", "//tensorflow/python:client_testlib", ], ) @@ -71,9 +83,8 @@ py_test( name = "control_flow_test", srcs = ["control_flow_test.py"], deps = [ - ":convert", + ":test_lib", "//tensorflow/contrib/py2tf/pyct", - "//tensorflow/contrib/py2tf/pyct/static_analysis", "//tensorflow/python:client_testlib", ], ) @@ -82,9 +93,8 @@ py_test( name = "builtin_functions_test", srcs = ["builtin_functions_test.py"], deps = [ - ":convert", + ":test_lib", "//tensorflow/contrib/py2tf/pyct", - "//tensorflow/contrib/py2tf/pyct/static_analysis", "//tensorflow/python:client_testlib", ], ) @@ -93,9 +103,8 @@ py_test( name = "for_canonicalization_test", srcs = ["for_canonicalization_test.py"], deps = [ - ":convert", + ":test_lib", "//tensorflow/contrib/py2tf/pyct", - "//tensorflow/contrib/py2tf/pyct/static_analysis", "//tensorflow/python:client_testlib", ], ) @@ -104,9 +113,8 @@ py_test( name = "logical_expressions_test", srcs = ["logical_expressions_test.py"], deps = [ - ":convert", + ":test_lib", "//tensorflow/contrib/py2tf/pyct", - "//tensorflow/contrib/py2tf/pyct/static_analysis", "//tensorflow/python:client_testlib", ], ) @@ -115,9 +123,8 @@ py_test( name = "print_functions_test", srcs = ["print_functions_test.py"], deps = [ - ":convert", + ":test_lib", "//tensorflow/contrib/py2tf/pyct", - "//tensorflow/contrib/py2tf/pyct/static_analysis", "//tensorflow/python:client_testlib", "@gast_archive//:gast", ], @@ -127,9 +134,8 @@ py_test( name = "side_effect_guards_test", srcs = ["side_effect_guards_test.py"], deps = [ - ":convert", + ":test_lib", "//tensorflow/contrib/py2tf/pyct", - "//tensorflow/contrib/py2tf/pyct/static_analysis", "//tensorflow/python:client_testlib", ], ) diff --git a/tensorflow/contrib/py2tf/convert/__init__.py b/tensorflow/contrib/py2tf/converters/__init__.py similarity index 100% rename from tensorflow/contrib/py2tf/convert/__init__.py rename to tensorflow/contrib/py2tf/converters/__init__.py diff --git a/tensorflow/contrib/py2tf/convert/break_canonicalization.py b/tensorflow/contrib/py2tf/converters/break_canonicalization.py similarity index 100% rename from tensorflow/contrib/py2tf/convert/break_canonicalization.py rename to tensorflow/contrib/py2tf/converters/break_canonicalization.py diff --git a/tensorflow/contrib/py2tf/convert/break_canonicalization_test.py b/tensorflow/contrib/py2tf/converters/break_canonicalization_test.py similarity index 84% rename from tensorflow/contrib/py2tf/convert/break_canonicalization_test.py rename to tensorflow/contrib/py2tf/converters/break_canonicalization_test.py index 23c4c4d3e23e3e8eaafbafe9166d8c9618701fa5..b5ba2ad923dfeb73b38169494f6c7ea16ee815f1 100644 --- a/tensorflow/contrib/py2tf/convert/break_canonicalization_test.py +++ b/tensorflow/contrib/py2tf/converters/break_canonicalization_test.py @@ -18,11 +18,10 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.py2tf.convert import break_canonicalization -from tensorflow.contrib.py2tf.convert import control_flow +from tensorflow.contrib.py2tf.converters import break_canonicalization +from tensorflow.contrib.py2tf.converters import control_flow +from tensorflow.contrib.py2tf.converters import converter_test_base from tensorflow.contrib.py2tf.pyct import compiler -from tensorflow.contrib.py2tf.pyct import parser -from tensorflow.contrib.py2tf.pyct.static_analysis import access from tensorflow.python.platform import test @@ -32,12 +31,7 @@ class TestNamer(control_flow.SymbolNamer): return name_root -class BreakCanonicalizationTest(test.TestCase): - - def _parse_and_analyze(self, test_fn, namespace): - node = parser.parse_object(test_fn) - node = access.resolve(node) - return node +class BreakCanonicalizationTest(converter_test_base.TestCase): def test_basic_break(self): @@ -50,7 +44,7 @@ class BreakCanonicalizationTest(test.TestCase): v.append(x) return v - node = self._parse_and_analyze(test_fn, {}) + node = self.parse_and_analyze(test_fn, {}, include_type_analysis=False) node = break_canonicalization.transform(node, TestNamer()) result = compiler.ast_to_object(node) @@ -82,7 +76,7 @@ class BreakCanonicalizationTest(test.TestCase): v.append(x) return v - node = self._parse_and_analyze(test_fn, {}) + node = self.parse_and_analyze(test_fn, {}, include_type_analysis=False) node = break_canonicalization.transform(node, TestNamer()) result = compiler.ast_to_object(node) @@ -110,7 +104,7 @@ class BreakCanonicalizationTest(test.TestCase): v.append(x) return v, u, w - node = self._parse_and_analyze(test_fn, {}) + node = self.parse_and_analyze(test_fn, {}, include_type_analysis=False) node = break_canonicalization.transform(node, TestNamer()) result = compiler.ast_to_object(node) diff --git a/tensorflow/contrib/py2tf/convert/builtin_functions.py b/tensorflow/contrib/py2tf/converters/builtin_functions.py similarity index 100% rename from tensorflow/contrib/py2tf/convert/builtin_functions.py rename to tensorflow/contrib/py2tf/converters/builtin_functions.py diff --git a/tensorflow/contrib/py2tf/convert/builtin_functions_test.py b/tensorflow/contrib/py2tf/converters/builtin_functions_test.py similarity index 68% rename from tensorflow/contrib/py2tf/convert/builtin_functions_test.py rename to tensorflow/contrib/py2tf/converters/builtin_functions_test.py index 633602f4d49792c45826afd8646593e280e35d12..b5358da6bc0be06ec1f59d0ef58d926289b5b78f 100644 --- a/tensorflow/contrib/py2tf/convert/builtin_functions_test.py +++ b/tensorflow/contrib/py2tf/converters/builtin_functions_test.py @@ -18,32 +18,22 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.py2tf.convert import builtin_functions +from tensorflow.contrib.py2tf.converters import builtin_functions +from tensorflow.contrib.py2tf.converters import converter_test_base from tensorflow.contrib.py2tf.pyct import compiler -from tensorflow.contrib.py2tf.pyct import parser -from tensorflow.contrib.py2tf.pyct.static_analysis import access -from tensorflow.contrib.py2tf.pyct.static_analysis import live_values -from tensorflow.contrib.py2tf.pyct.static_analysis import type_info from tensorflow.python.framework import constant_op from tensorflow.python.ops import array_ops from tensorflow.python.platform import test -class BuiltinFunctionsTest(test.TestCase): - - def _parse_and_analyze(self, test_fn, namespace): - node = parser.parse_object(test_fn) - node = access.resolve(node) - node = live_values.resolve(node, namespace, {}) - node = type_info.resolve(node, {}) - return node +class BuiltinFunctionsTest(converter_test_base.TestCase): def test_len(self): def test_fn(a): return len(a) - node = self._parse_and_analyze(test_fn, {'len': len}) + node = self.parse_and_analyze(test_fn, {'len': len}) node = builtin_functions.transform(node) result = compiler.ast_to_object(node) setattr(result, 'tf', array_ops) diff --git a/tensorflow/contrib/py2tf/convert/call_trees.py b/tensorflow/contrib/py2tf/converters/call_trees.py similarity index 59% rename from tensorflow/contrib/py2tf/convert/call_trees.py rename to tensorflow/contrib/py2tf/converters/call_trees.py index 92c3439101ed9d3fe54147346be3cd6a1c0f9d8c..df071f596fc31502a98182f27bb66c54f71d2572 100644 --- a/tensorflow/contrib/py2tf/convert/call_trees.py +++ b/tensorflow/contrib/py2tf/converters/call_trees.py @@ -27,6 +27,7 @@ import types import gast from tensorflow.contrib.py2tf.pyct import anno +from tensorflow.contrib.py2tf.pyct import parser from tensorflow.contrib.py2tf.pyct import templates @@ -64,16 +65,75 @@ class FunctionNamer(object): class CallTreeTransformer(gast.NodeTransformer): """Transforms the call tree by renaming transformed symbols.""" - def __init__(self, namer, uncompiled_modules): + def __init__(self, namer, namespace, uncompiled_modules, + nocompile_decorators): self.namer = namer + self.namespace = namespace self.uncompiled_modules = uncompiled_modules + self.nocompile_decorators = nocompile_decorators # pylint:disable=invalid-name - def _should_compile(self, fqn): + def _resolve_name(self, node): + if isinstance(node, gast.Call): + return self._resolve_name(node.func) + if isinstance(node, gast.Name): + return self.namespace.get(node.id) + if isinstance(node, gast.Attribute): + parent = self._resolve_name(node.value) + if parent is not None: + return getattr(parent, node.attr) + return None + raise ValueError(node) + + def _try_resolve_target(self, node): + """Works for methods of objects of known type.""" + if anno.hasanno(node, 'live_val'): + return anno.getanno(node, 'live_val') + if isinstance(node, gast.Attribute) and anno.hasanno(node, 'type'): + member = getattr(anno.getanno(node, 'type'), node.attr) + return member + return None + + def _should_compile(self, node, fqn): for i in range(1, len(fqn)): if fqn[:i] in self.uncompiled_modules: return False + + # Check for local decorations + if anno.hasanno(node, 'graph_ready'): + return False + + # The decorators themselves are not to be converted. + # If present, the decorators should appear as static functions. + target_obj = self._try_resolve_target(node.func) + if target_obj is not None: + # This attribute is set by the decorator itself. + # TODO(mdan): This may not play nicely with other wrapping decorators. + if hasattr(target_obj, '__pyct_is_compile_decorator'): + return False + + if target_obj in self.nocompile_decorators: + return False + + # Inspect the target function decorators. If any include a @convert + # or @graph_ready annotation, then they must be called as they are. + # TODO(mdan): This may be quite heavy. + # To parse and re-analize each function for every call site could be quite + # wasteful. Maybe we could cache the parsed AST? + try: + target_node = parser.parse_object(target_obj).body[0] + except TypeError: + # Functions whose source we cannot access are compilable (e.g. wrapped + # to py_func). + return True + + for dec in target_node.decorator_list: + decorator_fn = self._resolve_name(dec) + if (decorator_fn is not None and + decorator_fn in self.nocompile_decorators): + return False + return True def _rename_compilable_function(self, node): @@ -82,15 +142,15 @@ class CallTreeTransformer(gast.NodeTransformer): target_obj = anno.getanno(node.func, 'live_val') target_fqn = anno.getanno(node.func, 'fqn') - if not self._should_compile(target_fqn): + if not self._should_compile(node, target_fqn): return node if anno.hasanno(node, 'is_constructor'): new_name = self.namer.compiled_class_name( - '.'.join(target_fqn), live_object=target_obj) + '__'.join(target_fqn), live_object=target_obj) else: new_name = self.namer.compiled_function_name( - '.'.join(target_fqn), live_object=target_obj) + '__'.join(target_fqn), live_object=target_obj) node.func = gast.Name(id=new_name, ctx=gast.Load(), annotation=None) return node @@ -101,15 +161,24 @@ class CallTreeTransformer(gast.NodeTransformer): assert anno.hasanno(node.func, 'type') target_type = anno.getanno(node.func, 'type') - if not self._should_compile(type_fqn): + if not self._should_compile(node, type_fqn): return node # TODO(mdan): We should not assume that the namer only needs the # member function name. + method_name = node.func.attr + method_object = getattr(target_type, method_name) new_name = self.namer.compiled_function_name( - node.func.attr, live_object=None, owner_type=target_type) - node.func.attr = new_name - + method_name, live_object=method_object, owner_type=target_type) + if new_name != node.func.attr: + # If a member function call is renamed, then the new function is no + # longer bound to the target object. We then refactor the call from: + # foo.bar(...) + # to: + # renamed_foo(bar, ...) + # TODO(mdan): This risks causing duplication, if target_type is renamed. + node.args = [node.func.value] + node.args + node.func = gast.Name(new_name, gast.Load(), None) return node def _wrap_to_py_func_no_return(self, node): @@ -136,6 +205,7 @@ class CallTreeTransformer(gast.NodeTransformer): wrapper=gast.Name(wrapper_name, gast.Load(), None), args=args) anno.setanno(call_expr.value, 'args_scope', args_scope) + # TODO(mdan): Rename this annotation to 'graph_ready' anno.setanno(wrapper_def, 'skip_processing', True) return (wrapper_def, call_expr) @@ -151,7 +221,7 @@ class CallTreeTransformer(gast.NodeTransformer): if not self._function_is_compilable(target_obj): if anno.hasanno(node.value.func, 'fqn'): target_fqn = anno.getanno(node.value.func, 'fqn') - if not self._should_compile(target_fqn): + if not self._should_compile(node.value, target_fqn): return node node = self._wrap_to_py_func_no_return(node.value) return node @@ -163,6 +233,17 @@ class CallTreeTransformer(gast.NodeTransformer): return node def visit_Call(self, node): + # If the function is wrapped by one of the marker decorators, + # consider it graph ready. + if anno.hasanno(node.func, 'live_val'): + target_obj = anno.getanno(node.func, 'live_val') + if target_obj in self.nocompile_decorators: + if len(node.args) < 1: + raise ValueError( + 'Found call to decorator function "%s", but it had no arguments. ' + 'A decorator needs at least an argument.') + anno.setanno(node.args[0], 'graph_ready', True) + self.generic_visit(node) if anno.hasanno(node.func, 'live_val'): target_obj = anno.getanno(node.func, 'live_val') @@ -180,20 +261,24 @@ class CallTreeTransformer(gast.NodeTransformer): # pylint:enable=invalid-name -def transform(node, namer, uncompiled_modules): +def transform(node, namer, namespace, uncompiled_modules, nocompile_decorators): """Transform function call to the compiled counterparts. Args: node: AST to transform. namer: FunctionNamer-like. + namespace: Dict mapping symbol names to their corresponding live objects. uncompiled_modules: set of string tuples, each tuple represents the fully qualified name of a package containing functions that will not be compiled. + nocompile_decorators: A tuple containing decorators to be stripped from + functions during conversion. Returns: A tuple (node, new_names): node: The transformed AST new_names: set(string), containing any newly-generated names """ - transformer = CallTreeTransformer(namer, uncompiled_modules) + transformer = CallTreeTransformer(namer, namespace, uncompiled_modules, + nocompile_decorators) node = transformer.visit(node) return node diff --git a/tensorflow/contrib/py2tf/convert/call_trees_test.py b/tensorflow/contrib/py2tf/converters/call_trees_test.py similarity index 73% rename from tensorflow/contrib/py2tf/convert/call_trees_test.py rename to tensorflow/contrib/py2tf/converters/call_trees_test.py index 38c701eaadee8ad4df006a950192d51d78c799fe..8cb8d7be0f122ed124b0fda69c745a349543a16d 100644 --- a/tensorflow/contrib/py2tf/convert/call_trees_test.py +++ b/tensorflow/contrib/py2tf/converters/call_trees_test.py @@ -18,12 +18,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.py2tf.convert import call_trees +from tensorflow.contrib.py2tf.converters import call_trees +from tensorflow.contrib.py2tf.converters import converter_test_base from tensorflow.contrib.py2tf.pyct import compiler -from tensorflow.contrib.py2tf.pyct import parser -from tensorflow.contrib.py2tf.pyct.static_analysis import access -from tensorflow.contrib.py2tf.pyct.static_analysis import live_values -from tensorflow.contrib.py2tf.pyct.static_analysis import type_info from tensorflow.python.framework import constant_op from tensorflow.python.ops import math_ops from tensorflow.python.platform import test @@ -35,14 +32,7 @@ class TestNamer(call_trees.FunctionNamer): return 'renamed_%s' % original_name -class CallTreesTest(test.TestCase): - - def _parse_and_analyze(self, test_fn, namespace): - node = parser.parse_object(test_fn) - node = access.resolve(node) - node = live_values.resolve(node, namespace, {}) - node = type_info.resolve(node, {}) - return node +class CallTreesTest(converter_test_base.TestCase): def test_basic(self): @@ -55,8 +45,8 @@ class CallTreesTest(test.TestCase): def test_fn_2(a): return test_fn_1(a) + 1 - node = self._parse_and_analyze(test_fn_2, {'test_fn_1': test_fn_1}) - node = call_trees.transform(node, TestNamer(), set()) + node = self.parse_and_analyze(test_fn_2, {'test_fn_1': test_fn_1}) + node = call_trees.transform(node, TestNamer(), {}, (), ()) result = compiler.ast_to_object(node) # Only test_fn_2 is transformed, so we'll insert renamed_test_fn_1 manually. setattr(result, 'renamed_test_fn_1', renamed_test_fn_1) @@ -70,13 +60,13 @@ class CallTreesTest(test.TestCase): a = math_ops.add(a, constant_op.constant(1)) return a - node = self._parse_and_analyze(test_fn, { + node = self.parse_and_analyze(test_fn, { 'math_ops': math_ops, 'constant_op': constant_op }) - node = call_trees.transform(node, TestNamer(), + node = call_trees.transform(node, TestNamer(), {}, set(((math_ops.__name__,), - (constant_op.__name__,)))) + (constant_op.__name__,))), ()) result = compiler.ast_to_object(node) setattr(result, 'math_ops', math_ops) setattr(result, 'constant_op', constant_op) diff --git a/tensorflow/contrib/py2tf/convert/continue_canonicalization.py b/tensorflow/contrib/py2tf/converters/continue_canonicalization.py similarity index 100% rename from tensorflow/contrib/py2tf/convert/continue_canonicalization.py rename to tensorflow/contrib/py2tf/converters/continue_canonicalization.py diff --git a/tensorflow/contrib/py2tf/convert/continue_canonicalization_test.py b/tensorflow/contrib/py2tf/converters/continue_canonicalization_test.py similarity index 83% rename from tensorflow/contrib/py2tf/convert/continue_canonicalization_test.py rename to tensorflow/contrib/py2tf/converters/continue_canonicalization_test.py index a041ff4641fef6c6d5cd7c502d1196dde26c55e0..c1fe903a2dd332626c8e64826652723c30ac412a 100644 --- a/tensorflow/contrib/py2tf/convert/continue_canonicalization_test.py +++ b/tensorflow/contrib/py2tf/converters/continue_canonicalization_test.py @@ -18,11 +18,10 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.py2tf.convert import continue_canonicalization -from tensorflow.contrib.py2tf.convert import control_flow +from tensorflow.contrib.py2tf.converters import continue_canonicalization +from tensorflow.contrib.py2tf.converters import control_flow +from tensorflow.contrib.py2tf.converters import converter_test_base from tensorflow.contrib.py2tf.pyct import compiler -from tensorflow.contrib.py2tf.pyct import parser -from tensorflow.contrib.py2tf.pyct.static_analysis import access from tensorflow.python.platform import test @@ -32,12 +31,7 @@ class TestNamer(control_flow.SymbolNamer): return name_root -class ContinueCanonicalizationTest(test.TestCase): - - def _parse_and_analyze(self, test_fn, namespace): - node = parser.parse_object(test_fn) - node = access.resolve(node) - return node +class ContinueCanonicalizationTest(converter_test_base.TestCase): def test_basic_continue(self): @@ -50,7 +44,7 @@ class ContinueCanonicalizationTest(test.TestCase): v.append(x) return v - node = self._parse_and_analyze(test_fn, {}) + node = self.parse_and_analyze(test_fn, {}, include_type_analysis=False) node = continue_canonicalization.transform(node, TestNamer()) result = compiler.ast_to_object(node) @@ -71,7 +65,7 @@ class ContinueCanonicalizationTest(test.TestCase): v.append(x) return v - node = self._parse_and_analyze(test_fn, {}) + node = self.parse_and_analyze(test_fn, {}, include_type_analysis=False) node = continue_canonicalization.transform(node, TestNamer()) result = compiler.ast_to_object(node) @@ -97,7 +91,7 @@ class ContinueCanonicalizationTest(test.TestCase): v.append(x) return v, u, w - node = self._parse_and_analyze(test_fn, {}) + node = self.parse_and_analyze(test_fn, {}, include_type_analysis=False) node = continue_canonicalization.transform(node, TestNamer()) result = compiler.ast_to_object(node) diff --git a/tensorflow/contrib/py2tf/convert/control_flow.py b/tensorflow/contrib/py2tf/converters/control_flow.py similarity index 100% rename from tensorflow/contrib/py2tf/convert/control_flow.py rename to tensorflow/contrib/py2tf/converters/control_flow.py diff --git a/tensorflow/contrib/py2tf/convert/control_flow_test.py b/tensorflow/contrib/py2tf/converters/control_flow_test.py similarity index 79% rename from tensorflow/contrib/py2tf/convert/control_flow_test.py rename to tensorflow/contrib/py2tf/converters/control_flow_test.py index 121af4ee949152cb6df7496a4a0c64f13f65a5eb..054e33750dbae86559a9575dfecde64132b9a2cd 100644 --- a/tensorflow/contrib/py2tf/convert/control_flow_test.py +++ b/tensorflow/contrib/py2tf/converters/control_flow_test.py @@ -18,12 +18,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.py2tf.convert import control_flow +from tensorflow.contrib.py2tf.converters import control_flow +from tensorflow.contrib.py2tf.converters import converter_test_base from tensorflow.contrib.py2tf.pyct import compiler -from tensorflow.contrib.py2tf.pyct import parser -from tensorflow.contrib.py2tf.pyct.static_analysis import access -from tensorflow.contrib.py2tf.pyct.static_analysis import live_values -from tensorflow.contrib.py2tf.pyct.static_analysis import type_info from tensorflow.python.framework import constant_op from tensorflow.python.ops import control_flow_ops from tensorflow.python.platform import test @@ -40,14 +37,7 @@ class TestNamer(control_flow.SymbolNamer): i += 1 -class ControlFlowTest(test.TestCase): - - def _parse_and_analyze(self, test_fn, namespace): - node = parser.parse_object(test_fn) - node = access.resolve(node) - node = live_values.resolve(node, namespace, {}) - node = type_info.resolve(node, {}) - return node +class ControlFlowTest(converter_test_base.TestCase): def test_simple_while(self): @@ -59,7 +49,7 @@ class ControlFlowTest(test.TestCase): i += 1 return s, i, n - node = self._parse_and_analyze(test_fn, {}) + node = self.parse_and_analyze(test_fn, {}) node = control_flow.transform(node, TestNamer()) result = compiler.ast_to_object(node) setattr(result, 'tf', control_flow_ops) @@ -75,7 +65,7 @@ class ControlFlowTest(test.TestCase): n -= 1 return n - node = self._parse_and_analyze(test_fn, {}) + node = self.parse_and_analyze(test_fn, {}) node = control_flow.transform(node, TestNamer()) result = compiler.ast_to_object(node) setattr(result, 'tf', control_flow_ops) @@ -94,7 +84,7 @@ class ControlFlowTest(test.TestCase): b = 2 * n return a, b - node = self._parse_and_analyze(test_fn, {}) + node = self.parse_and_analyze(test_fn, {}) node = control_flow.transform(node, TestNamer()) result = compiler.ast_to_object(node) setattr(result, 'tf', control_flow_ops) @@ -112,7 +102,7 @@ class ControlFlowTest(test.TestCase): n = -n return n - node = self._parse_and_analyze(test_fn, {}) + node = self.parse_and_analyze(test_fn, {}) node = control_flow.transform(node, TestNamer()) result = compiler.ast_to_object(node) setattr(result, 'tf', control_flow_ops) diff --git a/tensorflow/contrib/py2tf/converters/converter_test_base.py b/tensorflow/contrib/py2tf/converters/converter_test_base.py new file mode 100644 index 0000000000000000000000000000000000000000..ed006bad6d833b3682f819e87aa8b9c279372e51 --- /dev/null +++ b/tensorflow/contrib/py2tf/converters/converter_test_base.py @@ -0,0 +1,48 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Base class for tests in this module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.py2tf.pyct import context +from tensorflow.contrib.py2tf.pyct import parser +from tensorflow.contrib.py2tf.pyct.static_analysis import access +from tensorflow.contrib.py2tf.pyct.static_analysis import live_values +from tensorflow.contrib.py2tf.pyct.static_analysis import type_info +from tensorflow.python.platform import test + + +class TestCase(test.TestCase): + + def parse_and_analyze(self, + test_fn, + namespace, + arg_types=None, + include_type_analysis=True): + ctx = context.EntityContext( + namer=None, + source_code=None, + source_file=None, + namespace=namespace, + arg_values=None, + arg_types=arg_types) + node = parser.parse_object(test_fn) + node = access.resolve(node) + node = live_values.resolve(node, namespace, {}) + if include_type_analysis: + node = type_info.resolve(node, ctx) + return node diff --git a/tensorflow/contrib/py2tf/converters/decorators.py b/tensorflow/contrib/py2tf/converters/decorators.py new file mode 100644 index 0000000000000000000000000000000000000000..a4313bfa510a81463a218cd21b41d9a7f43d1892 --- /dev/null +++ b/tensorflow/contrib/py2tf/converters/decorators.py @@ -0,0 +1,56 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Handles decorators.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import gast + +from tensorflow.contrib.py2tf.pyct import anno +from tensorflow.contrib.py2tf.pyct import pretty_printer + + +class DecoratorsTransformer(gast.NodeTransformer): + """Converts or removes decorators.""" + + def __init__(self, remove_decorators): + self.remove_decorators = remove_decorators + + # pylint:disable=invalid-name + + def visit_FunctionDef(self, node): + self.generic_visit(node) + for dec in node.decorator_list: + if isinstance(dec, gast.Call): + dec = dec.func + if not anno.hasanno(dec, 'live_val'): + raise ValueError( + 'Could not resolve decorator: %s' % pretty_printer.fmt(dec)) + dec_value = anno.getanno(dec, 'live_val') + if dec_value in self.remove_decorators: + continue + raise ValueError('Dont know how to convert decorators for now.') + node.decorator_list = [] + return node + + # pylint:enable=invalid-name + + +def transform(node, remove_decorators): + transformer = DecoratorsTransformer(remove_decorators) + node = transformer.visit(node) + return node diff --git a/tensorflow/contrib/py2tf/convert/for_canonicalization.py b/tensorflow/contrib/py2tf/converters/for_canonicalization.py similarity index 100% rename from tensorflow/contrib/py2tf/convert/for_canonicalization.py rename to tensorflow/contrib/py2tf/converters/for_canonicalization.py diff --git a/tensorflow/contrib/py2tf/convert/for_canonicalization_test.py b/tensorflow/contrib/py2tf/converters/for_canonicalization_test.py similarity index 75% rename from tensorflow/contrib/py2tf/convert/for_canonicalization_test.py rename to tensorflow/contrib/py2tf/converters/for_canonicalization_test.py index 8de2d1a0f82cbb2f995a83fcdc1521ebf172e1ce..a6e6350fd45e9c9575af9c12d3d0c4e9b89bee41 100644 --- a/tensorflow/contrib/py2tf/convert/for_canonicalization_test.py +++ b/tensorflow/contrib/py2tf/converters/for_canonicalization_test.py @@ -18,11 +18,10 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.py2tf.convert import control_flow -from tensorflow.contrib.py2tf.convert import for_canonicalization +from tensorflow.contrib.py2tf.converters import control_flow +from tensorflow.contrib.py2tf.converters import converter_test_base +from tensorflow.contrib.py2tf.converters import for_canonicalization from tensorflow.contrib.py2tf.pyct import compiler -from tensorflow.contrib.py2tf.pyct import parser -from tensorflow.contrib.py2tf.pyct.static_analysis import access from tensorflow.python.platform import test @@ -32,12 +31,7 @@ class TestNamer(control_flow.SymbolNamer): return name_root -class ControlFlowTest(test.TestCase): - - def _parse_and_analyze(self, test_fn, namespace): - node = parser.parse_object(test_fn) - node = access.resolve(node) - return node +class ControlFlowTest(converter_test_base.TestCase): def test_basic_for(self): @@ -47,7 +41,7 @@ class ControlFlowTest(test.TestCase): s += e return s - node = self._parse_and_analyze(test_fn, {}) + node = self.parse_and_analyze(test_fn, {}) node = for_canonicalization.transform(node, TestNamer()) result = compiler.ast_to_object(node) diff --git a/tensorflow/contrib/py2tf/convert/logical_expressions.py b/tensorflow/contrib/py2tf/converters/logical_expressions.py similarity index 100% rename from tensorflow/contrib/py2tf/convert/logical_expressions.py rename to tensorflow/contrib/py2tf/converters/logical_expressions.py diff --git a/tensorflow/contrib/py2tf/convert/logical_expressions_test.py b/tensorflow/contrib/py2tf/converters/logical_expressions_test.py similarity index 85% rename from tensorflow/contrib/py2tf/convert/logical_expressions_test.py rename to tensorflow/contrib/py2tf/converters/logical_expressions_test.py index f07fa017b9dacd7a998f04fa7f6fdd83fccb1811..d711065099b24ad814104e6460e6ca551b31b3e6 100644 --- a/tensorflow/contrib/py2tf/convert/logical_expressions_test.py +++ b/tensorflow/contrib/py2tf/converters/logical_expressions_test.py @@ -18,21 +18,21 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.py2tf.convert import logical_expressions +from tensorflow.contrib.py2tf.converters import converter_test_base +from tensorflow.contrib.py2tf.converters import logical_expressions from tensorflow.contrib.py2tf.pyct import compiler -from tensorflow.contrib.py2tf.pyct import parser from tensorflow.python.ops import math_ops from tensorflow.python.platform import test -class GradientsFunctionTest(test.TestCase): +class GradientsFunctionTest(converter_test_base.TestCase): def test_equals(self): def test_fn(a, b): return a == b - node = parser.parse_object(test_fn) + node = self.parse_and_analyze(test_fn, {}) node = logical_expressions.transform(node) result = compiler.ast_to_object(node) setattr(result, 'tf', math_ops) @@ -46,7 +46,7 @@ class GradientsFunctionTest(test.TestCase): def test_fn(a, b, c): return (a or b) and (a or b or c) - node = parser.parse_object(test_fn) + node = self.parse_and_analyze(test_fn, {}) node = logical_expressions.transform(node) result = compiler.ast_to_object(node) setattr(result, 'tf', math_ops) diff --git a/tensorflow/contrib/py2tf/convert/print_functions.py b/tensorflow/contrib/py2tf/converters/print_functions.py similarity index 100% rename from tensorflow/contrib/py2tf/convert/print_functions.py rename to tensorflow/contrib/py2tf/converters/print_functions.py diff --git a/tensorflow/contrib/py2tf/convert/print_functions_test.py b/tensorflow/contrib/py2tf/converters/print_functions_test.py similarity index 65% rename from tensorflow/contrib/py2tf/convert/print_functions_test.py rename to tensorflow/contrib/py2tf/converters/print_functions_test.py index 65e592b66e9d0c08c7d2127ff40be8a0dc28ec6c..475196ce102955b350acf9bf94255997f875f62c 100644 --- a/tensorflow/contrib/py2tf/convert/print_functions_test.py +++ b/tensorflow/contrib/py2tf/converters/print_functions_test.py @@ -20,30 +20,20 @@ from __future__ import print_function import gast -from tensorflow.contrib.py2tf.convert import print_functions +from tensorflow.contrib.py2tf.converters import converter_test_base +from tensorflow.contrib.py2tf.converters import print_functions from tensorflow.contrib.py2tf.pyct import compiler -from tensorflow.contrib.py2tf.pyct import parser -from tensorflow.contrib.py2tf.pyct.static_analysis import access -from tensorflow.contrib.py2tf.pyct.static_analysis import live_values -from tensorflow.contrib.py2tf.pyct.static_analysis import type_info from tensorflow.python.platform import test -class PrintFunctionsTest(test.TestCase): - - def _parse_and_analyze(self, test_fn, namespace): - node = parser.parse_object(test_fn) - node = access.resolve(node) - node = live_values.resolve(node, namespace, {}) - node = type_info.resolve(node, {}) - return node +class PrintFunctionsTest(converter_test_base.TestCase): def test_transform(self): def test_fn(a): print(a) - node = self._parse_and_analyze(test_fn, {'print': print}) + node = self.parse_and_analyze(test_fn, {'print': print}) node = print_functions.transform(node) result = compiler.ast_to_object(node) diff --git a/tensorflow/contrib/py2tf/convert/side_effect_guards.py b/tensorflow/contrib/py2tf/converters/side_effect_guards.py similarity index 100% rename from tensorflow/contrib/py2tf/convert/side_effect_guards.py rename to tensorflow/contrib/py2tf/converters/side_effect_guards.py diff --git a/tensorflow/contrib/py2tf/convert/side_effect_guards_test.py b/tensorflow/contrib/py2tf/converters/side_effect_guards_test.py similarity index 73% rename from tensorflow/contrib/py2tf/convert/side_effect_guards_test.py rename to tensorflow/contrib/py2tf/converters/side_effect_guards_test.py index d932840186034c073512cbd1e253fc7676aa83e7..5c56973dc2ae5d1976a68f040772e856cdaeabf5 100644 --- a/tensorflow/contrib/py2tf/convert/side_effect_guards_test.py +++ b/tensorflow/contrib/py2tf/converters/side_effect_guards_test.py @@ -18,12 +18,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.py2tf.convert import side_effect_guards +from tensorflow.contrib.py2tf.converters import converter_test_base +from tensorflow.contrib.py2tf.converters import side_effect_guards from tensorflow.contrib.py2tf.pyct import compiler -from tensorflow.contrib.py2tf.pyct import parser -from tensorflow.contrib.py2tf.pyct.static_analysis import access -from tensorflow.contrib.py2tf.pyct.static_analysis import live_values -from tensorflow.contrib.py2tf.pyct.static_analysis import type_info from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import state_ops @@ -37,14 +34,7 @@ class TestNamer(side_effect_guards.SymbolNamer): return name_root -class SideEffectGuardsTest(test.TestCase): - - def _parse_and_analyze(self, test_fn, namespace): - node = parser.parse_object(test_fn) - node = access.resolve(node) - node = live_values.resolve(node, namespace, {}) - node = type_info.resolve(node, {}) - return node +class SideEffectGuardsTest(converter_test_base.TestCase): def test_transform(self): @@ -52,7 +42,7 @@ class SideEffectGuardsTest(test.TestCase): state_ops.assign(a, a + 1) return a - node = self._parse_and_analyze(test_fn, {'state_ops': state_ops}) + node = self.parse_and_analyze(test_fn, {'state_ops': state_ops}) node = side_effect_guards.transform(node, TestNamer()) result = compiler.ast_to_object(node) setattr(result, 'state_ops', state_ops) diff --git a/tensorflow/contrib/py2tf/naming.py b/tensorflow/contrib/py2tf/naming.py index 61772ec07b41d366769307982bf0376de9bb495e..a90758962b83e1616f7d727440eb7481c49343ad 100644 --- a/tensorflow/contrib/py2tf/naming.py +++ b/tensorflow/contrib/py2tf/naming.py @@ -34,8 +34,10 @@ class Namer(object): * side_effect_guards.SymbolNamer """ - def __init__(self, global_namespace, name_map=None): + def __init__(self, global_namespace, recursive, name_map, partial_types): self.global_namespace = global_namespace + self.recursive = recursive + self.partial_types = partial_types self.renamed_calls = {} if name_map is not None: @@ -54,6 +56,7 @@ class Namer(object): while new_name in self.global_namespace: n += 1 new_name = '%s_%d' % (new_name_root, n) + if live_object is not None: self.renamed_calls[live_object] = new_name self.generated_names.add(new_name) @@ -67,7 +70,9 @@ class Namer(object): if live_object is not None and live_object in self.renamed_calls: return self.renamed_calls[live_object] - if owner_type is None: + if not self.recursive: + new_name = original_name + elif owner_type is None or owner_type in self.partial_types: # Top level functions: rename new_name_root = 'tf__%s' % original_name new_name = new_name_root diff --git a/tensorflow/contrib/py2tf/naming_test.py b/tensorflow/contrib/py2tf/naming_test.py index 9403d9ae1f68d49ac19503b24fb86486cf197200..7bfc9b8733b6efc3ab440ae5a0614258ae395ad4 100644 --- a/tensorflow/contrib/py2tf/naming_test.py +++ b/tensorflow/contrib/py2tf/naming_test.py @@ -28,7 +28,7 @@ class NamerTest(test.TestCase): def bar(): pass - namer = naming.Namer(set()) + namer = naming.Namer({}, True, None, ()) self.assertEqual('tf__foo', namer.compiled_function_name('foo')) self.assertEqual('tf__bar', namer.compiled_function_name('bar', bar)) self.assertEqual({bar: 'tf__bar'}, namer.renamed_calls) @@ -38,7 +38,7 @@ class NamerTest(test.TestCase): def foo(): pass - namer = naming.Namer(set()) + namer = naming.Namer({}, True, None, ()) self.assertEqual('tf__foo', namer.compiled_function_name('foo', foo)) self.assertEqual('tf__foo', namer.compiled_function_name('foo', foo)) @@ -46,22 +46,22 @@ class NamerTest(test.TestCase): def foo(): pass - namer = naming.Namer(set(('tf__foo',))) + namer = naming.Namer({'tf__foo': 1}, True, None, ()) self.assertEqual('tf__foo_1', namer.compiled_function_name('foo', foo)) def test_new_symbol_tracks_names(self): - namer = naming.Namer(set()) + namer = naming.Namer({}, True, None, ()) self.assertEqual('temp', namer.new_symbol('temp', set())) self.assertItemsEqual(('temp',), namer.generated_names) def test_new_symbol_avoids_duplicates(self): - namer = naming.Namer(set()) + namer = naming.Namer({}, True, None, ()) self.assertEqual('temp', namer.new_symbol('temp', set())) self.assertEqual('temp_1', namer.new_symbol('temp', set())) self.assertItemsEqual(('temp', 'temp_1'), namer.generated_names) def test_new_symbol_avoids_conflicts(self): - namer = naming.Namer(set(('temp',))) + namer = naming.Namer({'temp': 1}, True, None, ()) # temp is reserved in the global namespace self.assertEqual('temp_1', namer.new_symbol('temp', set())) # temp_2 is reserved in the local namespace diff --git a/tensorflow/contrib/py2tf/pyct/BUILD b/tensorflow/contrib/py2tf/pyct/BUILD index b60ed918f5185e963de0877b13c3747d2b86f1e4..e0331dbc97c688ed34be362426b0b1f0d25931bc 100644 --- a/tensorflow/contrib/py2tf/pyct/BUILD +++ b/tensorflow/contrib/py2tf/pyct/BUILD @@ -20,9 +20,11 @@ py_library( "__init__.py", "anno.py", "compiler.py", + "context.py", "parser.py", "pretty_printer.py", "templates.py", + "transformer.py", ], srcs_version = "PY2AND3", visibility = ["//visibility:public"], diff --git a/tensorflow/contrib/py2tf/pyct/context.py b/tensorflow/contrib/py2tf/pyct/context.py new file mode 100644 index 0000000000000000000000000000000000000000..73f3613d09d01e9e643cfb8ee3a8e67e5c126455 --- /dev/null +++ b/tensorflow/contrib/py2tf/pyct/context.py @@ -0,0 +1,42 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Conversion context containers.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +class EntityContext(object): + """Contains information about an entity, like source code. + + Attributes: + namer: Namer that matches the contract of all converters. + source_code: The entity's source code. + source_file: The entity's source file. + namespace: Dict[str->*], containing symbols visible to the entity + (excluding parameters). + arg_values: Dict[str->*], containing parameter values, if known. + arg_types: Dict[str->*], containing parameter types, if known. + """ + + def __init__(self, namer, source_code, source_file, namespace, arg_values, + arg_types): + self.namer = namer + self.source_code = source_code + self.source_file = source_file + self.namespace = namespace + self.arg_values = {} if arg_values is None else arg_values + self.arg_types = {} if arg_types is None else arg_types diff --git a/tensorflow/contrib/py2tf/pyct/static_analysis/type_info.py b/tensorflow/contrib/py2tf/pyct/static_analysis/type_info.py index 3e545903261a41cac4dc9ac0e23f857e0be41f96..0042aa90ed218d42aedc720c94d1a478bc9f18f5 100644 --- a/tensorflow/contrib/py2tf/pyct/static_analysis/type_info.py +++ b/tensorflow/contrib/py2tf/pyct/static_analysis/type_info.py @@ -24,6 +24,7 @@ from __future__ import print_function import gast from tensorflow.contrib.py2tf.pyct import anno +from tensorflow.contrib.py2tf.pyct import transformer from tensorflow.python.util import tf_inspect @@ -69,7 +70,7 @@ class Scope(object): raise KeyError(name) -class TypeInfoResolver(gast.NodeTransformer): +class TypeInfoResolver(transformer.Base): """Annotates symbols with type information where possible. Nodes currently annotated: @@ -77,9 +78,9 @@ class TypeInfoResolver(gast.NodeTransformer): * Attribute (helps resolve object methods) """ - def __init__(self, value_hints): + def __init__(self, context): + super(TypeInfoResolver, self).__init__(context) self.scope = Scope(None) - self.value_hints = value_hints self.function_level = 0 def visit_FunctionDef(self, node): @@ -120,13 +121,11 @@ class TypeInfoResolver(gast.NodeTransformer): self.generic_visit(node) if isinstance(node.ctx, gast.Param): self.scope.setval(node.id, gast.Name(node.id, gast.Load(), None)) - # TODO(mdan): Member functions should not need type hints. - # We could attemp to extract im_class from the live_val annotation. - if self.function_level == 1 and node.id in self.value_hints: + if self.function_level == 1 and node.id in self.context.arg_types: # Forge a node to hold the type information, so that method calls on # it can resolve the type. type_holder = gast.Name(node.id, gast.Load(), None) - type_string, type_obj = self.value_hints[node.id] + type_string, type_obj = self.context.arg_types[node.id] anno.setanno(type_holder, 'type', type_obj) anno.setanno(type_holder, 'type_fqn', tuple(type_string.split('.'))) self.scope.setval(node.id, type_holder) @@ -206,6 +205,5 @@ class TypeInfoResolver(gast.NodeTransformer): return node -def resolve(node, value_hints): - assert value_hints is not None - return TypeInfoResolver(value_hints).visit(node) +def resolve(node, context): + return TypeInfoResolver(context).visit(node) diff --git a/tensorflow/contrib/py2tf/pyct/static_analysis/type_info_test.py b/tensorflow/contrib/py2tf/pyct/static_analysis/type_info_test.py index 8526f42413b9cca077da45195249615b55c45bc9..a491f49ca3b87d1340fdd691431e127737abc006 100644 --- a/tensorflow/contrib/py2tf/pyct/static_analysis/type_info_test.py +++ b/tensorflow/contrib/py2tf/pyct/static_analysis/type_info_test.py @@ -19,7 +19,9 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.py2tf.pyct import anno +from tensorflow.contrib.py2tf.pyct import context from tensorflow.contrib.py2tf.pyct import parser +from tensorflow.contrib.py2tf.pyct import transformer from tensorflow.contrib.py2tf.pyct.static_analysis import access from tensorflow.contrib.py2tf.pyct.static_analysis import live_values from tensorflow.contrib.py2tf.pyct.static_analysis import type_info @@ -54,17 +56,27 @@ class ScopeTest(test.TestCase): class TypeInfoResolverTest(test.TestCase): + def _parse_and_analyze(self, test_fn, namespace, arg_types=None): + ctx = context.EntityContext( + namer=None, + source_code=None, + source_file=None, + namespace=namespace, + arg_values=None, + arg_types=arg_types) + node = parser.parse_object(test_fn) + node = access.resolve(node) + node = live_values.resolve(node, namespace, {}) + node = type_info.resolve(node, ctx) + return node + def test_constructor_detection(self): def test_fn(): opt = training.GradientDescentOptimizer(0.1) return opt - node = parser.parse_object(test_fn) - node = access.resolve(node) - node = live_values.resolve(node, {'training': training}, {}) - node = type_info.resolve(node, {}) - + node = self._parse_and_analyze(test_fn, {'training': training}) call_node = node.body[0].body[0].value self.assertEquals(training.GradientDescentOptimizer, anno.getanno(call_node, 'type')) @@ -77,11 +89,7 @@ class TypeInfoResolverTest(test.TestCase): opt = training.GradientDescentOptimizer(0.1) opt.minimize(0) - node = parser.parse_object(test_fn) - node = access.resolve(node) - node = live_values.resolve(node, {'training': training}, {}) - node = type_info.resolve(node, {}) - + node = self._parse_and_analyze(test_fn, {'training': training}) attr_call_node = node.body[0].body[1].value.func self.assertEquals((training.__name__, 'GradientDescentOptimizer'), anno.getanno(attr_call_node, 'type_fqn')) @@ -92,11 +100,7 @@ class TypeInfoResolverTest(test.TestCase): with session.Session() as sess: sess.run(x) - node = parser.parse_object(test_fn) - node = access.resolve(node) - node = live_values.resolve(node, {'session': session}, {}) - node = type_info.resolve(node, {}) - + node = self._parse_and_analyze(test_fn, {'session': session}) constructor_call = node.body[0].body[0].items[0].context_expr self.assertEquals(session.Session, anno.getanno(constructor_call, 'type')) self.assertEquals((session.__name__, 'Session'), @@ -115,33 +119,25 @@ class TypeInfoResolverTest(test.TestCase): opt = training.GradientDescentOptimizer(0.01) opt.minimize(0) - node = parser.parse_object(test_fn) - node = access.resolve(node) - node = live_values.resolve(node, {'training': training}, {}) - with self.assertRaises(ValueError): - node = type_info.resolve(node, {}) + with self.assertRaises(transformer.PyFlowParseError): + self._parse_and_analyze(test_fn, {'training': training}) def test_parameter_class_members(self): def test_fn(opt): opt.minimize(0) - node = parser.parse_object(test_fn) - node = access.resolve(node) - node = live_values.resolve(node, {'training': training}, {}) - with self.assertRaises(ValueError): - node = type_info.resolve(node, {}) + with self.assertRaises(transformer.PyFlowParseError): + self._parse_and_analyze(test_fn, {'training': training}) def test_parameter_class_members_with_value_hints(self): def test_fn(opt): opt.minimize(0) - node = parser.parse_object(test_fn) - node = access.resolve(node) - node = live_values.resolve(node, {'training': training}, {}) - node = type_info.resolve( - node, { + node = self._parse_and_analyze( + test_fn, {'training': training}, + arg_types={ 'opt': (('%s.GradientDescentOptimizer' % training.__name__), training.GradientDescentOptimizer(0.1)) }) @@ -160,11 +156,8 @@ class TypeInfoResolverTest(test.TestCase): foo = bar foo() - node = parser.parse_object(test_fn) - node = access.resolve(node) - node = live_values.resolve(node, {'bar': bar}, {}) - with self.assertRaises(ValueError): - node = type_info.resolve(node, {}) + with self.assertRaises(transformer.PyFlowParseError): + self._parse_and_analyze(test_fn, {'bar': bar}) def test_nested_members(self): @@ -172,11 +165,8 @@ class TypeInfoResolverTest(test.TestCase): foo = training.GradientDescentOptimizer(0.1) foo.bar.baz() - node = parser.parse_object(test_fn) - node = access.resolve(node) - node = live_values.resolve(node, {'training': training}, {}) - with self.assertRaises(ValueError): - node = type_info.resolve(node, {}) + with self.assertRaises(transformer.PyFlowParseError): + self._parse_and_analyze(test_fn, {'training': training}) if __name__ == '__main__': diff --git a/tensorflow/contrib/py2tf/pyct/transformer.py b/tensorflow/contrib/py2tf/pyct/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..d5aa23eaebbbf7540d52d9fa9cc5292e0f756e6d --- /dev/null +++ b/tensorflow/contrib/py2tf/pyct/transformer.py @@ -0,0 +1,58 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""A node transformer that includes utilities for SCT.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import gast + +from tensorflow.contrib.py2tf.pyct import pretty_printer + + +class PyFlowParseError(SyntaxError): + pass + + +class Base(gast.NodeTransformer): + """Base class for specialized transformers.""" + + def __init__(self, context): + """Initialize the transformer. Subclasses should call this. + + Args: + context: An EntityContext. + """ + self._lineno = 0 + self._col_offset = 0 + self.context = context + + def visit(self, node): + try: + source_code = self.context.source_code + source_file = self.context.source_file + if source_code and hasattr(node, 'lineno'): + self._lineno = node.lineno + self._col_offset = node.col_offset + return super(Base, self).visit(node) + except ValueError as e: + msg = '%s\nOccurred at node:\n%s' % (str(e), pretty_printer.fmt(node)) + if source_code: + line = self._source.splitlines()[self._lineno - 1] + else: + line = '' + raise PyFlowParseError( + msg, (source_file, self._lineno, self._col_offset + 1, line)) diff --git a/tensorflow/contrib/quantize/__init__.py b/tensorflow/contrib/quantize/__init__.py index 5d4e4575c935e0a888c6e5e4d0db640d93e1bd49..933200e60749e62094040672793953c1c79de6cf 100644 --- a/tensorflow/contrib/quantize/__init__.py +++ b/tensorflow/contrib/quantize/__init__.py @@ -27,6 +27,8 @@ from tensorflow.python.util.all_util import remove_undocumented _allowed_symbols = [ "create_eval_graph", "create_training_graph", + "experimental_create_eval_graph", + "experimental_create_training_graph", ] remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/quantize/python/quantize_graph.py b/tensorflow/contrib/quantize/python/quantize_graph.py index d647bb94e849c713c2aca93c53f372bae5857c43..bbd9743d8014ce495a4967e7484981f7e60ae4a3 100644 --- a/tensorflow/contrib/quantize/python/quantize_graph.py +++ b/tensorflow/contrib/quantize/python/quantize_graph.py @@ -128,3 +128,67 @@ def create_eval_graph(input_graph, elements=None, device_name_or_function=None): is_training=False, elements=elements, device_name_or_function=device_name_or_function) + + +def experimental_create_training_graph(input_graph, + elements=None, + device_name_or_function=None): + """Returns a transformed training input_graph for simulated quantization. + + This function has additional experimental options not (yet) available to + create_training_graph. The resulting behavior may be undefined. + The forward pass has fake quantization ops inserted to simulate the error + introduced by quantization. + + Args: + input_graph: The tf.Graph to be transformed. + elements: (Optional) List of Tensors and Operations in input_graph whose + corresponding elements in the new graph will be returned. + device_name_or_function: (Optional) The device name or function to use. + + Returns: + g is new tf.Graph that is rewritten for simulated quantization. + l is a list of Tensors/Operations in g corresponding to the provided input + elements, if elements is not None. + + Raises: + ValueError: If elements contains an element that isn't a tf.Tensor or + tf.Operation. + """ + return _create_graph( + input_graph=input_graph, + is_training=True, + elements=elements, + device_name_or_function=device_name_or_function) + + +def experimental_create_eval_graph(input_graph, + elements=None, + device_name_or_function=None): + """Returns a transformed eval input_graph for simulated quantization. + + This function has additional experimental options not (yet) available to + create_eval_graph. The resulting behavior may be undefined. + The forward pass has fake quantization ops inserted to simulate the error + introduced by quantization. + + Args: + input_graph: The tf.Graph to be transformed. + elements: (Optional) List of Tensors and Operations in input_graph whose + corresponding elements in the new graph will be returned. + device_name_or_function: (Optional) The device name or function to use. + + Returns: + g is new tf.Graph that is rewritten for simulated quantization. + l is a list of Tensors/Operations in g corresponding to the provided input + elements, if elements is not None. + + Raises: + ValueError: If elements contains an element that isn't a tf.Tensor or + tf.Operation. + """ + return _create_graph( + input_graph=input_graph, + is_training=False, + elements=elements, + device_name_or_function=device_name_or_function) diff --git a/tensorflow/contrib/quantize/python/quantize_graph_test.py b/tensorflow/contrib/quantize/python/quantize_graph_test.py index 3407ace3914fe2de2506a2952ea5d1bf19028bb9..514862a0ab5b796718a04aa65a46e7a7e3b86330 100644 --- a/tensorflow/contrib/quantize/python/quantize_graph_test.py +++ b/tensorflow/contrib/quantize/python/quantize_graph_test.py @@ -31,28 +31,30 @@ from tensorflow.python.platform import googletest class QuantizeGraphTest(test_util.TensorFlowTestCase): - # We have a lot of other tests that test the details of the rewrite, here we # just the specific features of the quantize_graph API. - def testReturnedElementsTraining(self): - self._TestReturnElements(True) - def testReturnedElementsEval(self): - self._TestReturnElements(False) + def _RunTestOverParameters(self, test_fn): + rewrite_fns = [ + quantize_graph.create_training_graph, + quantize_graph.create_eval_graph, + quantize_graph.experimental_create_training_graph, + quantize_graph.experimental_create_eval_graph, + ] + for fn in rewrite_fns: + test_fn(fn) + + def testReturnedElements(self): + self._RunTestOverParameters(self._TestReturnElements) - def _TestReturnElements(self, is_training): + def _TestReturnElements(self, fn): graph = ops.Graph() with graph.as_default(): a = constant_op.constant(1.0) b = variables.Variable(2.0) c = a + b elements = [a, b, c.op] - if is_training: - q_graph, returned_elements = quantize_graph.create_training_graph( - graph, elements=elements) - else: - q_graph, returned_elements = quantize_graph.create_eval_graph( - graph, elements=elements) + q_graph, returned_elements = fn(graph, elements=elements) # Make sure q_graph is different from graph. self.assertTrue(graph != q_graph) # Check that the returned elements are part of the new graph. @@ -62,35 +64,26 @@ class QuantizeGraphTest(test_util.TensorFlowTestCase): for element, returned_element in zip(elements, returned_elements): self.assertEqual(element.name, returned_element.name) - def testNoReturnElementsTraining(self): - self._TestNoReturnElements(True) + def testNoReturnElements(self): + self._RunTestOverParameters(self._TestNoReturnElements) - def testNoReturnElementsEval(self): - self._TestNoReturnElements(False) - - def _TestNoReturnElements(self, is_training): + def _TestNoReturnElements(self, fn): graph = ops.Graph() with graph.as_default(): a = constant_op.constant(1.0) b = variables.Variable(2.0) _ = a + b - if is_training: - q_graph = quantize_graph.create_training_graph(graph) - else: - q_graph = quantize_graph.create_eval_graph(graph) + q_graph = fn(graph) # Check that quantize_graph didn't return a tuple when elements isn't # provided. self.assertTrue(isinstance(q_graph, ops.Graph)) # Make sure q_graph is different from graph. self.assertTrue(graph != q_graph) - def testDeviceNameTraining(self): - self._TestDeviceName(True) - - def testDeviceNameEval(self): - self._TestDeviceName(False) + def testDeviceName(self): + self._RunTestOverParameters(self._TestDeviceName) - def _TestDeviceName(self, is_training): + def _TestDeviceName(self, fn): graph = ops.Graph() with graph.as_default(): batch_size, height, width, depth = 5, 128, 128, 3 @@ -106,12 +99,7 @@ class QuantizeGraphTest(test_util.TensorFlowTestCase): _ = nn_ops.relu6(conv) device_name = '/job:oink/task:0/device:CPU:0' - if is_training: - q_graph = quantize_graph.create_training_graph( - graph, device_name_or_function=device_name) - else: - q_graph = quantize_graph.create_eval_graph( - graph, device_name_or_function=device_name) + q_graph = fn(graph, device_name_or_function=device_name) orig_variable_names = set( [v.name for v in graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)]) diff --git a/tensorflow/contrib/receptive_field/python/util/graph_compute_order.py b/tensorflow/contrib/receptive_field/python/util/graph_compute_order.py index b2360fec6ca2afd23233041cdd0d3fcadb4a460b..0388079f20dee0a6b249d568e2c51d1407d7466f 100644 --- a/tensorflow/contrib/receptive_field/python/util/graph_compute_order.py +++ b/tensorflow/contrib/receptive_field/python/util/graph_compute_order.py @@ -61,7 +61,7 @@ def _compute_output_resolution(input_spatial_resolution, kernel_size, stride, stride: Stride (int). total_padding: Total padding to be applied (int). Returns: - output_resolution: Ouput dimension (int) or None. + output_resolution: Output dimension (int) or None. """ if (input_spatial_resolution is None) or (kernel_size is None) or ( stride is None) or (total_padding is None): diff --git a/tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops.h b/tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops.h index fc3a2da9b398b16df223d60e2e913f952fa24434..9bb1724a2c0b70ee7ce7238cc179aded95935b26 100644 --- a/tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops.h +++ b/tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_PARTIAL_REDUCTION_OPS_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_PARTIAL_REDUCTION_OPS_H_ +#ifndef TENSORFLOW_CORE_KERNELS_PARTIAL_REDUCTION_OPS_H_ +#define TENSORFLOW_CORE_KERNELS_PARTIAL_REDUCTION_OPS_H_ #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" @@ -81,4 +81,4 @@ CALL_ALL_REDUCEOPS(ReduceSliceFunctorReduceop) } // namespace functor } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_PARTIAL_REDUCTION_OPS_H_ +#endif // TENSORFLOW_CORE_KERNELS_PARTIAL_REDUCTION_OPS_H_ diff --git a/tensorflow/contrib/resampler/kernels/resampler_ops.h b/tensorflow/contrib/resampler/kernels/resampler_ops.h index 8258ecaf5d3ba67094194c5cb12ca6d4d6efc85f..85d3676efac70fe9237d31c2be1fe75e67d70abd 100644 --- a/tensorflow/contrib/resampler/kernels/resampler_ops.h +++ b/tensorflow/contrib/resampler/kernels/resampler_ops.h @@ -13,8 +13,8 @@ // limitations under the License. // ============================================================================= -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_RESAMPLER_KERNELS_RESAMPLER_OPS_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_RESAMPLER_KERNELS_RESAMPLER_OPS_H_ +#ifndef TENSORFLOW_CONTRIB_RESAMPLER_KERNELS_RESAMPLER_OPS_H_ +#define TENSORFLOW_CONTRIB_RESAMPLER_KERNELS_RESAMPLER_OPS_H_ #if PLATFORM_WINDOWS #define __restrict__ __restrict @@ -64,5 +64,4 @@ struct ResamplerGrad2DFunctor{ } // namespace functor } // namespace tensorflow - -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_RESAMPLER_KERNELS_RESAMPLER_OPS_H_ +#endif // TENSORFLOW_CONTRIB_RESAMPLER_KERNELS_RESAMPLER_OPS_H_ diff --git a/tensorflow/contrib/rnn/kernels/blas_gemm.h b/tensorflow/contrib/rnn/kernels/blas_gemm.h index e33eceadff17fc3811f98fc29b3cb916b6a79766..a52c934233af3dc63e1a60d70fac6a9eba6a655b 100644 --- a/tensorflow/contrib/rnn/kernels/blas_gemm.h +++ b/tensorflow/contrib/rnn/kernels/blas_gemm.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_RNN_KERNELS_BLAS_GEMM_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_RNN_KERNELS_BLAS_GEMM_H_ +#ifndef TENSORFLOW_CONTRIB_RNN_KERNELS_BLAS_GEMM_H_ +#define TENSORFLOW_CONTRIB_RNN_KERNELS_BLAS_GEMM_H_ #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/tensor_types.h" @@ -74,4 +74,4 @@ struct TensorBlasGemm { } // namespace functor } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_RNN_KERNELS_BLAS_GEMM_H_ +#endif // TENSORFLOW_CONTRIB_RNN_KERNELS_BLAS_GEMM_H_ diff --git a/tensorflow/contrib/rnn/kernels/gru_ops.h b/tensorflow/contrib/rnn/kernels/gru_ops.h index 06a566506296dd658a01bb3038407f77a32cde84..3e2cb39e64bb3f0b22ea66c5601af36c5fb9b0fd 100644 --- a/tensorflow/contrib/rnn/kernels/gru_ops.h +++ b/tensorflow/contrib/rnn/kernels/gru_ops.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_RNN_KERNELS_GRU_OPS_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_RNN_KERNELS_GRU_OPS_H_ +#ifndef TENSORFLOW_CONTRIB_RNN_KERNELS_GRU_OPS_H_ +#define TENSORFLOW_CONTRIB_RNN_KERNELS_GRU_OPS_H_ #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/contrib/rnn/kernels/blas_gemm.h" @@ -181,4 +181,4 @@ struct GRUBlockCellBprop : public GRUCell { } // namespace functor } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_RNN_KERNELS_GRU_OPS_H_ +#endif // TENSORFLOW_CONTRIB_RNN_KERNELS_GRU_OPS_H_ diff --git a/tensorflow/contrib/rnn/kernels/lstm_ops.h b/tensorflow/contrib/rnn/kernels/lstm_ops.h index 1906581b16b2e76243320bc67c8ac831323fb8e7..bc6b85f3f1ab80b5ef5b4a8ba2e5242cf451adbe 100644 --- a/tensorflow/contrib/rnn/kernels/lstm_ops.h +++ b/tensorflow/contrib/rnn/kernels/lstm_ops.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_RNN_KERNELS_LSTM_OPS_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_RNN_KERNELS_LSTM_OPS_H_ +#ifndef TENSORFLOW_CONTRIB_RNN_KERNELS_LSTM_OPS_H_ +#define TENSORFLOW_CONTRIB_RNN_KERNELS_LSTM_OPS_H_ #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/contrib/rnn/kernels/blas_gemm.h" @@ -291,4 +291,4 @@ struct BlockLSTMBprop : public LSTMBlockCell { } // namespace functor } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_RNN_KERNELS_LSTM_OPS_H_ +#endif // TENSORFLOW_CONTRIB_RNN_KERNELS_LSTM_OPS_H_ diff --git a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py index b5d81b7caac5186b34548a06c67ba48afab0a1a5..cafeb56ad88ba83fb42faf16db8ee1035da1deac 100644 --- a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py +++ b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py @@ -663,6 +663,12 @@ class DropoutWrapperTest(test.TestCase): self.assertEqual(res[1].h.shape, (batch_size, 3)) return res + def testWrappedCellProperty(self): + cell = rnn_cell_impl.BasicRNNCell(10) + wrapper = rnn_cell_impl.DropoutWrapper(cell) + # Github issue 15810 + self.assertEqual(wrapper.wrapped_cell, cell) + def testDropoutWrapperKeepAllConstantInput(self): keep = array_ops.ones([]) res = self._testDropoutWrapper( diff --git a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py index 73789206f3120c34b686a8af98f37d7683bc88ae..8a3894ef9d7042e66b52edefdf08b278dcc6c4f4 100644 --- a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py +++ b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py @@ -53,14 +53,12 @@ class RNNCellTest(test.TestCase): batch_size = 3 input_size = 4 expected_output = np.array( - [[0.121753, 0.121753], - [0.103349, 0.103349], - [0.100178, 0.100178]], + [[0.121753, 0.121753], [0.103349, 0.103349], [0.100178, 0.100178]], dtype=np.float32) expected_state = np.array( - [[0.137523, 0.137523, 0.121753, 0.121753], - [0.105450, 0.105450, 0.103349, 0.103349], - [0.100742, 0.100742, 0.100178, 0.100178]], + [[0.137523, 0.137523, 0.121753, 0.121753], [ + 0.105450, 0.105450, 0.103349, 0.103349 + ], [0.100742, 0.100742, 0.100178, 0.100178]], dtype=np.float32) with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): @@ -69,14 +67,14 @@ class RNNCellTest(test.TestCase): output, state = contrib_rnn_cell.CoupledInputForgetGateLSTMCell( num_units=num_units, forget_bias=1.0, state_is_tuple=False)(x, m) sess.run([variables.global_variables_initializer()]) - res = sess.run([output, state], { - x.name: - np.array([[1., 1., 1., 1.], - [2., 2., 2., 2.], - [3., 3., 3., 3.]]), - m.name: - 0.1 * np.ones((batch_size, state_size)) - }) + res = sess.run( + [output, state], { + x.name: + np.array([[1., 1., 1., 1.], [2., 2., 2., 2.], + [3., 3., 3., 3.]]), + m.name: + 0.1 * np.ones((batch_size, state_size)) + }) # This is a smoke test: Only making sure expected values didn't change. self.assertEqual(len(res), 2) self.assertAllClose(res[0], expected_output) @@ -101,14 +99,14 @@ class RNNCellTest(test.TestCase): frequency_skip=frequency_skip, forget_bias=1.0)(x, m) sess.run([variables.global_variables_initializer()]) - res = sess.run([output, state], { - x.name: - np.array([[1., 1., 1., 1.], - [2., 2., 2., 2.], - [3., 3., 3., 3.]]), - m.name: - 0.1 * np.ones((batch_size, int(state_size * (num_shifts)))) - }) + res = sess.run( + [output, state], { + x.name: + np.array([[1., 1., 1., 1.], [2., 2., 2., 2.], + [3., 3., 3., 3.]]), + m.name: + 0.1 * np.ones((batch_size, int(state_size * (num_shifts)))) + }) self.assertEqual(len(res), 2) # The numbers in results were not calculated, this is mostly just a # smoke test. @@ -141,17 +139,14 @@ class RNNCellTest(test.TestCase): state_is_tuple=True) inputs = constant_op.constant( np.array( - [[1., 1., 1., 1.], - [2., 2., 2., 2.], - [3., 3., 3., 3.]], + [[1., 1., 1., 1.], [2., 2., 2., 2.], [3., 3., 3., 3.]], dtype=np.float32), dtype=dtypes.float32) state_value = constant_op.constant( - 0.1 * np.ones( - (batch_size, num_units), dtype=np.float32), + 0.1 * np.ones((batch_size, num_units), dtype=np.float32), dtype=dtypes.float32) - init_state = cell.state_tuple_type( - *([state_value, state_value] * num_shifts)) + init_state = cell.state_tuple_type(*( + [state_value, state_value] * num_shifts)) output, state = cell(inputs, init_state) sess.run([variables.global_variables_initializer()]) res = sess.run([output, state]) @@ -198,11 +193,10 @@ class RNNCellTest(test.TestCase): dtype=np.float32), dtype=dtypes.float32) state_value = constant_op.constant( - 0.1 * np.ones( - (batch_size, num_units), dtype=np.float32), + 0.1 * np.ones((batch_size, num_units), dtype=np.float32), dtype=dtypes.float32) - init_state = cell.state_tuple_type( - *([state_value, state_value] * total_blocks)) + init_state = cell.state_tuple_type(*( + [state_value, state_value] * total_blocks)) output, state = cell(inputs, init_state) sess.run([variables.global_variables_initializer()]) res = sess.run([output, state]) @@ -230,20 +224,28 @@ class RNNCellTest(test.TestCase): frequency_skip = 1 num_shifts = int((input_size - feature_size) / frequency_skip + 1) expected_output = np.array( - [[0.416383, 0.416383, 0.403238, 0.403238, 0.524020, 0.524020, - 0.565425, 0.565425, 0.557865, 0.557865, 0.609699, 0.609699], - [0.627331, 0.627331, 0.622393, 0.622393, 0.688342, 0.688342, - 0.708078, 0.708078, 0.694245, 0.694245, 0.715171, 0.715171], - [0.711050, 0.711050, 0.709197, 0.709197, 0.736533, 0.736533, - 0.744264, 0.744264, 0.737390, 0.737390, 0.745250, 0.745250]], + [[ + 0.416383, 0.416383, 0.403238, 0.403238, 0.524020, 0.524020, + 0.565425, 0.565425, 0.557865, 0.557865, 0.609699, 0.609699 + ], [ + 0.627331, 0.627331, 0.622393, 0.622393, 0.688342, 0.688342, + 0.708078, 0.708078, 0.694245, 0.694245, 0.715171, 0.715171 + ], [ + 0.711050, 0.711050, 0.709197, 0.709197, 0.736533, 0.736533, + 0.744264, 0.744264, 0.737390, 0.737390, 0.745250, 0.745250 + ]], dtype=np.float32) expected_state = np.array( - [[0.625556, 0.625556, 0.416383, 0.416383, 0.759134, 0.759134, - 0.524020, 0.524020, 0.798795, 0.798795, 0.557865, 0.557865], - [0.875488, 0.875488, 0.627331, 0.627331, 0.936432, 0.936432, - 0.688342, 0.688342, 0.941961, 0.941961, 0.694245, 0.694245], - [0.957327, 0.957327, 0.711050, 0.711050, 0.979522, 0.979522, - 0.736533, 0.736533, 0.980245, 0.980245, 0.737390, 0.737390]], + [[ + 0.625556, 0.625556, 0.416383, 0.416383, 0.759134, 0.759134, + 0.524020, 0.524020, 0.798795, 0.798795, 0.557865, 0.557865 + ], [ + 0.875488, 0.875488, 0.627331, 0.627331, 0.936432, 0.936432, + 0.688342, 0.688342, 0.941961, 0.941961, 0.694245, 0.694245 + ], [ + 0.957327, 0.957327, 0.711050, 0.711050, 0.979522, 0.979522, + 0.736533, 0.736533, 0.980245, 0.980245, 0.737390, 0.737390 + ]], dtype=np.float32) for state_is_tuple in [False, True]: with self.test_session() as sess: @@ -259,18 +261,16 @@ class RNNCellTest(test.TestCase): couple_input_forget_gates=True, state_is_tuple=state_is_tuple) inputs = constant_op.constant( - np.array([[1., 1., 1., 1.], - [2., 2., 2., 2.], - [3., 3., 3., 3.]], - dtype=np.float32), + np.array( + [[1., 1., 1., 1.], [2., 2., 2., 2.], [3., 3., 3., 3.]], + dtype=np.float32), dtype=dtypes.float32) if state_is_tuple: state_value = constant_op.constant( - 0.1 * np.ones( - (batch_size, num_units), dtype=np.float32), + 0.1 * np.ones((batch_size, num_units), dtype=np.float32), dtype=dtypes.float32) - init_state = cell.state_tuple_type( - *([state_value, state_value] * num_shifts)) + init_state = cell.state_tuple_type(*( + [state_value, state_value] * num_shifts)) else: init_state = constant_op.constant( 0.1 * np.ones( @@ -302,32 +302,40 @@ class RNNCellTest(test.TestCase): frequency_skip = 1 num_shifts = int((input_size - feature_size) / frequency_skip + 1) expected_output = np.array( - [[0.464130, 0.464130, 0.419165, 0.419165, 0.593283, 0.593283, - 0.738350, 0.738350, 0.661638, 0.661638, 0.866774, 0.866774, - 0.520789, 0.520789, 0.476968, 0.476968, 0.604341, 0.604341, - 0.760207, 0.760207, 0.635773, 0.635773, 0.850218, 0.850218], - [0.669636, 0.669636, 0.628966, 0.628966, 0.736057, 0.736057, - 0.895927, 0.895927, 0.755559, 0.755559, 0.954359, 0.954359, - 0.692621, 0.692621, 0.652363, 0.652363, 0.737517, 0.737517, - 0.899558, 0.899558, 0.745984, 0.745984, 0.946840, 0.946840], - [0.751109, 0.751109, 0.711716, 0.711716, 0.778357, 0.778357, - 0.940779, 0.940779, 0.784530, 0.784530, 0.980604, 0.980604, - 0.759940, 0.759940, 0.720652, 0.720652, 0.778552, 0.778552, - 0.941606, 0.941606, 0.781035, 0.781035, 0.977731, 0.977731]], + [[ + 0.464130, 0.464130, 0.419165, 0.419165, 0.593283, 0.593283, + 0.738350, 0.738350, 0.661638, 0.661638, 0.866774, 0.866774, + 0.520789, 0.520789, 0.476968, 0.476968, 0.604341, 0.604341, + 0.760207, 0.760207, 0.635773, 0.635773, 0.850218, 0.850218 + ], [ + 0.669636, 0.669636, 0.628966, 0.628966, 0.736057, 0.736057, + 0.895927, 0.895927, 0.755559, 0.755559, 0.954359, 0.954359, + 0.692621, 0.692621, 0.652363, 0.652363, 0.737517, 0.737517, + 0.899558, 0.899558, 0.745984, 0.745984, 0.946840, 0.946840 + ], [ + 0.751109, 0.751109, 0.711716, 0.711716, 0.778357, 0.778357, + 0.940779, 0.940779, 0.784530, 0.784530, 0.980604, 0.980604, + 0.759940, 0.759940, 0.720652, 0.720652, 0.778552, 0.778552, + 0.941606, 0.941606, 0.781035, 0.781035, 0.977731, 0.977731 + ]], dtype=np.float32) expected_state = np.array( - [[0.710660, 0.710660, 0.464130, 0.464130, 0.877293, 0.877293, - 0.593283, 0.593283, 0.958505, 0.958505, 0.661638, 0.661638, - 0.785405, 0.785405, 0.520789, 0.520789, 0.890836, 0.890836, - 0.604341, 0.604341, 0.928512, 0.928512, 0.635773, 0.635773], - [0.967579, 0.967579, 0.669636, 0.669636, 1.038811, 1.038811, - 0.736057, 0.736057, 1.058201, 1.058201, 0.755559, 0.755559, - 0.993088, 0.993088, 0.692621, 0.692621, 1.040288, 1.040288, - 0.737517, 0.737517, 1.048773, 1.048773, 0.745984, 0.745984], - [1.053842, 1.053842, 0.751109, 0.751109, 1.079919, 1.079919, - 0.778357, 0.778357, 1.085620, 1.085620, 0.784530, 0.784530, - 1.062455, 1.062455, 0.759940, 0.759940, 1.080101, 1.080101, - 0.778552, 0.778552, 1.082402, 1.082402, 0.781035, 0.781035]], + [[ + 0.710660, 0.710660, 0.464130, 0.464130, 0.877293, 0.877293, + 0.593283, 0.593283, 0.958505, 0.958505, 0.661638, 0.661638, + 0.785405, 0.785405, 0.520789, 0.520789, 0.890836, 0.890836, + 0.604341, 0.604341, 0.928512, 0.928512, 0.635773, 0.635773 + ], [ + 0.967579, 0.967579, 0.669636, 0.669636, 1.038811, 1.038811, + 0.736057, 0.736057, 1.058201, 1.058201, 0.755559, 0.755559, + 0.993088, 0.993088, 0.692621, 0.692621, 1.040288, 1.040288, + 0.737517, 0.737517, 1.048773, 1.048773, 0.745984, 0.745984 + ], [ + 1.053842, 1.053842, 0.751109, 0.751109, 1.079919, 1.079919, + 0.778357, 0.778357, 1.085620, 1.085620, 0.784530, 0.784530, + 1.062455, 1.062455, 0.759940, 0.759940, 1.080101, 1.080101, + 0.778552, 0.778552, 1.082402, 1.082402, 0.781035, 0.781035 + ]], dtype=np.float32) with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): @@ -339,17 +347,16 @@ class RNNCellTest(test.TestCase): forget_bias=1.0, num_frequency_blocks=[num_shifts]) inputs = constant_op.constant( - np.array([[1.0, 1.1, 1.2, 1.3], - [2.0, 2.1, 2.2, 2.3], - [3.0, 3.1, 3.2, 3.3]], - dtype=np.float32), + np.array( + [[1.0, 1.1, 1.2, 1.3], [2.0, 2.1, 2.2, 2.3], + [3.0, 3.1, 3.2, 3.3]], + dtype=np.float32), dtype=dtypes.float32) state_value = constant_op.constant( - 0.1 * np.ones( - (batch_size, num_units), dtype=np.float32), + 0.1 * np.ones((batch_size, num_units), dtype=np.float32), dtype=dtypes.float32) - init_state = cell.state_tuple_type( - *([state_value, state_value] * num_shifts * 2)) + init_state = cell.state_tuple_type(*( + [state_value, state_value] * num_shifts * 2)) output, state = cell(inputs, init_state) sess.run([variables.global_variables_initializer()]) res = sess.run([output, state]) @@ -375,32 +382,40 @@ class RNNCellTest(test.TestCase): frequency_skip = 1 num_shifts = int((input_size - feature_size) / frequency_skip + 1) expected_output = np.array( - [[0.464130, 0.464130, 0.419165, 0.419165, 0.593283, 0.593283, - 0.738350, 0.738350, 0.661638, 0.661638, 0.866774, 0.866774, - 0.322645, 0.322645, 0.276068, 0.276068, 0.584654, 0.584654, - 0.690292, 0.690292, 0.640446, 0.640446, 0.840071, 0.840071], - [0.669636, 0.669636, 0.628966, 0.628966, 0.736057, 0.736057, - 0.895927, 0.895927, 0.755559, 0.755559, 0.954359, 0.954359, - 0.493625, 0.493625, 0.449236, 0.449236, 0.730828, 0.730828, - 0.865996, 0.865996, 0.749429, 0.749429, 0.944958, 0.944958], - [0.751109, 0.751109, 0.711716, 0.711716, 0.778357, 0.778357, - 0.940779, 0.940779, 0.784530, 0.784530, 0.980604, 0.980604, - 0.608587, 0.608587, 0.566683, 0.566683, 0.777345, 0.777345, - 0.925820, 0.925820, 0.782597, 0.782597, 0.976858, 0.976858]], + [[ + 0.464130, 0.464130, 0.419165, 0.419165, 0.593283, 0.593283, + 0.738350, 0.738350, 0.661638, 0.661638, 0.866774, 0.866774, + 0.322645, 0.322645, 0.276068, 0.276068, 0.584654, 0.584654, + 0.690292, 0.690292, 0.640446, 0.640446, 0.840071, 0.840071 + ], [ + 0.669636, 0.669636, 0.628966, 0.628966, 0.736057, 0.736057, + 0.895927, 0.895927, 0.755559, 0.755559, 0.954359, 0.954359, + 0.493625, 0.493625, 0.449236, 0.449236, 0.730828, 0.730828, + 0.865996, 0.865996, 0.749429, 0.749429, 0.944958, 0.944958 + ], [ + 0.751109, 0.751109, 0.711716, 0.711716, 0.778357, 0.778357, + 0.940779, 0.940779, 0.784530, 0.784530, 0.980604, 0.980604, + 0.608587, 0.608587, 0.566683, 0.566683, 0.777345, 0.777345, + 0.925820, 0.925820, 0.782597, 0.782597, 0.976858, 0.976858 + ]], dtype=np.float32) expected_state = np.array( - [[0.710660, 0.710660, 0.464130, 0.464130, 0.877293, 0.877293, - 0.593283, 0.593283, 0.958505, 0.958505, 0.661638, 0.661638, - 0.516575, 0.516575, 0.322645, 0.322645, 0.866628, 0.866628, - 0.584654, 0.584654, 0.934002, 0.934002, 0.640446, 0.640446], - [0.967579, 0.967579, 0.669636, 0.669636, 1.038811, 1.038811, - 0.736057, 0.736057, 1.058201, 1.058201, 0.755559, 0.755559, - 0.749836, 0.749836, 0.493625, 0.493625, 1.033488, 1.033488, - 0.730828, 0.730828, 1.052186, 1.052186, 0.749429, 0.749429], - [1.053842, 1.053842, 0.751109, 0.751109, 1.079919, 1.079919, - 0.778357, 0.778357, 1.085620, 1.085620, 0.784530, 0.784530, - 0.895999, 0.895999, 0.608587, 0.608587, 1.078978, 1.078978, - 0.777345, 0.777345, 1.083843, 1.083843, 0.782597, 0.782597]], + [[ + 0.710660, 0.710660, 0.464130, 0.464130, 0.877293, 0.877293, + 0.593283, 0.593283, 0.958505, 0.958505, 0.661638, 0.661638, + 0.516575, 0.516575, 0.322645, 0.322645, 0.866628, 0.866628, + 0.584654, 0.584654, 0.934002, 0.934002, 0.640446, 0.640446 + ], [ + 0.967579, 0.967579, 0.669636, 0.669636, 1.038811, 1.038811, + 0.736057, 0.736057, 1.058201, 1.058201, 0.755559, 0.755559, + 0.749836, 0.749836, 0.493625, 0.493625, 1.033488, 1.033488, + 0.730828, 0.730828, 1.052186, 1.052186, 0.749429, 0.749429 + ], [ + 1.053842, 1.053842, 0.751109, 0.751109, 1.079919, 1.079919, + 0.778357, 0.778357, 1.085620, 1.085620, 0.784530, 0.784530, + 0.895999, 0.895999, 0.608587, 0.608587, 1.078978, 1.078978, + 0.777345, 0.777345, 1.083843, 1.083843, 0.782597, 0.782597 + ]], dtype=np.float32) with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): @@ -413,17 +428,16 @@ class RNNCellTest(test.TestCase): num_frequency_blocks=[num_shifts], backward_slice_offset=1) inputs = constant_op.constant( - np.array([[1.0, 1.1, 1.2, 1.3], - [2.0, 2.1, 2.2, 2.3], - [3.0, 3.1, 3.2, 3.3]], - dtype=np.float32), + np.array( + [[1.0, 1.1, 1.2, 1.3], [2.0, 2.1, 2.2, 2.3], + [3.0, 3.1, 3.2, 3.3]], + dtype=np.float32), dtype=dtypes.float32) state_value = constant_op.constant( - 0.1 * np.ones( - (batch_size, num_units), dtype=np.float32), + 0.1 * np.ones((batch_size, num_units), dtype=np.float32), dtype=dtypes.float32) - init_state = cell.state_tuple_type( - *([state_value, state_value] * num_shifts * 2)) + init_state = cell.state_tuple_type(*( + [state_value, state_value] * num_shifts * 2)) output, state = cell(inputs, init_state) sess.run([variables.global_variables_initializer()]) res = sess.run([output, state]) @@ -474,8 +488,8 @@ class RNNCellTest(test.TestCase): for state_is_tuple in [False, True]: with ops.Graph().as_default(): with self.test_session() as sess: - with variable_scope.variable_scope("state_is_tuple_" + str( - state_is_tuple)): + with variable_scope.variable_scope( + "state_is_tuple_" + str(state_is_tuple)): lstm_cell = rnn_cell.BasicLSTMCell( num_units, state_is_tuple=state_is_tuple) cell = contrib_rnn_cell.AttentionCellWrapper( @@ -525,16 +539,15 @@ class RNNCellTest(test.TestCase): for state_is_tuple in [False, True]: with ops.Graph().as_default(): with self.test_session() as sess: - with variable_scope.variable_scope("state_is_tuple_" + str( - state_is_tuple)): + with variable_scope.variable_scope( + "state_is_tuple_" + str(state_is_tuple)): lstm_cell = rnn_cell.BasicLSTMCell( num_units, state_is_tuple=state_is_tuple) cell = contrib_rnn_cell.AttentionCellWrapper( lstm_cell, attn_length, state_is_tuple=state_is_tuple) if state_is_tuple: zeros = constant_op.constant( - 0.1 * np.ones( - [batch_size, num_units], dtype=np.float32), + 0.1 * np.ones([batch_size, num_units], dtype=np.float32), dtype=dtypes.float32) attn_state_zeros = constant_op.constant( 0.1 * np.ones( @@ -579,22 +592,25 @@ class RNNCellTest(test.TestCase): [1.018088, 0.378983, -0.572179, 0.268591]], dtype=np.float32) expected_state = np.array( - [[0.74946702, 0.34681597, 0.26474735, 1.06485605, 0.38465962, - 0.11420801, 0.10272158, 0.30925757, 0.63899988, 0.7181077, - 0.47534478, 0.33715725, 0.58086717, 0.49446869, 0.7641536, - 0.12814975, 0.92231739, 0.89857256, 0.21889746, 0.38442063, - 0.53481543, 0.8876909, 0.45823169, 0.5905602, 0.78038228, - 0.56501579, 0.03971386, 0.09870267, 0.8074435, 0.66821432, - 0.99211812, 0.12295902, 1.14606023, 0.34370938, -0.79251152, - 0.51843399], - [0.5179342, 0.48682183, -0.25426468, 0.96810579, 0.28809637, - 0.13607743, -0.11446252, 0.26792109, 0.78047138, 0.63460857, - 0.49122369, 0.52007174, 0.73000264, 0.66986895, 0.73576689, - 0.86301267, 0.87887371, 0.35185754, 0.93417215, 0.64732957, - 0.63173044, 0.66627824, 0.53644657, 0.20477486, 0.98458421, - 0.38277245, 0.03746676, 0.92510188, 0.57714164, 0.84932971, - 0.36127412, 0.12125921, 1.1362772, 0.34361625, -0.78150457, - 0.70582712]], + [[ + 0.74946702, 0.34681597, 0.26474735, 1.06485605, 0.38465962, + 0.11420801, 0.10272158, 0.30925757, 0.63899988, 0.7181077, + 0.47534478, 0.33715725, 0.58086717, 0.49446869, 0.7641536, + 0.12814975, 0.92231739, 0.89857256, 0.21889746, 0.38442063, + 0.53481543, 0.8876909, 0.45823169, 0.5905602, 0.78038228, + 0.56501579, 0.03971386, 0.09870267, 0.8074435, 0.66821432, + 0.99211812, 0.12295902, 1.14606023, 0.34370938, -0.79251152, + 0.51843399 + ], [ + 0.5179342, 0.48682183, -0.25426468, 0.96810579, 0.28809637, + 0.13607743, -0.11446252, 0.26792109, 0.78047138, 0.63460857, + 0.49122369, 0.52007174, 0.73000264, 0.66986895, 0.73576689, + 0.86301267, 0.87887371, 0.35185754, 0.93417215, 0.64732957, + 0.63173044, 0.66627824, 0.53644657, 0.20477486, 0.98458421, + 0.38277245, 0.03746676, 0.92510188, 0.57714164, 0.84932971, + 0.36127412, 0.12125921, 1.1362772, 0.34361625, -0.78150457, + 0.70582712 + ]], dtype=np.float32) seed = 12345 random_seed.set_random_seed(seed) @@ -602,7 +618,8 @@ class RNNCellTest(test.TestCase): for state_is_tuple in [False, True]: with session.Session() as sess: with variable_scope.variable_scope( - "state_is_tuple", reuse=state_is_tuple, + "state_is_tuple", + reuse=state_is_tuple, initializer=init_ops.glorot_uniform_initializer()): lstm_cell = rnn_cell.BasicLSTMCell( num_units, state_is_tuple=state_is_tuple) @@ -646,36 +663,31 @@ class RNNCellTest(test.TestCase): def testNASCell(self): num_units = 6 batch_size = 3 - expected_output = np.array([[0.576751, 0.576751, 0.576751, 0.576751, - 0.576751, 0.576751], - [0.618936, 0.618936, 0.618936, 0.618936, - 0.618936, 0.618936], - [0.627393, 0.627393, 0.627393, 0.627393, - 0.627393, 0.627393]]) - expected_state = np.array([[0.71579772, 0.71579772, 0.71579772, 0.71579772, - 0.71579772, 0.71579772, 0.57675087, 0.57675087, - 0.57675087, 0.57675087, 0.57675087, 0.57675087], - [0.78041625, 0.78041625, 0.78041625, 0.78041625, - 0.78041625, 0.78041625, 0.6189357, 0.6189357, - 0.61893570, 0.6189357, 0.6189357, 0.6189357], - [0.79457647, 0.79457647, 0.79457647, 0.79457647, - 0.79457653, 0.79457653, 0.62739348, 0.62739348, - 0.62739348, 0.62739348, 0.62739348, 0.62739348] - ]) + expected_output = np.array( + [[0.576751, 0.576751, 0.576751, 0.576751, 0.576751, 0.576751], + [0.618936, 0.618936, 0.618936, 0.618936, 0.618936, 0.618936], + [0.627393, 0.627393, 0.627393, 0.627393, 0.627393, 0.627393]]) + expected_state = np.array([[ + 0.71579772, 0.71579772, 0.71579772, 0.71579772, 0.71579772, 0.71579772, + 0.57675087, 0.57675087, 0.57675087, 0.57675087, 0.57675087, 0.57675087 + ], [ + 0.78041625, 0.78041625, 0.78041625, 0.78041625, 0.78041625, 0.78041625, + 0.6189357, 0.6189357, 0.61893570, 0.6189357, 0.6189357, 0.6189357 + ], [ + 0.79457647, 0.79457647, 0.79457647, 0.79457647, 0.79457653, 0.79457653, + 0.62739348, 0.62739348, 0.62739348, 0.62739348, 0.62739348, 0.62739348 + ]]) with self.test_session() as sess: with variable_scope.variable_scope( - "nas_test", - initializer=init_ops.constant_initializer(0.5)): + "nas_test", initializer=init_ops.constant_initializer(0.5)): cell = contrib_rnn_cell.NASCell(num_units=num_units) inputs = constant_op.constant( - np.array([[1., 1., 1., 1.], - [2., 2., 2., 2.], - [3., 3., 3., 3.]], - dtype=np.float32), + np.array( + [[1., 1., 1., 1.], [2., 2., 2., 2.], [3., 3., 3., 3.]], + dtype=np.float32), dtype=dtypes.float32) state_value = constant_op.constant( - 0.1 * np.ones( - (batch_size, num_units), dtype=np.float32), + 0.1 * np.ones((batch_size, num_units), dtype=np.float32), dtype=dtypes.float32) init_state = rnn_cell.LSTMStateTuple(state_value, state_value) output, state = cell(inputs, init_state) @@ -699,39 +711,34 @@ class RNNCellTest(test.TestCase): num_units = 6 batch_size = 3 num_proj = 5 - expected_output = np.array([[1.697418, 1.697418, 1.697418, 1.697418, - 1.697418], - [1.840037, 1.840037, 1.840037, 1.840037, - 1.840037], - [1.873985, 1.873985, 1.873985, 1.873985, - 1.873985]]) - expected_state = np.array([[0.69855207, 0.69855207, 0.69855207, 0.69855207, - 0.69855207, 0.69855207, 1.69741797, 1.69741797, - 1.69741797, 1.69741797, 1.69741797], - [0.77073824, 0.77073824, 0.77073824, 0.77073824, - 0.77073824, 0.77073824, 1.84003687, 1.84003687, - 1.84003687, 1.84003687, 1.84003687], - [0.78973997, 0.78973997, 0.78973997, 0.78973997, - 0.78973997, 0.78973997, 1.87398517, 1.87398517, - 1.87398517, 1.87398517, 1.87398517]]) + expected_output = np.array( + [[1.697418, 1.697418, 1.697418, 1.697418, + 1.697418], [1.840037, 1.840037, 1.840037, 1.840037, 1.840037], + [1.873985, 1.873985, 1.873985, 1.873985, 1.873985]]) + expected_state = np.array([[ + 0.69855207, 0.69855207, 0.69855207, 0.69855207, 0.69855207, 0.69855207, + 1.69741797, 1.69741797, 1.69741797, 1.69741797, 1.69741797 + ], [ + 0.77073824, 0.77073824, 0.77073824, 0.77073824, 0.77073824, 0.77073824, + 1.84003687, 1.84003687, 1.84003687, 1.84003687, 1.84003687 + ], [ + 0.78973997, 0.78973997, 0.78973997, 0.78973997, 0.78973997, 0.78973997, + 1.87398517, 1.87398517, 1.87398517, 1.87398517, 1.87398517 + ]]) with self.test_session() as sess: with variable_scope.variable_scope( - "nas_proj_test", - initializer=init_ops.constant_initializer(0.5)): + "nas_proj_test", initializer=init_ops.constant_initializer(0.5)): cell = contrib_rnn_cell.NASCell(num_units=num_units, num_proj=num_proj) inputs = constant_op.constant( - np.array([[1., 1., 1., 1.], - [2., 2., 2., 2.], - [3., 3., 3., 3.]], - dtype=np.float32), + np.array( + [[1., 1., 1., 1.], [2., 2., 2., 2.], [3., 3., 3., 3.]], + dtype=np.float32), dtype=dtypes.float32) state_value_c = constant_op.constant( - 0.1 * np.ones( - (batch_size, num_units), dtype=np.float32), + 0.1 * np.ones((batch_size, num_units), dtype=np.float32), dtype=dtypes.float32) state_value_h = constant_op.constant( - 0.1 * np.ones( - (batch_size, num_proj), dtype=np.float32), + 0.1 * np.ones((batch_size, num_proj), dtype=np.float32), dtype=dtypes.float32) init_state = rnn_cell.LSTMStateTuple(state_value_c, state_value_h) output, state = cell(inputs, init_state) @@ -755,24 +762,20 @@ class RNNCellTest(test.TestCase): num_units = 2 batch_size = 3 expected_state_and_output = np.array( - [[0.13752282, 0.13752282], - [0.10545051, 0.10545051], + [[0.13752282, 0.13752282], [0.10545051, 0.10545051], [0.10074195, 0.10074195]], dtype=np.float32) with self.test_session() as sess: with variable_scope.variable_scope( - "ugrnn_cell_test", - initializer=init_ops.constant_initializer(0.5)): + "ugrnn_cell_test", initializer=init_ops.constant_initializer(0.5)): cell = contrib_rnn_cell.UGRNNCell(num_units=num_units) inputs = constant_op.constant( - np.array([[1., 1., 1., 1.], - [2., 2., 2., 2.], - [3., 3., 3., 3.]], - dtype=np.float32), + np.array( + [[1., 1., 1., 1.], [2., 2., 2., 2.], [3., 3., 3., 3.]], + dtype=np.float32), dtype=dtypes.float32) init_state = constant_op.constant( - 0.1 * np.ones( - (batch_size, num_units), dtype=np.float32), + 0.1 * np.ones((batch_size, num_units), dtype=np.float32), dtype=dtypes.float32) output, state = cell(inputs, init_state) sess.run([variables.global_variables_initializer()]) @@ -786,13 +789,11 @@ class RNNCellTest(test.TestCase): num_units = 2 batch_size = 3 expected_state = np.array( - [[0.13752282, 0.13752282], - [0.10545051, 0.10545051], + [[0.13752282, 0.13752282], [0.10545051, 0.10545051], [0.10074195, 0.10074195]], dtype=np.float32) expected_output = np.array( - [[2.00431061, 2.00431061], - [4.00060606, 4.00060606], + [[2.00431061, 2.00431061], [4.00060606, 4.00060606], [6.00008249, 6.00008249]], dtype=np.float32) with self.test_session() as sess: @@ -802,14 +803,12 @@ class RNNCellTest(test.TestCase): cell = contrib_rnn_cell.IntersectionRNNCell( num_units=num_units, num_in_proj=num_units) inputs = constant_op.constant( - np.array([[1., 1., 1., 1.], - [2., 2., 2., 2.], - [3., 3., 3., 3.]], - dtype=np.float32), + np.array( + [[1., 1., 1., 1.], [2., 2., 2., 2.], [3., 3., 3., 3.]], + dtype=np.float32), dtype=dtypes.float32) init_state = constant_op.constant( - 0.1 * np.ones( - (batch_size, num_units), dtype=np.float32), + 0.1 * np.ones((batch_size, num_units), dtype=np.float32), dtype=dtypes.float32) output, state = cell(inputs, init_state) sess.run([variables.global_variables_initializer()]) @@ -824,19 +823,17 @@ class RNNCellTest(test.TestCase): batch_size = 3 cell = contrib_rnn_cell.IntersectionRNNCell(num_units=num_units) inputs = constant_op.constant( - np.array([[1., 1., 1., 1.], - [2., 2., 2., 2.], - [3., 3., 3., 3.]], - dtype=np.float32), + np.array( + [[1., 1., 1., 1.], [2., 2., 2., 2.], [3., 3., 3., 3.]], + dtype=np.float32), dtype=dtypes.float32) init_state = constant_op.constant( - 0.1 * np.ones( - (batch_size, num_units), dtype=np.float32), + 0.1 * np.ones((batch_size, num_units), dtype=np.float32), dtype=dtypes.float32) - with self.assertRaisesRegexp( - ValueError, "Must have input size == output size for " - "Intersection RNN. To fix, num_in_proj should " - "be set to num_units at cell init."): + with self.assertRaisesRegexp(ValueError, + "Must have input size == output size for " + "Intersection RNN. To fix, num_in_proj should " + "be set to num_units at cell init."): cell(inputs, init_state) def testPhasedLSTMCell(self): @@ -845,13 +842,11 @@ class RNNCellTest(test.TestCase): batch_size = 3 input_size = 4 expected_state_c = np.array( - [[6.450831e-04, 4.697885e-04], - [9.862894e-05, 7.212213e-04], + [[6.450831e-04, 4.697885e-04], [9.862894e-05, 7.212213e-04], [4.401947e-04, 9.143004e-04]], dtype=np.float32) expected_state_h = np.array( - [[4.621217e-04, 3.365449e-04], - [7.438179e-05, 5.439147e-04], + [[4.621217e-04, 3.365449e-04], [7.438179e-05, 5.439147e-04], [3.347936e-04, 6.953785e-04]], dtype=np.float32) with variable_scope.variable_scope( @@ -864,14 +859,14 @@ class RNNCellTest(test.TestCase): output, state = contrib_rnn_cell.PhasedLSTMCell(num_units=num_units)( (t, x), state0) sess.run([variables.global_variables_initializer()]) - res = sess.run([output, state], { - t.name: - np.array([[1.], [2.], [3.]]), - x.name: - np.array([[1., 1., 1., 1.], - [2., 2., 2., 2.], - [3., 3., 3., 3.]]), - }) + res = sess.run( + [output, state], { + t.name: + np.array([[1.], [2.], [3.]]), + x.name: + np.array([[1., 1., 1., 1.], [2., 2., 2., 2.], + [3., 3., 3., 3.]]), + }) # This is a smoke test, making sure expected values are unchanged. self.assertEqual(len(res), 2) self.assertAllClose(res[0], res[1].h) @@ -880,36 +875,32 @@ class RNNCellTest(test.TestCase): def testConv1DLSTMCell(self): with self.test_session() as sess: - shape = [2,1] + shape = [2, 1] filter_size = [3] num_features = 1 batch_size = 2 expected_state_c = np.array( - [[[1.4375670191], [1.4375670191]], - [[2.7542609292], [2.7542609292]]], + [[[1.4375670191], [1.4375670191]], [[2.7542609292], [2.7542609292]]], dtype=np.float32) expected_state_h = np.array( - [[[0.6529865603], [0.6529865603]], - [[0.8736877431], [0.8736877431]]], + [[[0.6529865603], [0.6529865603]], [[0.8736877431], [0.8736877431]]], dtype=np.float32) with variable_scope.variable_scope( - "root", initializer=init_ops.constant_initializer(1.0/2.0)): + "root", initializer=init_ops.constant_initializer(1.0 / 2.0)): x = array_ops.placeholder(dtypes.float32, [None, None, 1]) - cell = contrib_rnn_cell.Conv1DLSTMCell(input_shape=shape, - kernel_shape=filter_size, - output_channels=num_features) + cell = contrib_rnn_cell.Conv1DLSTMCell( + input_shape=shape, + kernel_shape=filter_size, + output_channels=num_features) hidden = cell.zero_state(array_ops.shape(x)[0], dtypes.float32) output, state = cell(x, hidden) sess.run([variables.global_variables_initializer()]) - res = sess.run([output, state], { - hidden[0].name: - np.array([[[1.],[1.]], - [[2.],[2.]]]), - x.name: - np.array([[[1.],[1.]], - [[2.],[2.]]]), - }) + res = sess.run( + [output, state], { + hidden[0].name: np.array([[[1.], [1.]], [[2.], [2.]]]), + x.name: np.array([[[1.], [1.]], [[2.], [2.]]]), + }) # This is a smoke test, making sure expected values are unchanged. self.assertEqual(len(res), 2) self.assertAllClose(res[0], res[1].h) @@ -918,44 +909,40 @@ class RNNCellTest(test.TestCase): def testConv2DLSTMCell(self): with self.test_session() as sess: - shape = [2,2,1] - filter_size = [3,3] + shape = [2, 2, 1] + filter_size = [3, 3] num_features = 1 batch_size = 2 expected_state_c = np.array( - [[[[1.4375670191], [1.4375670191]], - [[1.4375670191], [1.4375670191]]], - [[[2.7542609292], [2.7542609292]], - [[2.7542609292], [2.7542609292]]]], + [[[[1.4375670191], [1.4375670191]], [[1.4375670191], [1.4375670191]]], + [[[2.7542609292], [2.7542609292]], [[2.7542609292], [2.7542609292]] + ]], dtype=np.float32) expected_state_h = np.array( - [[[[0.6529865603], [0.6529865603]], - [[0.6529865603], [0.6529865603]]], - [[[0.8736877431], [0.8736877431]], - [[0.8736877431], [0.8736877431]]]], + [[[[0.6529865603], [0.6529865603]], [[0.6529865603], [0.6529865603]]], + [[[0.8736877431], [0.8736877431]], [[0.8736877431], [0.8736877431]] + ]], dtype=np.float32) with variable_scope.variable_scope( - "root", initializer=init_ops.constant_initializer(1.0/4.0)): + "root", initializer=init_ops.constant_initializer(1.0 / 4.0)): x = array_ops.placeholder(dtypes.float32, [None, None, None, 1]) - cell = contrib_rnn_cell.Conv2DLSTMCell(input_shape=shape, - kernel_shape=filter_size, - output_channels=num_features) + cell = contrib_rnn_cell.Conv2DLSTMCell( + input_shape=shape, + kernel_shape=filter_size, + output_channels=num_features) hidden = cell.zero_state(array_ops.shape(x)[0], dtypes.float32) output, state = cell(x, hidden) sess.run([variables.global_variables_initializer()]) - res = sess.run([output, state], { - hidden[0].name: - np.array([[[[1.],[1.]], - [[1.],[1.]]], - [[[2.],[2.]], - [[2.],[2.]]]]), - x.name: - np.array([[[[1.],[1.]], - [[1.],[1.]]], - [[[2.],[2.]], - [[2.],[2.]]]]), - }) + res = sess.run( + [output, state], { + hidden[0].name: + np.array([[[[1.], [1.]], [[1.], [1.]]], [[[2.], [2.]], + [[2.], [2.]]]]), + x.name: + np.array([[[[1.], [1.]], [[1.], [1.]]], [[[2.], [2.]], + [[2.], [2.]]]]), + }) # This is a smoke test, making sure expected values are unchanged. self.assertEqual(len(res), 2) self.assertAllClose(res[0], res[1].h) @@ -964,36 +951,33 @@ class RNNCellTest(test.TestCase): def testConv3DLSTMCell(self): with self.test_session() as sess: - shape = [2,2,2,1] - filter_size = [3,3,3] + shape = [2, 2, 2, 1] + filter_size = [3, 3, 3] num_features = 1 batch_size = 2 expected_state_c = np.array( - [[[[[1.4375670191], [1.4375670191]], - [[1.4375670191], [1.4375670191]]], - [[[1.4375670191], [1.4375670191]], - [[1.4375670191], [1.4375670191]]]], - [[[[2.7542609292], [2.7542609292]], - [[2.7542609292], [2.7542609292]]], - [[[2.7542609292], [2.7542609292]], - [[2.7542609292], [2.7542609292]]]]], + [[[[[1.4375670191], [1.4375670191]], [[1.4375670191], [1.4375670191]] + ], [[[1.4375670191], [1.4375670191]], [[1.4375670191], + [1.4375670191]]]], + [[[[2.7542609292], [2.7542609292]], [[2.7542609292], [2.7542609292]] + ], [[[2.7542609292], [2.7542609292]], [[2.7542609292], + [2.7542609292]]]]], dtype=np.float32) expected_state_h = np.array( - [[[[[0.6529865603], [0.6529865603]], - [[0.6529865603], [0.6529865603]]], - [[[0.6529865603], [0.6529865603]], - [[0.6529865603], [0.6529865603]]]], - [[[[0.8736877431], [0.8736877431]], - [[0.8736877431], [0.8736877431]]], - [[[0.8736877431], [0.8736877431]], - [[0.8736877431], [0.8736877431]]]]], + [[[[[0.6529865603], [0.6529865603]], [[0.6529865603], [0.6529865603]] + ], [[[0.6529865603], [0.6529865603]], [[0.6529865603], + [0.6529865603]]]], + [[[[0.8736877431], [0.8736877431]], [[0.8736877431], [0.8736877431]] + ], [[[0.8736877431], [0.8736877431]], [[0.8736877431], + [0.8736877431]]]]], dtype=np.float32) with variable_scope.variable_scope( - "root", initializer=init_ops.constant_initializer(1.0/8.0)): + "root", initializer=init_ops.constant_initializer(1.0 / 8.0)): x = array_ops.placeholder(dtypes.float32, [None, None, None, None, 1]) - cell = contrib_rnn_cell.Conv3DLSTMCell(input_shape=shape, - kernel_shape=filter_size, - output_channels=num_features) + cell = contrib_rnn_cell.Conv3DLSTMCell( + input_shape=shape, + kernel_shape=filter_size, + output_channels=num_features) hidden = cell.zero_state(array_ops.shape(x)[0], dtypes.float32) output, state = cell(x, hidden) @@ -1056,8 +1040,8 @@ class RNNCellTest(test.TestCase): num_units=num_units, number_of_groups=number_of_groups) cell = rnn_cell.LSTMCell(num_units=num_units) self.assertTrue(isinstance(gcell.state_size, tuple)) - zero_state = gcell.zero_state(batch_size=batch_size, - dtype=dtypes.float32) + zero_state = gcell.zero_state( + batch_size=batch_size, dtype=dtypes.float32) gh, gs = gcell(x, zero_state) h, g = cell(x, zero_state) @@ -1080,16 +1064,16 @@ class RNNCellTest(test.TestCase): glstm_input = array_ops.ones([batch_size, num_units]) gcell = contrib_rnn_cell.GLSTMCell( num_units=num_units, number_of_groups=number_of_groups) - gcell_zero_state = gcell.zero_state(batch_size=batch_size, - dtype=dtypes.float32) + gcell_zero_state = gcell.zero_state( + batch_size=batch_size, dtype=dtypes.float32) gh, gs = gcell(glstm_input, gcell_zero_state) # input for LSTM cell simulating single G-LSTM group lstm_input = array_ops.ones([batch_size, num_units / number_of_groups]) # note division by number_of_groups. This cell one simulates G-LSTM group cell = rnn_cell.LSTMCell(num_units=int(num_units / number_of_groups)) - cell_zero_state = cell.zero_state(batch_size=batch_size, - dtype=dtypes.float32) + cell_zero_state = cell.zero_state( + batch_size=batch_size, dtype=dtypes.float32) h, g = cell(lstm_input, cell_zero_state) sess.run([variables.global_variables_initializer()]) @@ -1099,6 +1083,7 @@ class RNNCellTest(test.TestCase): self.assertAllClose(gh_res[:, int(num_units / number_of_groups):], h_res, 1e-5) + class LayerNormBasicLSTMCellTest(test.TestCase): # NOTE: all the values in the current test case have been calculated. @@ -1119,13 +1104,14 @@ class LayerNormBasicLSTMCellTest(test.TestCase): cell = rnn_cell.MultiRNNCell([single_cell() for _ in range(2)]) g, out_m = cell(x, state) sess.run([variables.global_variables_initializer()]) - res = sess.run([g, out_m], { - x.name: np.array([[1., 1.]]), - c0.name: 0.1 * np.asarray([[0, 1]]), - h0.name: 0.1 * np.asarray([[2, 3]]), - c1.name: 0.1 * np.asarray([[4, 5]]), - h1.name: 0.1 * np.asarray([[6, 7]]), - }) + res = sess.run( + [g, out_m], { + x.name: np.array([[1., 1.]]), + c0.name: 0.1 * np.asarray([[0, 1]]), + h0.name: 0.1 * np.asarray([[2, 3]]), + c1.name: 0.1 * np.asarray([[4, 5]]), + h1.name: 0.1 * np.asarray([[6, 7]]), + }) expected_h = np.array([[-0.38079708, 0.38079708]]) expected_state0_c = np.array([[-1.0, 1.0]]) @@ -1155,11 +1141,12 @@ class LayerNormBasicLSTMCellTest(test.TestCase): cell = contrib_rnn_cell.LayerNormBasicLSTMCell(2) g, out_m = cell(x, state) sess.run([variables.global_variables_initializer()]) - res = sess.run([g, out_m], { - x.name: np.array([[1., 1., 1.]]), - c.name: 0.1 * np.asarray([[0, 1]]), - h.name: 0.1 * np.asarray([[2, 3]]), - }) + res = sess.run( + [g, out_m], { + x.name: np.array([[1., 1., 1.]]), + c.name: 0.1 * np.asarray([[0, 1]]), + h.name: 0.1 * np.asarray([[2, 3]]), + }) expected_h = np.array([[-0.38079708, 0.38079708]]) expected_c = np.array([[-1.0, 1.0]]) @@ -1168,7 +1155,6 @@ class LayerNormBasicLSTMCellTest(test.TestCase): self.assertAllClose(res[1].c, expected_c, 1e-5) self.assertAllClose(res[1].h, expected_h, 1e-5) - def testBasicLSTMCellWithoutNorm(self): """Tests that BasicLSTMCell with layer_norm=False.""" with self.test_session() as sess: @@ -1186,19 +1172,20 @@ class LayerNormBasicLSTMCellTest(test.TestCase): cell = rnn_cell.MultiRNNCell([single_cell() for _ in range(2)]) g, out_m = cell(x, state) sess.run([variables.global_variables_initializer()]) - res = sess.run([g, out_m], { - x.name: np.array([[1., 1.]]), - c0.name: 0.1 * np.asarray([[0, 1]]), - h0.name: 0.1 * np.asarray([[2, 3]]), - c1.name: 0.1 * np.asarray([[4, 5]]), - h1.name: 0.1 * np.asarray([[6, 7]]), - }) + res = sess.run( + [g, out_m], { + x.name: np.array([[1., 1.]]), + c0.name: 0.1 * np.asarray([[0, 1]]), + h0.name: 0.1 * np.asarray([[2, 3]]), + c1.name: 0.1 * np.asarray([[4, 5]]), + h1.name: 0.1 * np.asarray([[6, 7]]), + }) - expected_h = np.array([[ 0.70230919, 0.72581059]]) - expected_state0_c = np.array([[ 0.8020075, 0.89599884]]) - expected_state0_h = np.array([[ 0.56668288, 0.60858738]]) - expected_state1_c = np.array([[ 1.17500675, 1.26892781]]) - expected_state1_h = np.array([[ 0.70230919, 0.72581059]]) + expected_h = np.array([[0.70230919, 0.72581059]]) + expected_state0_c = np.array([[0.8020075, 0.89599884]]) + expected_state0_h = np.array([[0.56668288, 0.60858738]]) + expected_state1_c = np.array([[1.17500675, 1.26892781]]) + expected_state1_h = np.array([[0.70230919, 0.72581059]]) actual_h = res[0] actual_state0_c = res[1][0].c @@ -1215,21 +1202,22 @@ class LayerNormBasicLSTMCellTest(test.TestCase): with variable_scope.variable_scope( "other", initializer=init_ops.constant_initializer(0.5)) as vs: x = array_ops.zeros( - [1, 3]) # Test BasicLSTMCell with input_size != num_units. + [1, 3]) # Test BasicLSTMCell with input_size != num_units. c = array_ops.zeros([1, 2]) h = array_ops.zeros([1, 2]) state = rnn_cell.LSTMStateTuple(c, h) cell = contrib_rnn_cell.LayerNormBasicLSTMCell(2, layer_norm=False) g, out_m = cell(x, state) sess.run([variables.global_variables_initializer()]) - res = sess.run([g, out_m], { - x.name: np.array([[1., 1., 1.]]), - c.name: 0.1 * np.asarray([[0, 1]]), - h.name: 0.1 * np.asarray([[2, 3]]), - }) - - expected_h = np.array([[ 0.64121795, 0.68166804]]) - expected_c = np.array([[ 0.88477188, 0.98103917]]) + res = sess.run( + [g, out_m], { + x.name: np.array([[1., 1., 1.]]), + c.name: 0.1 * np.asarray([[0, 1]]), + h.name: 0.1 * np.asarray([[2, 3]]), + }) + + expected_h = np.array([[0.64121795, 0.68166804]]) + expected_c = np.array([[0.88477188, 0.98103917]]) self.assertEqual(len(res), 2) self.assertAllClose(res[0], expected_h, 1e-5) self.assertAllClose(res[1].c, expected_c, 1e-5) @@ -1250,13 +1238,14 @@ class LayerNormBasicLSTMCellTest(test.TestCase): [contrib_rnn_cell.LayerNormBasicLSTMCell(2) for _ in range(2)]) h, (s0, s1) = cell(x, (state0, state1)) sess.run([variables.global_variables_initializer()]) - res = sess.run([h, s0, s1], { - x.name: np.array([[1., 1.]]), - c0.name: 0.1 * np.asarray([[0, 1]]), - h0.name: 0.1 * np.asarray([[2, 3]]), - c1.name: 0.1 * np.asarray([[4, 5]]), - h1.name: 0.1 * np.asarray([[6, 7]]), - }) + res = sess.run( + [h, s0, s1], { + x.name: np.array([[1., 1.]]), + c0.name: 0.1 * np.asarray([[0, 1]]), + h0.name: 0.1 * np.asarray([[2, 3]]), + c1.name: 0.1 * np.asarray([[4, 5]]), + h1.name: 0.1 * np.asarray([[6, 7]]), + }) expected_h = np.array([[-0.38079708, 0.38079708]]) expected_h0 = np.array([[-0.38079708, 0.38079708]]) @@ -1344,11 +1333,12 @@ class LayerNormBasicLSTMCellTest(test.TestCase): g, s = cell(x, state) sess.run([variables.global_variables_initializer()]) - res = sess.run([g, s], { - x.name: np.ones([1, 5]), - c.name: np.ones([1, 5]), - h.name: np.ones([1, 5]), - }) + res = sess.run( + [g, s], { + x.name: np.ones([1, 5]), + c.name: np.ones([1, 5]), + h.name: np.ones([1, 5]), + }) # Since the returned tensors are of size [1,n] # get the first component right now. @@ -1374,35 +1364,35 @@ class LayerNormBasicLSTMCellTest(test.TestCase): self.assertIn(dropped_count, allowed_low) -def _create_multi_lstm_cell_ops(batch_size, num_units, input_depth, - num_layers, max_time, compiled): +def _create_multi_lstm_cell_ops(batch_size, num_units, input_depth, num_layers, + max_time, compiled): with variable_scope.variable_scope( "root", initializer=init_ops.random_uniform_initializer(-0.1, 0.1, seed=2)): inputs = variable_scope.get_variable( - "inputs", initializer=random_ops.random_uniform( + "inputs", + initializer=random_ops.random_uniform( (max_time, batch_size, input_depth), seed=1)) maybe_xla = lambda c: contrib_rnn_cell.CompiledWrapper(c) if compiled else c cell = rnn_cell.MultiRNNCell( [maybe_xla(rnn_cell.LSTMCell(num_units)) for _ in range(num_layers)]) - initial_state = cell.zero_state( - batch_size=batch_size, dtype=dtypes.float32) + initial_state = cell.zero_state(batch_size=batch_size, dtype=dtypes.float32) outputs, final_state = rnn.dynamic_rnn( - cell=cell, inputs=inputs, initial_state=initial_state, - time_major=True) + cell=cell, inputs=inputs, initial_state=initial_state, time_major=True) flat_final_state = nest.flatten(final_state) trainable_variables = variables.trainable_variables() outputs_grad = gradients_impl.gradients( - [outputs], - trainable_variables + [inputs] + nest.flatten(initial_state)) + [outputs], trainable_variables + [inputs] + nest.flatten(initial_state)) final_state_grad = gradients_impl.gradients( flat_final_state, trainable_variables + [inputs] + nest.flatten(initial_state)) - return {"outputs": outputs, - "final_state": flat_final_state, - "outputs_grad": outputs_grad, - "final_state_grad": final_state_grad} + return { + "outputs": outputs, + "final_state": flat_final_state, + "outputs_grad": outputs_grad, + "final_state_grad": final_state_grad + } class CompiledWrapperTest(test.TestCase): @@ -1420,8 +1410,10 @@ class CompiledWrapperTest(test.TestCase): random_seed.set_random_seed(1234) with self.test_session(graph=ops.Graph()) as sess: xla_ops = _create_multi_lstm_cell_ops( - batch_size=batch_size, num_units=num_units, - input_depth=input_depth, num_layers=num_layers, + batch_size=batch_size, + num_units=num_units, + input_depth=input_depth, + num_layers=num_layers, max_time=max_time, compiled=True) sess.run([variables.global_variables_initializer()]) @@ -1430,8 +1422,10 @@ class CompiledWrapperTest(test.TestCase): random_seed.set_random_seed(1234) with self.test_session(graph=ops.Graph()) as sess: non_xla_ops = _create_multi_lstm_cell_ops( - batch_size=batch_size, num_units=num_units, - input_depth=input_depth, num_layers=num_layers, + batch_size=batch_size, + num_units=num_units, + input_depth=input_depth, + num_layers=num_layers, max_time=max_time, compiled=False) sess.run([variables.global_variables_initializer()]) @@ -1440,16 +1434,16 @@ class CompiledWrapperTest(test.TestCase): self.assertAllClose( non_xla_results["outputs"], xla_results["outputs"], atol=atol) - for xla_value, non_xla_value in zip( - xla_results["final_state"], non_xla_results["final_state"]): + for xla_value, non_xla_value in zip(xla_results["final_state"], + non_xla_results["final_state"]): self.assertAllClose(xla_value, non_xla_value, atol=atol) - for xla_g, non_xla_g in zip( - xla_results["outputs_grad"], non_xla_results["outputs_grad"]): + for xla_g, non_xla_g in zip(xla_results["outputs_grad"], + non_xla_results["outputs_grad"]): self.assertAllClose(xla_g, non_xla_g, atol=atol) - for xla_g, non_xla_g in zip( - xla_results["final_state_grad"], non_xla_results["final_state_grad"]): + for xla_g, non_xla_g in zip(xla_results["final_state_grad"], + non_xla_results["final_state_grad"]): self.assertAllClose(xla_g, non_xla_g, atol=atol) def testMultiRNNCellWithStateTuple(self): @@ -1463,19 +1457,20 @@ class CompiledWrapperTest(test.TestCase): # Test incorrectness of state with self.assertRaisesRegexp(ValueError, "Expected state .* a tuple"): rnn_cell.MultiRNNCell( - [rnn_cell.GRUCell(2) - for _ in range(2)], state_is_tuple=True)(x, m_bad) + [rnn_cell.GRUCell(2) for _ in range(2)], + state_is_tuple=True)(x, m_bad) _, ml = rnn_cell.MultiRNNCell( - [rnn_cell.GRUCell(2) - for _ in range(2)], state_is_tuple=True)(x, m_good) + [rnn_cell.GRUCell(2) for _ in range(2)], + state_is_tuple=True)(x, m_good) sess.run([variables.global_variables_initializer()]) - res = sess.run(ml, { - x.name: np.array([[1., 1.]]), - m_good[0].name: np.array([[0.1, 0.1]]), - m_good[1].name: np.array([[0.1, 0.1]]) - }) + res = sess.run( + ml, { + x.name: np.array([[1., 1.]]), + m_good[0].name: np.array([[0.1, 0.1]]), + m_good[1].name: np.array([[0.1, 0.1]]) + }) # The numbers in results were not calculated, this is just a # smoke test. However, these numbers should match those of @@ -1490,24 +1485,20 @@ class BenchmarkLSTMCellXLA(test.Benchmark): num_layers = 3 max_time = 50 print("benchmarkDynamicRNNWithMultiLSTMCell") - print("\t" + - "\t".join(["inter_th", "intra_th", - "batch_size", "num_units", "input_depth", "device", - "compiled", "wall_time"])) + print("\t" + "\t".join([ + "inter_th", "intra_th", "batch_size", "num_units", "input_depth", + "device", "compiled", "wall_time" + ])) warmup_run = True - for (threads, - device, - num_units, - batch_size, - input_depth, - compiled) in itertools.product( - [{"inter": 0, "intra": 0}, {"inter": 1, "intra": 4}], - ["cpu", "gpu"], - [32, 512], - [1, 32, 256], - [32, 512], - [False, True]): + for (threads, device, num_units, batch_size, input_depth, + compiled) in itertools.product([{ + "inter": 0, + "intra": 0 + }, { + "inter": 1, + "intra": 4 + }], ["cpu", "gpu"], [32, 512], [1, 32, 256], [32, 512], [False, True]): if threads["inter"] != 0: # We only care about testing inter/intra op limitations on # CPU with small batch size, to mimic embedded devices. @@ -1523,31 +1514,222 @@ class BenchmarkLSTMCellXLA(test.Benchmark): with session.Session(config=config, graph=ops.Graph()) as sess: with ops.device("/%s:0" % device): ops_dict = _create_multi_lstm_cell_ops( - batch_size=batch_size, num_units=num_units, - input_depth=input_depth, num_layers=num_layers, + batch_size=batch_size, + num_units=num_units, + input_depth=input_depth, + num_layers=num_layers, max_time=max_time, compiled=compiled) sess.run([variables.global_variables_initializer()]) all_ops = nest.flatten(ops_dict.values()) all_ops_group = control_flow_ops.group(*all_ops) - name_suffix = ( - "inter_th_%d_intra_th_%d_bs_%d_units_%d_inputdepth_%d" - "_device_%s_xla_%s" % ( - threads["inter"], threads["intra"], - batch_size, num_units, input_depth, device, compiled)) + name_suffix = ("inter_th_%d_intra_th_%d_bs_%d_units_%d_inputdepth_%d" + "_device_%s_xla_%s" % + (threads["inter"], threads["intra"], batch_size, + num_units, input_depth, device, compiled)) if warmup_run: self.run_op_benchmark( sess, all_ops_group, min_iters=30, name="ignore_warmup") warmup_run = False benchmark_results = self.run_op_benchmark( - sess, all_ops_group, min_iters=50, + sess, + all_ops_group, + min_iters=50, name="benchmarkDynamicRNNWithMultiLSTMCell_%s" % name_suffix) - print("\t" + - "\t".join(["%s" % x for x in [ - threads["inter"], threads["intra"], - batch_size, num_units, input_depth, device, compiled, - benchmark_results["wall_time"]]])) + print("\t" + "\t".join([ + "%s" % x + for x in [ + threads["inter"], threads["intra"], batch_size, num_units, + input_depth, device, compiled, benchmark_results["wall_time"] + ] + ])) + + +class WeightNormLSTMCellTest(test.TestCase): + """Compared cell output with pre-calculated values.""" + + def _cell_output(self, cell): + """Calculate cell output""" + + with self.test_session() as sess: + init = init_ops.constant_initializer(0.5) + with variable_scope.variable_scope("root", initializer=init): + x = array_ops.zeros([1, 2]) + c0 = array_ops.zeros([1, 2]) + h0 = array_ops.zeros([1, 2]) + + state0 = rnn_cell.LSTMStateTuple(c0, h0) + + xout, sout = cell()(x, state0) + + sess.run([variables.global_variables_initializer()]) + res = sess.run( + [xout, sout], { + x.name: np.array([[1., 1.]]), + c0.name: 0.1 * np.asarray([[0, 1]]), + h0.name: 0.1 * np.asarray([[2, 3]]), + }) + + actual_state_c = res[1].c + actual_state_h = res[1].h + + return actual_state_c, actual_state_h + + def testBasicCell(self): + """Tests cell w/o peepholes and w/o normalisation""" + + def cell(): + return contrib_rnn_cell.WeightNormLSTMCell( + 2, norm=False, use_peepholes=False) + + actual_c, actual_h = self._cell_output(cell) + + expected_c = np.array([[0.65937078, 0.74983585]]) + expected_h = np.array([[0.44923624, 0.49362513]]) + + self.assertAllClose(expected_c, actual_c, 1e-5) + self.assertAllClose(expected_h, actual_h, 1e-5) + + def testNonbasicCell(self): + """Tests cell with peepholes and w/o normalisation""" + + def cell(): + return contrib_rnn_cell.WeightNormLSTMCell( + 2, norm=False, use_peepholes=True) + + actual_c, actual_h = self._cell_output(cell) + + expected_c = np.array([[0.65937084, 0.7574988]]) + expected_h = np.array([[0.4792085, 0.53470564]]) + + self.assertAllClose(expected_c, actual_c, 1e-5) + self.assertAllClose(expected_h, actual_h, 1e-5) + + def testBasicCellWithNorm(self): + """Tests cell w/o peepholes and with normalisation""" + + def cell(): + return contrib_rnn_cell.WeightNormLSTMCell( + 2, norm=True, use_peepholes=False) + + actual_c, actual_h = self._cell_output(cell) + + expected_c = np.array([[0.50125383, 0.58805949]]) + expected_h = np.array([[0.32770363, 0.37397948]]) + + self.assertAllClose(expected_c, actual_c, 1e-5) + self.assertAllClose(expected_h, actual_h, 1e-5) + + def testNonBasicCellWithNorm(self): + """Tests cell with peepholes and with normalisation""" + + def cell(): + return contrib_rnn_cell.WeightNormLSTMCell( + 2, norm=True, use_peepholes=True) + + actual_c, actual_h = self._cell_output(cell) + + expected_c = np.array([[0.50125383, 0.59587258]]) + expected_h = np.array([[0.35041603, 0.40873795]]) + + self.assertAllClose(expected_c, actual_c, 1e-5) + self.assertAllClose(expected_h, actual_h, 1e-5) + + +class WeightNormLSTMCellTest(test.TestCase): + """Compared cell output with pre-calculated values.""" + + def _cell_output(self, cell): + """Calculate cell output""" + + with self.test_session() as sess: + init = init_ops.constant_initializer(0.5) + with variable_scope.variable_scope("root", + initializer=init): + x = array_ops.zeros([1, 2]) + c0 = array_ops.zeros([1, 2]) + h0 = array_ops.zeros([1, 2]) + + state0 = rnn_cell.LSTMStateTuple(c0, h0) + + xout, sout = cell()(x, state0) + + sess.run([variables.global_variables_initializer()]) + res = sess.run([xout, sout], { + x.name: np.array([[1., 1.]]), + c0.name: 0.1 * np.asarray([[0, 1]]), + h0.name: 0.1 * np.asarray([[2, 3]]), + }) + + actual_state_c = res[1].c + actual_state_h = res[1].h + + return actual_state_c, actual_state_h + + def testBasicCell(self): + """Tests cell w/o peepholes and w/o normalisation""" + + def cell(): + return contrib_rnn_cell.WeightNormLSTMCell(2, + norm=False, + use_peepholes=False) + + actual_c, actual_h = self._cell_output(cell) + + expected_c = np.array([[0.65937078, 0.74983585]]) + expected_h = np.array([[0.44923624, 0.49362513]]) + + self.assertAllClose(expected_c, actual_c, 1e-5) + self.assertAllClose(expected_h, actual_h, 1e-5) + + def testNonbasicCell(self): + """Tests cell with peepholes and w/o normalisation""" + + def cell(): + return contrib_rnn_cell.WeightNormLSTMCell(2, + norm=False, + use_peepholes=True) + + actual_c, actual_h = self._cell_output(cell) + + expected_c = np.array([[0.65937084, 0.7574988]]) + expected_h = np.array([[0.4792085, 0.53470564]]) + + self.assertAllClose(expected_c, actual_c, 1e-5) + self.assertAllClose(expected_h, actual_h, 1e-5) + + + def testBasicCellWithNorm(self): + """Tests cell w/o peepholes and with normalisation""" + + def cell(): + return contrib_rnn_cell.WeightNormLSTMCell(2, + norm=True, + use_peepholes=False) + + actual_c, actual_h = self._cell_output(cell) + + expected_c = np.array([[0.50125383, 0.58805949]]) + expected_h = np.array([[0.32770363, 0.37397948]]) + + self.assertAllClose(expected_c, actual_c, 1e-5) + self.assertAllClose(expected_h, actual_h, 1e-5) + + def testNonBasicCellWithNorm(self): + """Tests cell with peepholes and with normalisation""" + + def cell(): + return contrib_rnn_cell.WeightNormLSTMCell(2, + norm=True, + use_peepholes=True) + + actual_c, actual_h = self._cell_output(cell) + + expected_c = np.array([[0.50125383, 0.59587258]]) + expected_h = np.array([[0.35041603, 0.40873795]]) + self.assertAllClose(expected_c, actual_c, 1e-5) + self.assertAllClose(expected_h, actual_h, 1e-5) if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/rnn/python/ops/rnn_cell.py b/tensorflow/contrib/rnn/python/ops/rnn_cell.py index e4667828cdaad627143efcb823eee39aec24fab7..8adf5dce6ec76d8ac4f182929e0dfc81be946277 100644 --- a/tensorflow/contrib/rnn/python/ops/rnn_cell.py +++ b/tensorflow/contrib/rnn/python/ops/rnn_cell.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - """Module for constructing RNN Cells.""" from __future__ import absolute_import from __future__ import division @@ -38,6 +37,7 @@ from tensorflow.python.ops import random_ops from tensorflow.python.ops import rnn_cell_impl from tensorflow.python.ops import variable_scope as vs from tensorflow.python.ops import partitioned_variables +from tensorflow.python.ops import nn_impl from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import nest @@ -55,16 +55,15 @@ def _get_concat_variable(name, shape, dtype, num_shards): return value concat_variable = array_ops.concat(sharded_variable, 0, name=concat_name) - ops.add_to_collection(ops.GraphKeys.CONCATENATED_VARIABLES, - concat_variable) + ops.add_to_collection(ops.GraphKeys.CONCATENATED_VARIABLES, concat_variable) return concat_variable def _get_sharded_variable(name, shape, dtype, num_shards): """Get a list of sharded variables with the given dtype.""" if num_shards > shape[0]: - raise ValueError("Too many shards: shape=%s, num_shards=%d" % - (shape, num_shards)) + raise ValueError("Too many shards: shape=%s, num_shards=%d" % (shape, + num_shards)) unit_shard_size = int(math.floor(shape[0] / num_shards)) remaining_rows = shape[0] - unit_shard_size * num_shards @@ -73,8 +72,9 @@ def _get_sharded_variable(name, shape, dtype, num_shards): current_size = unit_shard_size if i < remaining_rows: current_size += 1 - shards.append(vs.get_variable(name + "_%d" % i, [current_size] + shape[1:], - dtype=dtype)) + shards.append( + vs.get_variable( + name + "_%d" % i, [current_size] + shape[1:], dtype=dtype)) return shards @@ -176,9 +176,8 @@ class CoupledInputForgetGateLSTMCell(rnn_cell_impl.RNNCell): """ super(CoupledInputForgetGateLSTMCell, self).__init__(_reuse=reuse) if not state_is_tuple: - logging.warn( - "%s: Using a concatenated state is slower and will soon be " - "deprecated. Use state_is_tuple=True.", self) + logging.warn("%s: Using a concatenated state is slower and will soon be " + "deprecated. Use state_is_tuple=True.", self) self._num_units = num_units self._use_peepholes = use_peepholes self._initializer = initializer @@ -195,12 +194,14 @@ class CoupledInputForgetGateLSTMCell(rnn_cell_impl.RNNCell): self._norm_shift = norm_shift if num_proj: - self._state_size = (rnn_cell_impl.LSTMStateTuple(num_units, num_proj) - if state_is_tuple else num_units + num_proj) + self._state_size = ( + rnn_cell_impl.LSTMStateTuple(num_units, num_proj) + if state_is_tuple else num_units + num_proj) self._output_size = num_proj else: - self._state_size = (rnn_cell_impl.LSTMStateTuple(num_units, num_units) - if state_is_tuple else 2 * num_units) + self._state_size = ( + rnn_cell_impl.LSTMStateTuple(num_units, num_units) + if state_is_tuple else 2 * num_units) self._output_size = num_units @property @@ -250,8 +251,8 @@ class CoupledInputForgetGateLSTMCell(rnn_cell_impl.RNNCell): if input_size.value is None: raise ValueError("Could not infer input size from inputs.get_shape()[-1]") concat_w = _get_concat_variable( - "W", [input_size.value + num_proj, 3 * self._num_units], - dtype, self._num_unit_shards) + "W", [input_size.value + num_proj, 3 * self._num_units], dtype, + self._num_unit_shards) b = vs.get_variable( "B", @@ -298,9 +299,9 @@ class CoupledInputForgetGateLSTMCell(rnn_cell_impl.RNNCell): m = sigmoid(o) * self._activation(c) if self._num_proj is not None: - concat_w_proj = _get_concat_variable( - "W_P", [self._num_units, self._num_proj], - dtype, self._num_proj_shards) + concat_w_proj = _get_concat_variable("W_P", + [self._num_units, self._num_proj], + dtype, self._num_proj_shards) m = math_ops.matmul(m, concat_w_proj) if self._proj_clip is not None: @@ -308,8 +309,9 @@ class CoupledInputForgetGateLSTMCell(rnn_cell_impl.RNNCell): m = clip_ops.clip_by_value(m, -self._proj_clip, self._proj_clip) # pylint: enable=invalid-unary-operand-type - new_state = (rnn_cell_impl.LSTMStateTuple(c, m) - if self._state_is_tuple else array_ops.concat([c, m], 1)) + new_state = ( + rnn_cell_impl.LSTMStateTuple(c, m) + if self._state_is_tuple else array_ops.concat([c, m], 1)) return m, new_state @@ -325,10 +327,15 @@ class TimeFreqLSTMCell(rnn_cell_impl.RNNCell): It uses peep-hole connections and optional cell clipping. """ - def __init__(self, num_units, use_peepholes=False, - cell_clip=None, initializer=None, - num_unit_shards=1, forget_bias=1.0, - feature_size=None, frequency_skip=None, + def __init__(self, + num_units, + use_peepholes=False, + cell_clip=None, + initializer=None, + num_unit_shards=1, + forget_bias=1.0, + feature_size=None, + frequency_skip=1, reuse=None): """Initialize the parameters for an LSTM cell. @@ -398,7 +405,7 @@ class TimeFreqLSTMCell(rnn_cell_impl.RNNCell): actual_input_size = freq_inputs[0].get_shape().as_list()[1] concat_w = _get_concat_variable( - "W", [actual_input_size + 2*self._num_units, 4 * self._num_units], + "W", [actual_input_size + 2 * self._num_units, 4 * self._num_units], dtype, self._num_unit_shards) b = vs.get_variable( @@ -417,23 +424,23 @@ class TimeFreqLSTMCell(rnn_cell_impl.RNNCell): "W_O_diag", shape=[self._num_units], dtype=dtype) # initialize the first freq state to be zero - m_prev_freq = array_ops.zeros([int(inputs.get_shape()[0]), - self._num_units], dtype) + m_prev_freq = array_ops.zeros([int(inputs.get_shape()[0]), self._num_units], + dtype) for fq in range(len(freq_inputs)): - c_prev = array_ops.slice(state, [0, 2*fq*self._num_units], + c_prev = array_ops.slice(state, [0, 2 * fq * self._num_units], [-1, self._num_units]) - m_prev = array_ops.slice(state, [0, (2*fq+1)*self._num_units], + m_prev = array_ops.slice(state, [0, (2 * fq + 1) * self._num_units], [-1, self._num_units]) # i = input_gate, j = new_input, f = forget_gate, o = output_gate - cell_inputs = array_ops.concat([freq_inputs[fq], m_prev, m_prev_freq], - 1) + cell_inputs = array_ops.concat([freq_inputs[fq], m_prev, m_prev_freq], 1) lstm_matrix = nn_ops.bias_add(math_ops.matmul(cell_inputs, concat_w), b) i, j, f, o = array_ops.split( value=lstm_matrix, num_or_size_splits=4, axis=1) if self._use_peepholes: - c = (sigmoid(f + self._forget_bias + w_f_diag * c_prev) * c_prev + - sigmoid(i + w_i_diag * c_prev) * tanh(j)) + c = ( + sigmoid(f + self._forget_bias + w_f_diag * c_prev) * c_prev + + sigmoid(i + w_i_diag * c_prev) * tanh(j)) else: c = (sigmoid(f + self._forget_bias) * c_prev + sigmoid(i) * tanh(j)) @@ -471,11 +478,11 @@ class TimeFreqLSTMCell(rnn_cell_impl.RNNCell): input_size = input_feat.get_shape().with_rank(2)[-1].value if input_size is None: raise ValueError("Cannot infer input_size from static shape inference.") - num_feats = int((input_size - self._feature_size) / ( - self._frequency_skip)) + 1 + num_feats = int( + (input_size - self._feature_size) / (self._frequency_skip)) + 1 freq_inputs = [] for f in range(num_feats): - cur_input = array_ops.slice(input_feat, [0, f*self._frequency_skip], + cur_input = array_ops.slice(input_feat, [0, f * self._frequency_skip], [-1, self._feature_size]) freq_inputs.append(cur_input) return freq_inputs @@ -497,11 +504,16 @@ class GridLSTMCell(rnn_cell_impl.RNNCell): The code uses optional peephole connections, shared_weights and cell clipping. """ - def __init__(self, num_units, use_peepholes=False, + def __init__(self, + num_units, + use_peepholes=False, share_time_frequency_weights=False, - cell_clip=None, initializer=None, - num_unit_shards=1, forget_bias=1.0, - feature_size=None, frequency_skip=None, + cell_clip=None, + initializer=None, + num_unit_shards=1, + forget_bias=1.0, + feature_size=None, + frequency_skip=None, num_frequency_blocks=None, start_freqindex_list=None, end_freqindex_list=None, @@ -579,10 +591,10 @@ class GridLSTMCell(rnn_cell_impl.RNNCell): for freq_index in range(self._num_frequency_blocks[block_index]): name_prefix = "state_f%02d_b%02d" % (freq_index, block_index) state_names += ("%s_c, %s_m," % (name_prefix, name_prefix)) - self._state_tuple_type = collections.namedtuple( - "GridLSTMStateTuple", state_names.strip(",")) - self._state_size = self._state_tuple_type( - *([num_units, num_units] * self._total_blocks)) + self._state_tuple_type = collections.namedtuple("GridLSTMStateTuple", + state_names.strip(",")) + self._state_size = self._state_tuple_type(*( + [num_units, num_units] * self._total_blocks)) else: self._state_tuple_type = None self._state_size = num_units * self._total_blocks * 2 @@ -625,7 +637,10 @@ class GridLSTMCell(rnn_cell_impl.RNNCell): state_out_lst = [] for block in range(len(freq_inputs)): m_out_lst_current, state_out_lst_current = self._compute( - freq_inputs[block], block, state, batch_size, + freq_inputs[block], + block, + state, + batch_size, state_is_tuple=self._state_is_tuple) m_out_lst.extend(m_out_lst_current) state_out_lst.extend(state_out_lst_current) @@ -636,7 +651,11 @@ class GridLSTMCell(rnn_cell_impl.RNNCell): m_out = array_ops.concat(m_out_lst, 1) return m_out, state_out - def _compute(self, freq_inputs, block, state, batch_size, + def _compute(self, + freq_inputs, + block, + state, + batch_size, state_prefix="state", state_is_tuple=True): """Run the actual computation of one step LSTM. @@ -665,8 +684,8 @@ class GridLSTMCell(rnn_cell_impl.RNNCell): actual_input_size = freq_inputs[0].get_shape().as_list()[1] concat_w_f = _get_concat_variable( - "W_f_%d" % block, [actual_input_size + 2 * self._num_units, - num_gates * self._num_units], + "W_f_%d" % block, + [actual_input_size + 2 * self._num_units, num_gates * self._num_units], dtype, self._num_unit_shards) b_f = vs.get_variable( "B_f_%d" % block, @@ -674,10 +693,9 @@ class GridLSTMCell(rnn_cell_impl.RNNCell): initializer=init_ops.zeros_initializer(), dtype=dtype) if not self._share_time_frequency_weights: - concat_w_t = _get_concat_variable( - "W_t_%d" % block, [actual_input_size + 2 * self._num_units, - num_gates * self._num_units], - dtype, self._num_unit_shards) + concat_w_t = _get_concat_variable("W_t_%d" % block, [ + actual_input_size + 2 * self._num_units, num_gates * self._num_units + ], dtype, self._num_unit_shards) b_t = vs.get_variable( "B_t_%d" % block, shape=[num_gates * self._num_units], @@ -690,7 +708,7 @@ class GridLSTMCell(rnn_cell_impl.RNNCell): w_f_diag_freqf = vs.get_variable( "W_F_diag_freqf_%d" % block, shape=[self._num_units], dtype=dtype) w_f_diag_freqt = vs.get_variable( - "W_F_diag_freqt_%d"% block, shape=[self._num_units], dtype=dtype) + "W_F_diag_freqt_%d" % block, shape=[self._num_units], dtype=dtype) w_i_diag_freqf = vs.get_variable( "W_I_diag_freqf_%d" % block, shape=[self._num_units], dtype=dtype) w_i_diag_freqt = vs.get_variable( @@ -724,8 +742,7 @@ class GridLSTMCell(rnn_cell_impl.RNNCell): m_prev_time = getattr(state, name_prefix + "_m") else: c_prev_time = array_ops.slice( - state, [0, 2 * freq_index * self._num_units], - [-1, self._num_units]) + state, [0, 2 * freq_index * self._num_units], [-1, self._num_units]) m_prev_time = array_ops.slice( state, [0, (2 * freq_index + 1) * self._num_units], [-1, self._num_units]) @@ -735,8 +752,8 @@ class GridLSTMCell(rnn_cell_impl.RNNCell): [freq_inputs[freq_index], m_prev_time, m_prev_freq], 1) # F-LSTM - lstm_matrix_freq = nn_ops.bias_add(math_ops.matmul(cell_inputs, - concat_w_f), b_f) + lstm_matrix_freq = nn_ops.bias_add( + math_ops.matmul(cell_inputs, concat_w_f), b_f) if self._couple_input_forget_gates: i_freq, j_freq, o_freq = array_ops.split( value=lstm_matrix_freq, num_or_size_splits=num_gates, axis=1) @@ -751,8 +768,8 @@ class GridLSTMCell(rnn_cell_impl.RNNCell): f_time = f_freq o_time = o_freq else: - lstm_matrix_time = nn_ops.bias_add(math_ops.matmul(cell_inputs, - concat_w_t), b_t) + lstm_matrix_time = nn_ops.bias_add( + math_ops.matmul(cell_inputs, concat_w_t), b_t) if self._couple_input_forget_gates: i_time, j_time, o_time = array_ops.split( value=lstm_matrix_time, num_or_size_splits=num_gates, axis=1) @@ -764,8 +781,7 @@ class GridLSTMCell(rnn_cell_impl.RNNCell): # F-LSTM c_freq # input gate activations if self._use_peepholes: - i_freq_g = sigmoid(i_freq + - w_i_diag_freqf * c_prev_freq + + i_freq_g = sigmoid(i_freq + w_i_diag_freqf * c_prev_freq + w_i_diag_freqt * c_prev_time) else: i_freq_g = sigmoid(i_freq) @@ -774,9 +790,8 @@ class GridLSTMCell(rnn_cell_impl.RNNCell): f_freq_g = 1.0 - i_freq_g else: if self._use_peepholes: - f_freq_g = sigmoid(f_freq + self._forget_bias + - w_f_diag_freqf * c_prev_freq + - w_f_diag_freqt * c_prev_time) + f_freq_g = sigmoid(f_freq + self._forget_bias + w_f_diag_freqf * + c_prev_freq + w_f_diag_freqt * c_prev_time) else: f_freq_g = sigmoid(f_freq + self._forget_bias) # cell state @@ -791,12 +806,10 @@ class GridLSTMCell(rnn_cell_impl.RNNCell): # input gate activations if self._use_peepholes: if self._share_time_frequency_weights: - i_time_g = sigmoid(i_time + - w_i_diag_freqf * c_prev_freq + + i_time_g = sigmoid(i_time + w_i_diag_freqf * c_prev_freq + w_i_diag_freqt * c_prev_time) else: - i_time_g = sigmoid(i_time + - w_i_diag_timef * c_prev_freq + + i_time_g = sigmoid(i_time + w_i_diag_timef * c_prev_freq + w_i_diag_timet * c_prev_time) else: i_time_g = sigmoid(i_time) @@ -806,13 +819,11 @@ class GridLSTMCell(rnn_cell_impl.RNNCell): else: if self._use_peepholes: if self._share_time_frequency_weights: - f_time_g = sigmoid(f_time + self._forget_bias + - w_f_diag_freqf * c_prev_freq + - w_f_diag_freqt * c_prev_time) + f_time_g = sigmoid(f_time + self._forget_bias + w_f_diag_freqf * + c_prev_freq + w_f_diag_freqt * c_prev_time) else: - f_time_g = sigmoid(f_time + self._forget_bias + - w_f_diag_timef * c_prev_freq + - w_f_diag_timet * c_prev_time) + f_time_g = sigmoid(f_time + self._forget_bias + w_f_diag_timef * + c_prev_freq + w_f_diag_timet * c_prev_time) else: f_time_g = sigmoid(f_time + self._forget_bias) # cell state @@ -825,8 +836,7 @@ class GridLSTMCell(rnn_cell_impl.RNNCell): # F-LSTM m_freq if self._use_peepholes: - m_freq = sigmoid(o_freq + - w_o_diag_freqf * c_freq + + m_freq = sigmoid(o_freq + w_o_diag_freqf * c_freq + w_o_diag_freqt * c_time) * tanh(c_freq) else: m_freq = sigmoid(o_freq) * tanh(c_freq) @@ -834,12 +844,10 @@ class GridLSTMCell(rnn_cell_impl.RNNCell): # T-LSTM m_time if self._use_peepholes: if self._share_time_frequency_weights: - m_time = sigmoid(o_time + - w_o_diag_freqf * c_freq + + m_time = sigmoid(o_time + w_o_diag_freqf * c_freq + w_o_diag_freqt * c_time) * tanh(c_time) else: - m_time = sigmoid(o_time + - w_o_diag_timef * c_freq + + m_time = sigmoid(o_time + w_o_diag_timef * c_freq + w_o_diag_timet * c_time) * tanh(c_time) else: m_time = sigmoid(o_time) * tanh(c_time) @@ -878,16 +886,18 @@ class GridLSTMCell(rnn_cell_impl.RNNCell): raise ValueError("Cannot infer input_size from static shape inference.") if slice_offset > 0: # Padding to the end - inputs = array_ops.pad( - input_feat, array_ops.constant([0, 0, 0, slice_offset], shape=[2, 2], - dtype=dtypes.int32), - "CONSTANT") + inputs = array_ops.pad(input_feat, + array_ops.constant( + [0, 0, 0, slice_offset], + shape=[2, 2], + dtype=dtypes.int32), "CONSTANT") elif slice_offset < 0: # Padding to the front - inputs = array_ops.pad( - input_feat, array_ops.constant([0, 0, -slice_offset, 0], shape=[2, 2], - dtype=dtypes.int32), - "CONSTANT") + inputs = array_ops.pad(input_feat, + array_ops.constant( + [0, 0, -slice_offset, 0], + shape=[2, 2], + dtype=dtypes.int32), "CONSTANT") slice_offset = 0 else: inputs = input_feat @@ -897,13 +907,13 @@ class GridLSTMCell(rnn_cell_impl.RNNCell): raise ValueError("Length of num_frequency_blocks" " is not 1, but instead is %d", len(self._num_frequency_blocks)) - num_feats = int((input_size - self._feature_size) / ( - self._frequency_skip)) + 1 + num_feats = int( + (input_size - self._feature_size) / (self._frequency_skip)) + 1 if num_feats != self._num_frequency_blocks[0]: raise ValueError( "Invalid num_frequency_blocks, requires %d but gets %d, please" - " check the input size and filter config are correct." % ( - self._num_frequency_blocks[0], num_feats)) + " check the input size and filter config are correct." % + (self._num_frequency_blocks[0], num_feats)) block_inputs = [] for f in range(num_feats): cur_input = array_ops.slice( @@ -926,18 +936,18 @@ class GridLSTMCell(rnn_cell_impl.RNNCell): start_index = self._start_freqindex_list[b] end_index = self._end_freqindex_list[b] cur_size = end_index - start_index - block_feats = int((cur_size - self._feature_size) / ( - self._frequency_skip)) + 1 + block_feats = int( + (cur_size - self._feature_size) / (self._frequency_skip)) + 1 if block_feats != self._num_frequency_blocks[b]: raise ValueError( "Invalid num_frequency_blocks, requires %d but gets %d, please" - " check the input size and filter config are correct." % ( - self._num_frequency_blocks[b], block_feats)) + " check the input size and filter config are correct." % + (self._num_frequency_blocks[b], block_feats)) block_inputs = [] for f in range(block_feats): cur_input = array_ops.slice( - inputs, [0, start_index + slice_offset + f * - self._frequency_skip], + inputs, + [0, start_index + slice_offset + f * self._frequency_skip], [-1, self._feature_size]) block_inputs.append(cur_input) freq_inputs.append(block_inputs) @@ -953,11 +963,16 @@ class BidirectionalGridLSTMCell(GridLSTMCell): The current implementation uses different weights for the two directions. """ - def __init__(self, num_units, use_peepholes=False, + def __init__(self, + num_units, + use_peepholes=False, share_time_frequency_weights=False, - cell_clip=None, initializer=None, - num_unit_shards=1, forget_bias=1.0, - feature_size=None, frequency_skip=None, + cell_clip=None, + initializer=None, + num_unit_shards=1, + forget_bias=1.0, + feature_size=None, + frequency_skip=None, num_frequency_blocks=None, start_freqindex_list=None, end_freqindex_list=None, @@ -1016,8 +1031,8 @@ class BidirectionalGridLSTMCell(GridLSTMCell): state_names += ("%s_c, %s_m," % (name_prefix, name_prefix)) self._state_tuple_type = collections.namedtuple( "BidirectionalGridLSTMStateTuple", state_names.strip(",")) - self._state_size = self._state_tuple_type( - *([num_units, num_units] * self._total_blocks * 2)) + self._state_size = self._state_tuple_type(*( + [num_units, num_units] * self._total_blocks * 2)) self._output_size = 2 * num_units * self._total_blocks * 2 def call(self, inputs, state): @@ -1051,8 +1066,12 @@ class BidirectionalGridLSTMCell(GridLSTMCell): fwd_state_out_lst = [] for block in range(len(fwd_inputs)): fwd_m_out_lst_current, fwd_state_out_lst_current = self._compute( - fwd_inputs[block], block, state, batch_size, - state_prefix="fwd_state", state_is_tuple=True) + fwd_inputs[block], + block, + state, + batch_size, + state_prefix="fwd_state", + state_is_tuple=True) fwd_m_out_lst.extend(fwd_m_out_lst_current) fwd_state_out_lst.extend(fwd_state_out_lst_current) # Backward processing @@ -1063,8 +1082,12 @@ class BidirectionalGridLSTMCell(GridLSTMCell): # Reverse the blocks bwd_inputs_reverse = bwd_inputs[block][::-1] bwd_m_out_lst_current, bwd_state_out_lst_current = self._compute( - bwd_inputs_reverse, block, state, batch_size, - state_prefix="bwd_state", state_is_tuple=True) + bwd_inputs_reverse, + block, + state, + batch_size, + state_prefix="bwd_state", + state_is_tuple=True) bwd_m_out_lst.extend(bwd_m_out_lst_current) bwd_state_out_lst.extend(bwd_state_out_lst_current) state_out = self._state_tuple_type(*(fwd_state_out_lst + bwd_state_out_lst)) @@ -1075,6 +1098,7 @@ class BidirectionalGridLSTMCell(GridLSTMCell): # pylint: disable=protected-access _Linear = core_rnn_cell._Linear # pylint: disable=invalid-name + # pylint: enable=protected-access @@ -1084,8 +1108,14 @@ class AttentionCellWrapper(rnn_cell_impl.RNNCell): Implementation based on https://arxiv.org/abs/1409.0473. """ - def __init__(self, cell, attn_length, attn_size=None, attn_vec_size=None, - input_size=None, state_is_tuple=True, reuse=None): + def __init__(self, + cell, + attn_length, + attn_size=None, + attn_vec_size=None, + input_size=None, + state_is_tuple=True, + reuse=None): """Create a cell with attention. Args: @@ -1115,16 +1145,15 @@ class AttentionCellWrapper(rnn_cell_impl.RNNCell): if not rnn_cell_impl._like_rnncell(cell): # pylint: disable=protected-access raise TypeError("The parameter cell is not RNNCell.") if nest.is_sequence(cell.state_size) and not state_is_tuple: - raise ValueError("Cell returns tuple of states, but the flag " - "state_is_tuple is not set. State size is: %s" - % str(cell.state_size)) + raise ValueError( + "Cell returns tuple of states, but the flag " + "state_is_tuple is not set. State size is: %s" % str(cell.state_size)) if attn_length <= 0: - raise ValueError("attn_length should be greater than zero, got %s" - % str(attn_length)) + raise ValueError( + "attn_length should be greater than zero, got %s" % str(attn_length)) if not state_is_tuple: - logging.warn( - "%s: Using a concatenated state is slower and will soon be " - "deprecated. Use state_is_tuple=True.", self) + logging.warn("%s: Using a concatenated state is slower and will soon be " + "deprecated. Use state_is_tuple=True.", self) if attn_size is None: attn_size = cell.output_size if attn_vec_size is None: @@ -1160,8 +1189,8 @@ class AttentionCellWrapper(rnn_cell_impl.RNNCell): else: states = state state = array_ops.slice(states, [0, 0], [-1, self._cell.state_size]) - attns = array_ops.slice( - states, [0, self._cell.state_size], [-1, self._attn_size]) + attns = array_ops.slice(states, [0, self._cell.state_size], + [-1, self._attn_size]) attn_states = array_ops.slice( states, [0, self._cell.state_size + self._attn_size], [-1, self._attn_size * self._attn_length]) @@ -1199,8 +1228,8 @@ class AttentionCellWrapper(rnn_cell_impl.RNNCell): tanh = math_ops.tanh with vs.variable_scope("attention"): - k = vs.get_variable( - "attn_w", [1, 1, self._attn_size, self._attn_vec_size]) + k = vs.get_variable("attn_w", + [1, 1, self._attn_size, self._attn_vec_size]) v = vs.get_variable("attn_v", [self._attn_vec_size]) hidden = array_ops.reshape(attn_states, [-1, self._attn_length, 1, self._attn_size]) @@ -1227,7 +1256,8 @@ class HighwayWrapper(rnn_cell_impl.RNNCell): https://arxiv.org/abs/1505.00387 """ - def __init__(self, cell, + def __init__(self, + cell, couple_carry_transform_gates=True, carry_bias_init=1.0): """Constructs a `HighwayWrapper` for `cell`. @@ -1259,8 +1289,7 @@ class HighwayWrapper(rnn_cell_impl.RNNCell): carry_weight = vs.get_variable("carry_w", [input_size, input_size]) carry_bias = vs.get_variable( "carry_b", [input_size], - initializer=init_ops.constant_initializer( - self._carry_bias_init)) + initializer=init_ops.constant_initializer(self._carry_bias_init)) carry = math_ops.sigmoid(nn_ops.xw_plus_b(inp, carry_weight, carry_bias)) if self._couple_carry_transform_gates: transform = 1 - carry @@ -1269,11 +1298,9 @@ class HighwayWrapper(rnn_cell_impl.RNNCell): [input_size, input_size]) transform_bias = vs.get_variable( "transform_b", [input_size], - initializer=init_ops.constant_initializer( - -self._carry_bias_init)) - transform = math_ops.sigmoid(nn_ops.xw_plus_b(inp, - transform_weight, - transform_bias)) + initializer=init_ops.constant_initializer(-self._carry_bias_init)) + transform = math_ops.sigmoid( + nn_ops.xw_plus_b(inp, transform_weight, transform_bias)) return inp * carry + out * transform def __call__(self, inputs, state, scope=None): @@ -1293,9 +1320,11 @@ class HighwayWrapper(rnn_cell_impl.RNNCell): """ outputs, new_state = self._cell(inputs, state, scope=scope) nest.assert_same_structure(inputs, outputs) + # Ensure shapes match def assert_shape_match(inp, out): inp.get_shape().assert_is_compatible_with(out.get_shape()) + nest.map_structure(assert_shape_match, inputs, outputs) res_outputs = nest.map_structure(self._highway, inputs, outputs) return (res_outputs, new_state) @@ -1321,10 +1350,16 @@ class LayerNormBasicLSTMCell(rnn_cell_impl.RNNCell): Stanislau Semeniuta, Aliaksei Severyn, Erhardt Barth. """ - def __init__(self, num_units, forget_bias=1.0, - input_size=None, activation=math_ops.tanh, - layer_norm=True, norm_gain=1.0, norm_shift=0.0, - dropout_keep_prob=1.0, dropout_prob_seed=None, + def __init__(self, + num_units, + forget_bias=1.0, + input_size=None, + activation=math_ops.tanh, + layer_norm=True, + norm_gain=1.0, + norm_shift=0.0, + dropout_keep_prob=1.0, + dropout_prob_seed=None, reuse=None): """Initializes the basic LSTM cell. @@ -1409,8 +1444,8 @@ class LayerNormBasicLSTMCell(rnn_cell_impl.RNNCell): if (not isinstance(self._keep_prob, float)) or self._keep_prob < 1: g = nn_ops.dropout(g, self._keep_prob, seed=self._seed) - new_c = (c * math_ops.sigmoid(f + self._forget_bias) - + math_ops.sigmoid(i) * g) + new_c = ( + c * math_ops.sigmoid(f + self._forget_bias) + math_ops.sigmoid(i) * g) if self._layer_norm: new_c = self._norm(new_c, "state", dtype=dtype) new_h = self._activation(new_c) * math_ops.sigmoid(o) @@ -1432,8 +1467,7 @@ class NASCell(rnn_cell_impl.RNNCell): The class uses an optional projection layer. """ - def __init__(self, num_units, num_proj=None, - use_biases=False, reuse=None): + def __init__(self, num_units, num_proj=None, use_biases=False, reuse=None): """Initialize the parameters for a NAS cell. Args: @@ -1503,12 +1537,10 @@ class NASCell(rnn_cell_impl.RNNCell): raise ValueError("Could not infer input size from inputs.get_shape()[-1]") # Variables for the NAS cell. W_m is all matrices multiplying the # hiddenstate and W_inputs is all matrices multiplying the inputs. - concat_w_m = vs.get_variable( - "recurrent_kernel", [num_proj, 8 * self._num_units], - dtype) + concat_w_m = vs.get_variable("recurrent_kernel", + [num_proj, 8 * self._num_units], dtype) concat_w_inputs = vs.get_variable( - "kernel", [input_size.value, 8 * self._num_units], - dtype) + "kernel", [input_size.value, 8 * self._num_units], dtype) m_matrix = math_ops.matmul(m_prev, concat_w_m) inputs_matrix = math_ops.matmul(inputs, concat_w_inputs) @@ -1523,10 +1555,10 @@ class NASCell(rnn_cell_impl.RNNCell): # The NAS cell branches into 8 different splits for both the hiddenstate # and the input - m_matrix_splits = array_ops.split(axis=1, num_or_size_splits=8, - value=m_matrix) - inputs_matrix_splits = array_ops.split(axis=1, num_or_size_splits=8, - value=inputs_matrix) + m_matrix_splits = array_ops.split( + axis=1, num_or_size_splits=8, value=m_matrix) + inputs_matrix_splits = array_ops.split( + axis=1, num_or_size_splits=8, value=inputs_matrix) # First layer layer1_0 = sigmoid(inputs_matrix_splits[0] + m_matrix_splits[0]) @@ -1558,9 +1590,8 @@ class NASCell(rnn_cell_impl.RNNCell): # Projection layer if specified if self._num_proj is not None: - concat_w_proj = vs.get_variable( - "projection_weights", [self._num_units, self._num_proj], - dtype) + concat_w_proj = vs.get_variable("projection_weights", + [self._num_units, self._num_proj], dtype) new_m = math_ops.matmul(new_m, concat_w_proj) new_state = rnn_cell_impl.LSTMStateTuple(new_c, new_m) @@ -1583,8 +1614,12 @@ class UGRNNCell(rnn_cell_impl.RNNCell): "Capacity and Trainability in Recurrent Neural Networks" Proc. ICLR 2017. """ - def __init__(self, num_units, initializer=None, forget_bias=1.0, - activation=math_ops.tanh, reuse=None): + def __init__(self, + num_units, + initializer=None, + forget_bias=1.0, + activation=math_ops.tanh, + reuse=None): """Initialize the parameters for an UGRNN cell. Args: @@ -1639,8 +1674,8 @@ class UGRNNCell(rnn_cell_impl.RNNCell): if input_size.value is None: raise ValueError("Could not infer input size from inputs.get_shape()[-1]") - with vs.variable_scope(vs.get_variable_scope(), - initializer=self._initializer): + with vs.variable_scope( + vs.get_variable_scope(), initializer=self._initializer): cell_inputs = array_ops.concat([inputs, state], 1) if self._linear is None: self._linear = _Linear(cell_inputs, 2 * self._num_units, True) @@ -1680,9 +1715,13 @@ class IntersectionRNNCell(rnn_cell_impl.RNNCell): RNNs so it may not achieve best performance with depth 1. """ - def __init__(self, num_units, num_in_proj=None, - initializer=None, forget_bias=1.0, - y_activation=nn_ops.relu, reuse=None): + def __init__(self, + num_units, + num_in_proj=None, + initializer=None, + forget_bias=1.0, + y_activation=nn_ops.relu, + reuse=None): """Initialize the parameters for an +RNN cell. Args: @@ -1746,8 +1785,8 @@ class IntersectionRNNCell(rnn_cell_impl.RNNCell): if input_size.value is None: raise ValueError("Could not infer input size from inputs.get_shape()[-1]") - with vs.variable_scope(vs.get_variable_scope(), - initializer=self._initializer): + with vs.variable_scope( + vs.get_variable_scope(), initializer=self._initializer): # read-in projections (should be used for first layer in deep +RNN # to transform size of inputs from I --> N) if input_size.value != self._num_units: @@ -1764,13 +1803,13 @@ class IntersectionRNNCell(rnn_cell_impl.RNNCell): n_dim = i_dim = self._num_units cell_inputs = array_ops.concat([inputs, state], 1) if self._linear2 is None: - self._linear2 = _Linear(cell_inputs, 2*n_dim + 2*i_dim, True) + self._linear2 = _Linear(cell_inputs, 2 * n_dim + 2 * i_dim, True) rnn_matrix = self._linear2(cell_inputs) - gh_act = rnn_matrix[:, :n_dim] # b x n - h_act = rnn_matrix[:, n_dim:2*n_dim] # b x n - gy_act = rnn_matrix[:, 2*n_dim:2*n_dim+i_dim] # b x i - y_act = rnn_matrix[:, 2*n_dim+i_dim:2*n_dim+2*i_dim] # b x i + gh_act = rnn_matrix[:, :n_dim] # b x n + h_act = rnn_matrix[:, n_dim:2 * n_dim] # b x n + gy_act = rnn_matrix[:, 2 * n_dim:2 * n_dim + i_dim] # b x i + y_act = rnn_matrix[:, 2 * n_dim + i_dim:2 * n_dim + 2 * i_dim] # b x i h = tanh(h_act) y = self._y_activation(y_act) @@ -1816,6 +1855,7 @@ class CompiledWrapper(rnn_cell_impl.RNNCell): if self._compile_stateful: compile_ops = True else: + def compile_ops(node_def): global _REGISTERED_OPS if _REGISTERED_OPS is None: @@ -1826,10 +1866,7 @@ class CompiledWrapper(rnn_cell_impl.RNNCell): return self._cell(inputs, state, scope=scope) -def _random_exp_initializer(minval, - maxval, - seed=None, - dtype=dtypes.float32): +def _random_exp_initializer(minval, maxval, seed=None, dtype=dtypes.float32): """Returns an exponential distribution initializer. Args: @@ -1848,10 +1885,7 @@ def _random_exp_initializer(minval, del partition_info # Unused. return math_ops.exp( random_ops.random_uniform( - shape, - math_ops.log(minval), - math_ops.log(maxval), - dtype, + shape, math_ops.log(minval), math_ops.log(maxval), dtype, seed=seed)) return _initializer @@ -1955,8 +1989,7 @@ class PhasedLSTMCell(rnn_cell_impl.RNNCell): if self._linear1 is None: self._linear1 = _Linear(in_mask_gates, 2 * self._num_units, True) - mask_gates = math_ops.sigmoid( - self._linear1(in_mask_gates)) + mask_gates = math_ops.sigmoid(self._linear1(in_mask_gates)) [input_gate, forget_gate] = array_ops.split( axis=1, num_or_size_splits=2, value=mask_gates) @@ -1980,12 +2013,12 @@ class PhasedLSTMCell(rnn_cell_impl.RNNCell): period = vs.get_variable( "period", [self._num_units], - initializer=_random_exp_initializer( - self._period_init_min, self._period_init_max)) + initializer=_random_exp_initializer(self._period_init_min, + self._period_init_max)) phase = vs.get_variable( "phase", [self._num_units], - initializer=init_ops.random_uniform_initializer( - 0., period.initial_value)) + initializer=init_ops.random_uniform_initializer(0., + period.initial_value)) ratio_on = vs.get_variable( "ratio_on", [self._num_units], initializer=init_ops.constant_initializer(self._ratio_on), @@ -2007,6 +2040,7 @@ class PhasedLSTMCell(rnn_cell_impl.RNNCell): return new_h, new_state + class ConvLSTMCell(rnn_cell_impl.RNNCell): """Convolutional LSTM recurrent network cell. @@ -2040,7 +2074,7 @@ class ConvLSTMCell(rnn_cell_impl.RNNCell): """ super(ConvLSTMCell, self).__init__(name=name) - if conv_ndims != len(input_shape)-1: + if conv_ndims != len(input_shape) - 1: raise ValueError("Invalid input_shape {} for conv_ndims={}.".format( input_shape, conv_ndims)) @@ -2059,8 +2093,8 @@ class ConvLSTMCell(rnn_cell_impl.RNNCell): state_size = tensor_shape.TensorShape( self._input_shape[:-1] + [self._output_channels]) self._state_size = rnn_cell_impl.LSTMStateTuple(state_size, state_size) - self._output_size = tensor_shape.TensorShape(self._input_shape[:-1] - + [self._total_output_channels]) + self._output_size = tensor_shape.TensorShape( + self._input_shape[:-1] + [self._total_output_channels]) @property def output_size(self): @@ -2072,13 +2106,10 @@ class ConvLSTMCell(rnn_cell_impl.RNNCell): def call(self, inputs, state, scope=None): cell, hidden = state - new_hidden = _conv([inputs, hidden], - self._kernel_shape, - 4*self._output_channels, - self._use_bias) - gates = array_ops.split(value=new_hidden, - num_or_size_splits=4, - axis=self._conv_ndims+1) + new_hidden = _conv([inputs, hidden], self._kernel_shape, + 4 * self._output_channels, self._use_bias) + gates = array_ops.split( + value=new_hidden, num_or_size_splits=4, axis=self._conv_ndims + 1) input_gate, new_input, forget_gate, output_gate = gates new_cell = math_ops.sigmoid(forget_gate + self._forget_bias) * cell @@ -2090,29 +2121,35 @@ class ConvLSTMCell(rnn_cell_impl.RNNCell): new_state = rnn_cell_impl.LSTMStateTuple(new_cell, output) return output, new_state + class Conv1DLSTMCell(ConvLSTMCell): """1D Convolutional LSTM recurrent network cell. https://arxiv.org/pdf/1506.04214v1.pdf """ + def __init__(self, name="conv_1d_lstm_cell", **kwargs): """Construct Conv1DLSTM. See `ConvLSTMCell` for more details.""" super(Conv1DLSTMCell, self).__init__(conv_ndims=1, **kwargs) + class Conv2DLSTMCell(ConvLSTMCell): """2D Convolutional LSTM recurrent network cell. https://arxiv.org/pdf/1506.04214v1.pdf """ + def __init__(self, name="conv_2d_lstm_cell", **kwargs): """Construct Conv2DLSTM. See `ConvLSTMCell` for more details.""" super(Conv2DLSTMCell, self).__init__(conv_ndims=2, **kwargs) + class Conv3DLSTMCell(ConvLSTMCell): """3D Convolutional LSTM recurrent network cell. https://arxiv.org/pdf/1506.04214v1.pdf """ + def __init__(self, name="conv_3d_lstm_cell", **kwargs): """Construct Conv3DLSTM. See `ConvLSTMCell` for more details.""" super(Conv3DLSTMCell, self).__init__(conv_ndims=3, **kwargs) @@ -2137,7 +2174,7 @@ def _conv(args, filter_size, num_features, bias, bias_start=0.0): shapes = [a.get_shape().as_list() for a in args] shape_length = len(shapes[0]) for shape in shapes: - if len(shape) not in [3,4,5]: + if len(shape) not in [3, 4, 5]: raise ValueError("Conv Linear expects 3D, 4D " "or 5D arguments: %s" % str(shapes)) if len(shape) != len(shapes[0]): @@ -2148,40 +2185,36 @@ def _conv(args, filter_size, num_features, bias, bias_start=0.0): dtype = [a.dtype for a in args][0] # determine correct conv operation - if shape_length == 3: + if shape_length == 3: conv_op = nn_ops.conv1d strides = 1 elif shape_length == 4: conv_op = nn_ops.conv2d - strides = shape_length*[1] + strides = shape_length * [1] elif shape_length == 5: conv_op = nn_ops.conv3d - strides = shape_length*[1] + strides = shape_length * [1] # Now the computation. kernel = vs.get_variable( - "kernel", - filter_size + [total_arg_size_depth, num_features], - dtype=dtype) + "kernel", filter_size + [total_arg_size_depth, num_features], dtype=dtype) if len(args) == 1: - res = conv_op(args[0], - kernel, - strides, - padding='SAME') + res = conv_op(args[0], kernel, strides, padding="SAME") else: - res = conv_op(array_ops.concat(axis=shape_length-1, values=args), - kernel, - strides, - padding='SAME') + res = conv_op( + array_ops.concat(axis=shape_length - 1, values=args), + kernel, + strides, + padding="SAME") if not bias: return res bias_term = vs.get_variable( "biases", [num_features], dtype=dtype, - initializer=init_ops.constant_initializer( - bias_start, dtype=dtype)) + initializer=init_ops.constant_initializer(bias_start, dtype=dtype)) return res + bias_term + class GLSTMCell(rnn_cell_impl.RNNCell): """Group LSTM cell (G-LSTM). @@ -2193,8 +2226,13 @@ class GLSTMCell(rnn_cell_impl.RNNCell): "Factorization Tricks for LSTM Networks", ICLR 2017 workshop. """ - def __init__(self, num_units, initializer=None, num_proj=None, - number_of_groups=1, forget_bias=1.0, activation=math_ops.tanh, + def __init__(self, + num_units, + initializer=None, + num_proj=None, + number_of_groups=1, + forget_bias=1.0, + activation=math_ops.tanh, reuse=None): """Initialize the parameters of G-LSTM cell. @@ -2231,11 +2269,15 @@ class GLSTMCell(rnn_cell_impl.RNNCell): if self._num_proj: if self._num_proj % self._number_of_groups != 0: raise ValueError("num_proj must be divisible by number_of_groups") - self._group_shape = [int(self._num_proj / self._number_of_groups), - int(self._num_units / self._number_of_groups)] + self._group_shape = [ + int(self._num_proj / self._number_of_groups), + int(self._num_units / self._number_of_groups) + ] else: - self._group_shape = [int(self._num_units / self._number_of_groups), - int(self._num_units / self._number_of_groups)] + self._group_shape = [ + int(self._num_units / self._number_of_groups), + int(self._num_units / self._number_of_groups) + ] if num_proj: self._state_size = rnn_cell_impl.LSTMStateTuple(num_units, num_proj) @@ -2267,10 +2309,11 @@ class GLSTMCell(rnn_cell_impl.RNNCell): subset of inputs corresponding to group "group_id", a Tensor, 2D, [batch x num_units/number_of_groups] """ - return array_ops.slice(input_=inputs, - begin=[0, group_id * group_size], - size=[self._batch_size, group_size], - name=("GLSTM_group%d_input_generation" % group_id)) + return array_ops.slice( + input_=inputs, + begin=[0, group_id * group_size], + size=[self._batch_size, group_size], + name=("GLSTM_group%d_input_generation" % group_id)) def call(self, inputs, state): """Run one step of G-LSTM. @@ -2309,10 +2352,13 @@ class GLSTMCell(rnn_cell_impl.RNNCell): for group_id in range(self._number_of_groups): with vs.variable_scope("group%d" % group_id): x_g_id = array_ops.concat( - [self._get_input_for_group(inputs, group_id, - self._group_shape[0]), - self._get_input_for_group(m_prev, group_id, - self._group_shape[0])], axis=1) + [ + self._get_input_for_group(inputs, group_id, + self._group_shape[0]), + self._get_input_for_group(m_prev, group_id, + self._group_shape[0]) + ], + axis=1) if self._linear1 is None: self._linear1 = _Linear(x_g_id, 4 * self._group_shape[1], False) R_k = self._linear1(x_g_id) # pylint: disable=invalid-name @@ -2323,34 +2369,35 @@ class GLSTMCell(rnn_cell_impl.RNNCell): f_parts.append(f_k) o_parts.append(o_k) - bi = vs.get_variable(name="bias_i", - shape=[self._num_units], - dtype=dtype, - initializer= - init_ops.constant_initializer(0.0, dtype=dtype)) - bj = vs.get_variable(name="bias_j", - shape=[self._num_units], - dtype=dtype, - initializer= - init_ops.constant_initializer(0.0, dtype=dtype)) - bf = vs.get_variable(name="bias_f", - shape=[self._num_units], - dtype=dtype, - initializer= - init_ops.constant_initializer(0.0, dtype=dtype)) - bo = vs.get_variable(name="bias_o", - shape=[self._num_units], - dtype=dtype, - initializer= - init_ops.constant_initializer(0.0, dtype=dtype)) + bi = vs.get_variable( + name="bias_i", + shape=[self._num_units], + dtype=dtype, + initializer=init_ops.constant_initializer(0.0, dtype=dtype)) + bj = vs.get_variable( + name="bias_j", + shape=[self._num_units], + dtype=dtype, + initializer=init_ops.constant_initializer(0.0, dtype=dtype)) + bf = vs.get_variable( + name="bias_f", + shape=[self._num_units], + dtype=dtype, + initializer=init_ops.constant_initializer(0.0, dtype=dtype)) + bo = vs.get_variable( + name="bias_o", + shape=[self._num_units], + dtype=dtype, + initializer=init_ops.constant_initializer(0.0, dtype=dtype)) i = nn_ops.bias_add(array_ops.concat(i_parts, axis=1), bi) j = nn_ops.bias_add(array_ops.concat(j_parts, axis=1), bj) f = nn_ops.bias_add(array_ops.concat(f_parts, axis=1), bf) o = nn_ops.bias_add(array_ops.concat(o_parts, axis=1), bo) - c = (math_ops.sigmoid(f + self._forget_bias) * c_prev + - math_ops.sigmoid(i) * math_ops.tanh(j)) + c = ( + math_ops.sigmoid(f + self._forget_bias) * c_prev + + math_ops.sigmoid(i) * math_ops.tanh(j)) m = math_ops.sigmoid(o) * self._activation(c) if self._num_proj is not None: @@ -2635,10 +2682,12 @@ class LayerNormLSTMCell(rnn_cell_impl.RNNCell): class SRUCell(rnn_cell_impl._LayerRNNCell): """SRU, Simple Recurrent Unit + Implementation based on Training RNNs as Fast as CNNs (cf. https://arxiv.org/abs/1709.02755). - This variation of RNN cell is characterized by the simplified data dependence + This variation of RNN cell is characterized by the simplified data + dependence between hidden states of two consecutive time steps. Traditionally, hidden states from a cell at time step t-1 needs to be multiplied with a matrix W_hh before being fed into the ensuing cell at time step t. @@ -2656,8 +2705,8 @@ class SRUCell(rnn_cell_impl._LayerRNNCell): will share weights, but to avoid mistakes we require reuse=True in such cases. """ - def __init__(self, num_units, - activation=None, reuse=None, name=None): + + def __init__(self, num_units, activation=None, reuse=None, name=None): super(SRUCell, self).__init__(_reuse=reuse, name=name) self._num_units = num_units self._activation = activation or math_ops.tanh @@ -2675,8 +2724,8 @@ class SRUCell(rnn_cell_impl._LayerRNNCell): def build(self, inputs_shape): if inputs_shape[1].value is None: - raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s" - % inputs_shape) + raise ValueError( + "Expected inputs.shape[-1] to be known, saw shape: %s" % inputs_shape) input_depth = inputs_shape[1].value @@ -2711,15 +2760,276 @@ class SRUCell(rnn_cell_impl._LayerRNNCell): """Simple recurrent unit (SRU) with num_units cells.""" U = math_ops.matmul(inputs, self._kernel) - x_bar, f_intermediate, r_intermediate = array_ops.split(value=U, - num_or_size_splits=3, - axis=1) + x_bar, f_intermediate, r_intermediate = array_ops.split( + value=U, num_or_size_splits=3, axis=1) - f_r = math_ops.sigmoid(nn_ops.bias_add(array_ops.concat( - [f_intermediate, r_intermediate], 1), self._bias)) + f_r = math_ops.sigmoid( + nn_ops.bias_add( + array_ops.concat([f_intermediate, r_intermediate], 1), self._bias)) f, r = array_ops.split(value=f_r, num_or_size_splits=2, axis=1) c = f * state + (1.0 - f) * x_bar h = r * self._activation(c) + (1.0 - r) * inputs return h, c + + +class WeightNormLSTMCell(rnn_cell_impl.RNNCell): + """Weight normalized LSTM Cell. Adapted from `rnn_cell_impl.LSTMCell`. + + The weight-norm implementation is based on: + https://arxiv.org/abs/1602.07868 + Tim Salimans, Diederik P. Kingma. + Weight Normalization: A Simple Reparameterization to Accelerate + Training of Deep Neural Networks + + The default LSTM implementation based on: + http://www.bioinf.jku.at/publications/older/2604.pdf + S. Hochreiter and J. Schmidhuber. + "Long Short-Term Memory". Neural Computation, 9(8):1735-1780, 1997. + + The class uses optional peephole connections, optional cell clipping + and an optional projection layer. + + The optional peephole implementation is based on: + https://research.google.com/pubs/archive/43905.pdf + Hasim Sak, Andrew Senior, and Francoise Beaufays. + "Long short-term memory recurrent neural network architectures for + large scale acoustic modeling." INTERSPEECH, 2014. + """ + + def __init__(self, + num_units, + norm=True, + use_peepholes=False, + cell_clip=None, + initializer=None, + num_proj=None, + proj_clip=None, + forget_bias=1, + activation=None, + reuse=None): + """Initialize the parameters of a weight-normalized LSTM cell. + + Args: + num_units: int, The number of units in the LSTM cell + norm: If `True`, apply normalization to the weight matrices. If False, + the result is identical to that obtained from `rnn_cell_impl.LSTMCell` + use_peepholes: bool, set `True` to enable diagonal/peephole connections. + cell_clip: (optional) A float value, if provided the cell state is clipped + by this value prior to the cell output activation. + initializer: (optional) The initializer to use for the weight matrices. + num_proj: (optional) int, The output dimensionality for the projection + matrices. If None, no projection is performed. + proj_clip: (optional) A float value. If `num_proj > 0` and `proj_clip` is + provided, then the projected values are clipped elementwise to within + `[-proj_clip, proj_clip]`. + forget_bias: Biases of the forget gate are initialized by default to 1 + in order to reduce the scale of forgetting at the beginning of + the training. + activation: Activation function of the inner states. Default: `tanh`. + reuse: (optional) Python boolean describing whether to reuse variables + in an existing scope. If not `True`, and the existing scope already has + the given variables, an error is raised. + """ + super(WeightNormLSTMCell, self).__init__(_reuse=reuse) + + self._scope = "wn_lstm_cell" + self._num_units = num_units + self._norm = norm + self._initializer = initializer + self._use_peepholes = use_peepholes + self._cell_clip = cell_clip + self._num_proj = num_proj + self._proj_clip = proj_clip + self._activation = activation or math_ops.tanh + self._forget_bias = forget_bias + + self._weights_variable_name = "kernel" + self._bias_variable_name = "bias" + + if num_proj: + self._state_size = rnn_cell_impl.LSTMStateTuple(num_units, num_proj) + self._output_size = num_proj + else: + self._state_size = rnn_cell_impl.LSTMStateTuple(num_units, num_units) + self._output_size = num_units + + @property + def state_size(self): + return self._state_size + + @property + def output_size(self): + return self._output_size + + def _normalize(self, weight, name): + """Apply weight normalization. + + Args: + weight: a 2D tensor with known number of columns. + name: string, variable name for the normalizer. + Returns: + A tensor with the same shape as `weight`. + """ + + output_size = weight.get_shape().as_list()[1] + g = vs.get_variable(name, [output_size], dtype=weight.dtype) + return nn_impl.l2_normalize(weight, dim=0) * g + + def _linear(self, + args, + output_size, + norm, + bias, + bias_initializer=None, + kernel_initializer=None): + """Linear map: sum_i(args[i] * W[i]), where W[i] is a variable. + + Args: + args: a 2D Tensor or a list of 2D, batch x n, Tensors. + output_size: int, second dimension of W[i]. + bias: boolean, whether to add a bias term or not. + bias_initializer: starting value to initialize the bias + (default is all zeros). + kernel_initializer: starting value to initialize the weight. + + Returns: + A 2D Tensor with shape [batch x output_size] equal to + sum_i(args[i] * W[i]), where W[i]s are newly created matrices. + + Raises: + ValueError: if some of the arguments has unspecified or wrong shape. + """ + if args is None or (nest.is_sequence(args) and not args): + raise ValueError("`args` must be specified") + if not nest.is_sequence(args): + args = [args] + + # Calculate the total size of arguments on dimension 1. + total_arg_size = 0 + shapes = [a.get_shape() for a in args] + for shape in shapes: + if shape.ndims != 2: + raise ValueError("linear is expecting 2D arguments: %s" % shapes) + if shape[1].value is None: + raise ValueError("linear expects shape[1] to be provided for shape %s, " + "but saw %s" % (shape, shape[1])) + else: + total_arg_size += shape[1].value + + dtype = [a.dtype for a in args][0] + + # Now the computation. + scope = vs.get_variable_scope() + with vs.variable_scope(scope) as outer_scope: + weights = vs.get_variable( + self._weights_variable_name, [total_arg_size, output_size], + dtype=dtype, + initializer=kernel_initializer) + if norm: + wn = [] + st = 0 + with ops.control_dependencies(None): + for i in range(len(args)): + en = st + shapes[i][1].value + wn.append( + self._normalize(weights[st:en, :], name="norm_{}".format(i))) + st = en + + weights = array_ops.concat(wn, axis=0) + + if len(args) == 1: + res = math_ops.matmul(args[0], weights) + else: + res = math_ops.matmul(array_ops.concat(args, 1), weights) + if not bias: + return res + + with vs.variable_scope(outer_scope) as inner_scope: + inner_scope.set_partitioner(None) + if bias_initializer is None: + bias_initializer = init_ops.constant_initializer(0.0, dtype=dtype) + + biases = vs.get_variable( + self._bias_variable_name, [output_size], + dtype=dtype, + initializer=bias_initializer) + + return nn_ops.bias_add(res, biases) + + def call(self, inputs, state): + """Run one step of LSTM. + + Args: + inputs: input Tensor, 2D, batch x num_units. + state: A tuple of state Tensors, both `2-D`, with column sizes + `c_state` and `m_state`. + + Returns: + A tuple containing: + + - A `2-D, [batch x output_dim]`, Tensor representing the output of the + LSTM after reading `inputs` when previous state was `state`. + Here output_dim is: + num_proj if num_proj was set, + num_units otherwise. + - Tensor(s) representing the new state of LSTM after reading `inputs` when + the previous state was `state`. Same type and shape(s) as `state`. + + Raises: + ValueError: If input size cannot be inferred from inputs via + static shape inference. + """ + dtype = inputs.dtype + num_units = self._num_units + sigmoid = math_ops.sigmoid + c, h = state + + input_size = inputs.get_shape().with_rank(2)[1] + if input_size.value is None: + raise ValueError("Could not infer input size from inputs.get_shape()[-1]") + + with vs.variable_scope(self._scope, initializer=self._initializer): + + concat = self._linear( + [inputs, h], 4 * num_units, norm=self._norm, bias=True) + + # i = input_gate, j = new_input, f = forget_gate, o = output_gate + i, j, f, o = array_ops.split(value=concat, num_or_size_splits=4, axis=1) + + if self._use_peepholes: + w_f_diag = vs.get_variable("w_f_diag", shape=[num_units], dtype=dtype) + w_i_diag = vs.get_variable("w_i_diag", shape=[num_units], dtype=dtype) + w_o_diag = vs.get_variable("w_o_diag", shape=[num_units], dtype=dtype) + + new_c = ( + c * sigmoid(f + self._forget_bias + w_f_diag * c) + + sigmoid(i + w_i_diag * c) * self._activation(j)) + else: + new_c = ( + c * sigmoid(f + self._forget_bias) + + sigmoid(i) * self._activation(j)) + + if self._cell_clip is not None: + # pylint: disable=invalid-unary-operand-type + new_c = clip_ops.clip_by_value(new_c, -self._cell_clip, self._cell_clip) + # pylint: enable=invalid-unary-operand-type + if self._use_peepholes: + new_h = sigmoid(o + w_o_diag * new_c) * self._activation(new_c) + else: + new_h = sigmoid(o) * self._activation(new_c) + + if self._num_proj is not None: + with vs.variable_scope("projection"): + new_h = self._linear( + new_h, self._num_proj, norm=self._norm, bias=False) + + if self._proj_clip is not None: + # pylint: disable=invalid-unary-operand-type + new_h = clip_ops.clip_by_value(new_h, -self._proj_clip, + self._proj_clip) + # pylint: enable=invalid-unary-operand-type + + new_state = rnn_cell_impl.LSTMStateTuple(new_c, new_h) + return new_h, new_state diff --git a/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils.h b/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils.h index c0df224bc8cffcb485db38dea270600c71070dff..b732cdd41e5c39793c17fa920c115e2bbe96f5de 100644 --- a/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils.h +++ b/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils.h @@ -15,8 +15,8 @@ limitations under the License. // Helpers for working with the SignatureDefs of TensorFlow SavedModels. -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_SAVED_MODEL_CC_SAVED_MODEL_SIGNATURE_DEF_UTILS_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_SAVED_MODEL_CC_SAVED_MODEL_SIGNATURE_DEF_UTILS_H_ +#ifndef TENSORFLOW_CONTRIB_SAVED_MODEL_CC_SAVED_MODEL_SIGNATURE_DEF_UTILS_H_ +#define TENSORFLOW_CONTRIB_SAVED_MODEL_CC_SAVED_MODEL_SIGNATURE_DEF_UTILS_H_ #include #include @@ -66,4 +66,4 @@ Status FindOutputTensorNameByKey(const SignatureDef& signature_def, } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_SAVED_MODEL_CC_SAVED_MODEL_SIGNATURE_DEF_UTILS_H_ +#endif // TENSORFLOW_CONTRIB_SAVED_MODEL_CC_SAVED_MODEL_SIGNATURE_DEF_UTILS_H_ diff --git a/tensorflow/contrib/seq2seq/kernels/beam_search_ops.h b/tensorflow/contrib/seq2seq/kernels/beam_search_ops.h index 693b02dc437afdf14c38e4224c5469bb3e569540..34da8c82cdab9b6f82af328c49a365ae1cb951ed 100644 --- a/tensorflow/contrib/seq2seq/kernels/beam_search_ops.h +++ b/tensorflow/contrib/seq2seq/kernels/beam_search_ops.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_SEQ2SEQ_KERNELS_BEAM_SEARCH_OPS_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_SEQ2SEQ_KERNELS_BEAM_SEARCH_OPS_H_ +#ifndef TENSORFLOW_CONTRIB_SEQ2SEQ_KERNELS_BEAM_SEARCH_OPS_H_ +#define TENSORFLOW_CONTRIB_SEQ2SEQ_KERNELS_BEAM_SEARCH_OPS_H_ #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/tensor_types.h" @@ -38,4 +38,4 @@ struct GatherTree { } // namespace functor } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_SEQ2SEQ_KERNELS_BEAM_SEARCH_OPS_H_ +#endif // TENSORFLOW_CONTRIB_SEQ2SEQ_KERNELS_BEAM_SEARCH_OPS_H_ diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py index f498b2bb5709ea28faca1c5cfa21ad30aac14ab7..926554031775202d7f7d9018cf6ae4efb34fe96b 100644 --- a/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py +++ b/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py @@ -46,20 +46,18 @@ class TestGatherTree(test.TestCase): # create (batch_size, max_time, beam_width) matrix and transpose it predicted_ids = np.array( - [[[1, 2, 3], [4, 5, 6], [7, 8, 9]], - [[2, 3, 4], [5, 6, 7], [8, 9, 10]]], + [[[1, 2, 3], [4, 5, 6], [7, 8, 9]], [[2, 3, 4], [5, 6, 7], [8, 9, 10]]], dtype=np.int32).transpose([1, 0, 2]) parent_ids = np.array( - [[[0, 0, 0], [0, 1, 1], [2, 1, 2]], - [[0, 0, 0], [1, 2, 0], [2, 1, 1]]], + [[[0, 0, 0], [0, 1, 1], [2, 1, 2]], [[0, 0, 0], [1, 2, 0], [2, 1, 1]]], dtype=np.int32).transpose([1, 0, 2]) # sequence_lengths is shaped (batch_size = 3) max_sequence_lengths = [3, 3] - expected_result = np.array( - [[[2, 2, 2], [6, 5, 6], [7, 8, 9]], - [[2, 4, 4], [7, 6, 6], [8, 9, 10]]]).transpose([1, 0, 2]) + expected_result = np.array([[[2, 2, 2], [6, 5, 6], [7, 8, 9]], + [[2, 4, 4], [7, 6, 6], + [8, 9, 10]]]).transpose([1, 0, 2]) res = beam_search_ops.gather_tree( predicted_ids, @@ -157,8 +155,8 @@ class TestBeamStep(test.TestCase): self.assertAllEqual(outputs_.predicted_ids, [[3, 3, 2], [2, 2, 1]]) self.assertAllEqual(outputs_.parent_ids, [[1, 0, 0], [2, 1, 0]]) self.assertAllEqual(next_state_.lengths, [[3, 3, 3], [3, 3, 3]]) - self.assertAllEqual(next_state_.finished, [[False, False, False], - [False, False, False]]) + self.assertAllEqual(next_state_.finished, + [[False, False, False], [False, False, False]]) expected_log_probs = [] expected_log_probs.append(state_.log_probs[0][[1, 0, 0]]) @@ -212,8 +210,8 @@ class TestBeamStep(test.TestCase): self.assertAllEqual(outputs_.parent_ids, [[1, 0, 0], [1, 2, 0]]) self.assertAllEqual(outputs_.predicted_ids, [[0, 3, 2], [2, 0, 1]]) self.assertAllEqual(next_state_.lengths, [[1, 3, 3], [3, 1, 3]]) - self.assertAllEqual(next_state_.finished, [[True, False, False], - [False, True, False]]) + self.assertAllEqual(next_state_.finished, + [[True, False, False], [False, True, False]]) expected_log_probs = [] expected_log_probs.append(state_.log_probs[0][[1, 0, 0]]) @@ -226,9 +224,10 @@ class TestBeamStep(test.TestCase): class TestLargeBeamStep(test.TestCase): - """ - Tests a single step of beam search in such - case that beam size is larger than vocabulary size. + """Tests large beam step. + + Tests a single step of beam search in such case that beam size is larger than + vocabulary size. """ def setUp(self): @@ -239,19 +238,21 @@ class TestLargeBeamStep(test.TestCase): self.end_token = 0 self.length_penalty_weight = 0.6 - def test_step(self): - def get_probs(): - """this simulates the initialize method in BeamSearchDecoder""" - log_prob_mask = array_ops.one_hot(array_ops.zeros([self.batch_size], - dtype=dtypes.int32), - depth=self.beam_width, on_value=True, - off_value=False, dtype=dtypes.bool) - log_prob_zeros = array_ops.zeros([self.batch_size, self.beam_width], - dtype=dtypes.float32) - log_prob_neg_inf = array_ops.ones([self.batch_size, self.beam_width], - dtype=dtypes.float32) * -np.Inf + def get_probs(): + """this simulates the initialize method in BeamSearchDecoder.""" + log_prob_mask = array_ops.one_hot( + array_ops.zeros([self.batch_size], dtype=dtypes.int32), + depth=self.beam_width, + on_value=True, + off_value=False, + dtype=dtypes.bool) + + log_prob_zeros = array_ops.zeros( + [self.batch_size, self.beam_width], dtype=dtypes.float32) + log_prob_neg_inf = array_ops.ones( + [self.batch_size, self.beam_width], dtype=dtypes.float32) * -np.Inf log_probs = array_ops.where(log_prob_mask, log_prob_zeros, log_prob_neg_inf) @@ -260,12 +261,15 @@ class TestLargeBeamStep(test.TestCase): log_probs = get_probs() dummy_cell_state = array_ops.zeros([self.batch_size, self.beam_width]) + # pylint: disable=invalid-name _finished = array_ops.one_hot( array_ops.zeros([self.batch_size], dtype=dtypes.int32), - depth=self.beam_width, on_value=False, - off_value=True, dtype=dtypes.bool) + depth=self.beam_width, + on_value=False, + off_value=True, + dtype=dtypes.bool) _lengths = np.zeros([self.batch_size, self.beam_width], dtype=np.int64) - _lengths[:, 0]=2 + _lengths[:, 0] = 2 _lengths = constant_op.constant(_lengths, dtype=dtypes.int64) beam_state = beam_search_decoder.BeamSearchDecoderState( @@ -298,20 +302,20 @@ class TestLargeBeamStep(test.TestCase): length_penalty_weight=self.length_penalty_weight) with self.test_session() as sess: - outputs_, next_state_, state_, log_probs_ = sess.run( + outputs_, next_state_, _, _ = sess.run( [outputs, next_beam_state, beam_state, log_probs]) self.assertEqual(outputs_.predicted_ids[0, 0], 3) self.assertEqual(outputs_.predicted_ids[0, 1], 2) self.assertEqual(outputs_.predicted_ids[1, 0], 1) neg_inf = -np.Inf - self.assertAllEqual(next_state_.log_probs[:, -3:], - [[neg_inf, neg_inf, neg_inf], - [neg_inf, neg_inf, neg_inf]]) + self.assertAllEqual( + next_state_.log_probs[:, -3:], + [[neg_inf, neg_inf, neg_inf], [neg_inf, neg_inf, neg_inf]]) self.assertEqual((next_state_.log_probs[:, :-3] > neg_inf).all(), True) self.assertEqual((next_state_.lengths[:, :-3] > 0).all(), True) - self.assertAllEqual(next_state_.lengths[:, -3:], [[0, 0, 0], - [0, 0, 0]]) + self.assertAllEqual(next_state_.lengths[:, -3:], [[0, 0, 0], [0, 0, 0]]) + class BeamSearchDecoderTest(test.TestCase): @@ -338,8 +342,8 @@ class BeamSearchDecoderTest(test.TestCase): initial_state = cell.zero_state(batch_size, dtypes.float32) if has_attention: inputs = array_ops.placeholder_with_default( - np.random.randn(batch_size, decoder_max_time, - input_depth).astype(np.float32), + np.random.randn(batch_size, decoder_max_time, input_depth).astype( + np.float32), shape=(None, None, input_depth)) tiled_inputs = beam_search_decoder.tile_batch( inputs, multiplier=beam_width) @@ -359,8 +363,7 @@ class BeamSearchDecoderTest(test.TestCase): cell_state = cell.zero_state( dtype=dtypes.float32, batch_size=batch_size_tensor * beam_width) if has_attention: - cell_state = cell_state.clone( - cell_state=initial_state) + cell_state = cell_state.clone(cell_state=initial_state) bsd = beam_search_decoder.BeamSearchDecoder( cell=cell, embedding=embedding, diff --git a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py index a5f7169c3106d12cd22e822dca96c6adf43a45fe..d6184d61095f727f9dcab56fe59e2601868c1624 100644 --- a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py +++ b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py @@ -37,7 +37,6 @@ from tensorflow.python.ops import rnn_cell_impl from tensorflow.python.ops import tensor_array_ops from tensorflow.python.util import nest - __all__ = [ "BeamSearchDecoderOutput", "BeamSearchDecoderState", @@ -48,8 +47,8 @@ __all__ = [ class BeamSearchDecoderState( - collections.namedtuple("BeamSearchDecoderState", ("cell_state", "log_probs", - "finished", "lengths"))): + collections.namedtuple("BeamSearchDecoderState", + ("cell_state", "log_probs", "finished", "lengths"))): pass @@ -85,11 +84,12 @@ def _tile_batch(t, multiplier): tiled_static_batch_size = ( t.shape[0].value * multiplier if t.shape[0].value is not None else None) tiled = array_ops.tile(array_ops.expand_dims(t, 1), tiling) - tiled = array_ops.reshape( - tiled, array_ops.concat(([shape_t[0] * multiplier], shape_t[1:]), 0)) + tiled = array_ops.reshape(tiled, + array_ops.concat( + ([shape_t[0] * multiplier], shape_t[1:]), 0)) tiled.set_shape( - tensor_shape.TensorShape( - [tiled_static_batch_size]).concatenate(t.shape[1:])) + tensor_shape.TensorShape([tiled_static_batch_size]).concatenate( + t.shape[1:])) return tiled @@ -197,8 +197,8 @@ class BeamSearchDecoder(decoder.Decoder): """ if not rnn_cell_impl._like_rnncell(cell): # pylint: disable=protected-access raise TypeError("cell must be an RNNCell, received: %s" % type(cell)) - if (output_layer is not None - and not isinstance(output_layer, layers_base.Layer)): + if (output_layer is not None and + not isinstance(output_layer, layers_base.Layer)): raise TypeError( "output_layer must be a Layer, received: %s" % type(output_layer)) self._cell = cell @@ -223,16 +223,17 @@ class BeamSearchDecoder(decoder.Decoder): self._beam_width = beam_width self._length_penalty_weight = length_penalty_weight self._initial_cell_state = nest.map_structure( - self._maybe_split_batch_beams, - initial_state, self._cell.state_size) + self._maybe_split_batch_beams, initial_state, self._cell.state_size) self._start_tokens = array_ops.tile( array_ops.expand_dims(self._start_tokens, 1), [1, self._beam_width]) self._start_inputs = self._embedding_fn(self._start_tokens) - + self._finished = array_ops.one_hot( array_ops.zeros([self._batch_size], dtype=dtypes.int32), - depth=self._beam_width, on_value=False, - off_value=True, dtype=dtypes.bool) + depth=self._beam_width, + on_value=False, + off_value=True, + dtype=dtypes.bool) @property def batch_size(self): @@ -250,8 +251,7 @@ class BeamSearchDecoder(decoder.Decoder): # dimensions to get the output size of the rnn with the layer # applied to the top. output_shape_with_unknown_batch = nest.map_structure( - lambda s: tensor_shape.TensorShape([None]).concatenate(s), - size) + lambda s: tensor_shape.TensorShape([None]).concatenate(s), size) layer_output_shape = self._output_layer.compute_output_shape( output_shape_with_unknown_batch) return nest.map_structure(lambda s: s[1:], layer_output_shape) @@ -302,10 +302,11 @@ class BeamSearchDecoder(decoder.Decoder): log_probs = array_ops.one_hot( # shape(batch_sz, beam_sz) array_ops.zeros([self._batch_size], dtype=dtypes.int32), - depth=self._beam_width, on_value=0.0, off_value=-np.Inf, + depth=self._beam_width, + on_value=0.0, + off_value=-np.Inf, dtype=nest.flatten(self._initial_cell_state)[0].dtype) - initial_state = BeamSearchDecoderState( cell_state=self._initial_cell_state, log_probs=log_probs, @@ -365,11 +366,12 @@ class BeamSearchDecoder(decoder.Decoder): t_shape = array_ops.shape(t) static_batch_size = tensor_util.constant_value(self._batch_size) batch_size_beam_width = ( - None if static_batch_size is None - else static_batch_size * self._beam_width) + None + if static_batch_size is None else static_batch_size * self._beam_width) reshaped_t = array_ops.reshape( - t, array_ops.concat( - ([self._batch_size * self._beam_width], t_shape[2:]), 0)) + t, + array_ops.concat(([self._batch_size * self._beam_width], t_shape[2:]), + 0)) reshaped_t.set_shape( (tensor_shape.TensorShape([batch_size_beam_width]).concatenate(s))) return reshaped_t @@ -398,8 +400,9 @@ class BeamSearchDecoder(decoder.Decoder): s = tensor_shape.TensorShape(s) t_shape = array_ops.shape(t) reshaped_t = array_ops.reshape( - t, array_ops.concat( - ([self._batch_size, self._beam_width], t_shape[1:]), 0)) + t, + array_ops.concat(([self._batch_size, self._beam_width], t_shape[1:]), + 0)) static_batch_size = tensor_util.constant_value(self._batch_size) expected_reshaped_shape = tensor_shape.TensorShape( [static_batch_size, self._beam_width]).concatenate(s) @@ -409,8 +412,8 @@ class BeamSearchDecoder(decoder.Decoder): "We expected it to have shape " "(batch_size, beam_width, depth) == %s. Perhaps you " "forgot to create a zero_state with " - "batch_size=encoder_batch_size * beam_width?" - % (reshaped_t.shape, expected_reshaped_shape)) + "batch_size=encoder_batch_size * beam_width?" % + (reshaped_t.shape, expected_reshaped_shape)) reshaped_t.set_shape(expected_reshaped_shape) return reshaped_t @@ -482,15 +485,13 @@ class BeamSearchDecoder(decoder.Decoder): cell_state = state.cell_state inputs = nest.map_structure( lambda inp: self._merge_batch_beams(inp, s=inp.shape[2:]), inputs) - cell_state = nest.map_structure( - self._maybe_merge_batch_beams, - cell_state, self._cell.state_size) + cell_state = nest.map_structure(self._maybe_merge_batch_beams, cell_state, + self._cell.state_size) cell_outputs, next_cell_state = self._cell(inputs, cell_state) cell_outputs = nest.map_structure( lambda out: self._split_batch_beams(out, out.shape[1:]), cell_outputs) next_cell_state = nest.map_structure( - self._maybe_split_batch_beams, - next_cell_state, self._cell.state_size) + self._maybe_split_batch_beams, next_cell_state, self._cell.state_size) if self._output_layer is not None: cell_outputs = self._output_layer(cell_outputs) @@ -553,7 +554,8 @@ def _beam_search_step(time, logits, next_cell_state, beam_state, batch_size, lengths_to_add = array_ops.one_hot( indices=array_ops.fill([batch_size, beam_width], end_token), depth=vocab_size, - on_value=np.int64(0), off_value=np.int64(1), + on_value=np.int64(0), + off_value=np.int64(1), dtype=dtypes.int64) add_mask = math_ops.to_int64(math_ops.logical_not(previously_finished)) lengths_to_add *= array_ops.expand_dims(add_mask, 2) @@ -572,8 +574,8 @@ def _beam_search_step(time, logits, next_cell_state, beam_state, batch_size, scores_flat = array_ops.reshape(scores, [batch_size, -1]) # Pick the next beams according to the specified successors function - next_beam_size = ops.convert_to_tensor(beam_width, dtype=dtypes.int32, - name="beam_width") + next_beam_size = ops.convert_to_tensor( + beam_width, dtype=dtypes.int32, name="beam_width") next_beam_scores, word_indices = nn_ops.top_k(scores_flat, k=next_beam_size) next_beam_scores.set_shape([static_batch_size, beam_width]) @@ -592,11 +594,11 @@ def _beam_search_step(time, logits, next_cell_state, beam_state, batch_size, # name="next_beam_word_ids") # would be a lot cleaner but for reasons unclear, that hides the results of # the op which prevents capturing it with tfdbg debug ops. - raw_next_word_ids = math_ops.mod(word_indices, vocab_size, - name="next_beam_word_ids") + raw_next_word_ids = math_ops.mod( + word_indices, vocab_size, name="next_beam_word_ids") next_word_ids = math_ops.to_int32(raw_next_word_ids) - next_beam_ids = math_ops.to_int32(word_indices / vocab_size, - name="next_beam_parent_ids") + next_beam_ids = math_ops.to_int32( + word_indices / vocab_size, name="next_beam_parent_ids") # Append new ids to current predictions previously_finished = _tensor_gather_helper( @@ -605,9 +607,10 @@ def _beam_search_step(time, logits, next_cell_state, beam_state, batch_size, batch_size=batch_size, range_size=beam_width, gather_shape=[-1]) - next_finished = math_ops.logical_or(previously_finished, - math_ops.equal(next_word_ids, end_token), - name="next_beam_finished") + next_finished = math_ops.logical_or( + previously_finished, + math_ops.equal(next_word_ids, end_token), + name="next_beam_finished") # Calculate the length of the next predictions. # 1. Finished beams remain unchanged. @@ -768,8 +771,12 @@ def _maybe_tensor_gather_helper(gather_indices, gather_from, batch_size, return gather_from -def _tensor_gather_helper(gather_indices, gather_from, batch_size, - range_size, gather_shape, name=None): +def _tensor_gather_helper(gather_indices, + gather_from, + batch_size, + range_size, + gather_shape, + name=None): """Helper for gathering the right indices from the tensor. This works by reshaping gather_from to gather_shape (e.g. [-1]) and then @@ -800,9 +807,9 @@ def _tensor_gather_helper(gather_indices, gather_from, batch_size, array_ops.reshape(gather_from, gather_shape), gather_indices) final_shape = array_ops.shape(gather_from)[:1 + len(gather_shape)] static_batch_size = tensor_util.constant_value(batch_size) - final_static_shape = (tensor_shape.TensorShape([static_batch_size]) - .concatenate( - gather_from.shape[1:1 + len(gather_shape)])) + final_static_shape = ( + tensor_shape.TensorShape([static_batch_size]).concatenate( + gather_from.shape[1:1 + len(gather_shape)])) output = array_ops.reshape(output, final_shape, name="output") output.set_shape(final_static_shape) return output diff --git a/tensorflow/contrib/session_bundle/bundle_shim.h b/tensorflow/contrib/session_bundle/bundle_shim.h index e24efa0de14824044591b954b8465ebeebc10dd5..4628b6ab1b1164addef6aaf930a0dbe7091cd16d 100644 --- a/tensorflow/contrib/session_bundle/bundle_shim.h +++ b/tensorflow/contrib/session_bundle/bundle_shim.h @@ -15,8 +15,8 @@ limitations under the License. // Shim for systems that need to load both SessionBundle and // SavedModelBundle interchangeably during migration to SavedModel. -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_SESSION_BUNDLE_BUNDLE_SHIM_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_SESSION_BUNDLE_BUNDLE_SHIM_H_ +#ifndef TENSORFLOW_CONTRIB_SESSION_BUNDLE_BUNDLE_SHIM_H_ +#define TENSORFLOW_CONTRIB_SESSION_BUNDLE_BUNDLE_SHIM_H_ #include @@ -67,4 +67,4 @@ Status LoadSessionBundleOrSavedModelBundle( } // namespace serving } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_SESSION_BUNDLE_BUNDLE_SHIM_H_ +#endif // TENSORFLOW_CONTRIB_SESSION_BUNDLE_BUNDLE_SHIM_H_ diff --git a/tensorflow/contrib/session_bundle/bundle_shim.py b/tensorflow/contrib/session_bundle/bundle_shim.py index 062c9cc68046c59ffd04190dad0fa69f5f9dfa0a..3149875e41f6f77b3bcbc0ab1a150cfdc59ad2ba 100644 --- a/tensorflow/contrib/session_bundle/bundle_shim.py +++ b/tensorflow/contrib/session_bundle/bundle_shim.py @@ -82,7 +82,7 @@ def _convert_default_signature_to_signature_def(signatures): """ default_signature = signatures.default_signature signature_def = meta_graph_pb2.SignatureDef() - if default_signature.WhichOneof("type") == "regression_signature": + if default_signature.WhichOneof("type") == legacy_constants.REGRESSION_SIGNATURE: regression_signature = default_signature.regression_signature signature_def.method_name = signature_constants.REGRESS_METHOD_NAME _add_input_to_signature_def(regression_signature.input.tensor_name, @@ -91,7 +91,7 @@ def _convert_default_signature_to_signature_def(signatures): _add_output_to_signature_def(regression_signature.output.tensor_name, signature_constants.REGRESS_OUTPUTS, signature_def) - elif default_signature.WhichOneof("type") == "classification_signature": + elif default_signature.WhichOneof("type") == legacy_constants.CLASSIFICATION_SIGNATURE: classification_signature = default_signature.classification_signature signature_def.method_name = signature_constants.CLASSIFY_METHOD_NAME _add_input_to_signature_def(classification_signature.input.tensor_name, @@ -132,8 +132,8 @@ def _convert_named_signatures_to_signature_def(signatures): signature_constants.PREDICT_OUTPUTS] # TODO(pdudnik): what if there are other signatures? Mimic cr/140900781 once # it is submitted. - if (input_signature.WhichOneof("type") != "generic_signature" or - output_signature.WhichOneof("type") != "generic_signature"): + if (input_signature.WhichOneof("type") != legacy_constants.GENERIC_SIGNATURE or + output_signature.WhichOneof("type") != legacy_constants.GENERIC_SIGNATURE): raise RuntimeError("Named input and output signatures can only be " "up-converted if they are generic signature. " "Input signature type is %s, output signature type is " diff --git a/tensorflow/contrib/session_bundle/constants.py b/tensorflow/contrib/session_bundle/constants.py index 6ced73241afdda047b8feacb26fedd72363b6240..e833baee791f97df5829ee289bcaf17c31a17deb 100644 --- a/tensorflow/contrib/session_bundle/constants.py +++ b/tensorflow/contrib/session_bundle/constants.py @@ -32,3 +32,6 @@ INIT_OP_KEY = "serving_init_op" SIGNATURES_KEY = "serving_signatures" ASSETS_KEY = "serving_assets" GRAPH_KEY = "serving_graph" +REGRESSION_SIGNATURE = "regression_signature" +CLASSIFICATION_SIGNATURE = "classification_signature" +GENERIC_SIGNATURE = "generic_signature" diff --git a/tensorflow/contrib/session_bundle/session_bundle.h b/tensorflow/contrib/session_bundle/session_bundle.h index 2ff258411d1928cea7da4f637ffe94f144b2b60a..b2be46efa6d1e7ceb1fb66a7148735b86cc68dd3 100644 --- a/tensorflow/contrib/session_bundle/session_bundle.h +++ b/tensorflow/contrib/session_bundle/session_bundle.h @@ -15,8 +15,8 @@ limitations under the License. // Low-level functionality for setting up a inference Session. -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_SESSION_BUNDLE_SESSION_BUNDLE_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_SESSION_BUNDLE_SESSION_BUNDLE_H_ +#ifndef TENSORFLOW_CONTRIB_SESSION_BUNDLE_SESSION_BUNDLE_H_ +#define TENSORFLOW_CONTRIB_SESSION_BUNDLE_SESSION_BUNDLE_H_ #include @@ -82,4 +82,4 @@ bool IsPossibleExportDirectory(const StringPiece export_dir); } // namespace serving } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_SESSION_BUNDLE_SESSION_BUNDLE_H_ +#endif // TENSORFLOW_CONTRIB_SESSION_BUNDLE_SESSION_BUNDLE_H_ diff --git a/tensorflow/contrib/session_bundle/signature.h b/tensorflow/contrib/session_bundle/signature.h index 0049bea00822db85c606b9e6d00ae4db83804bab..4ef1277cec413a6fcfb54721520279d024f18bc1 100644 --- a/tensorflow/contrib/session_bundle/signature.h +++ b/tensorflow/contrib/session_bundle/signature.h @@ -15,8 +15,8 @@ limitations under the License. // Helpers for working with TensorFlow exports and their signatures. -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_SESSION_BUNDLE_SIGNATURE_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_SESSION_BUNDLE_SIGNATURE_H_ +#ifndef TENSORFLOW_CONTRIB_SESSION_BUNDLE_SIGNATURE_H_ +#define TENSORFLOW_CONTRIB_SESSION_BUNDLE_SIGNATURE_H_ #include #include @@ -121,4 +121,4 @@ Status BindGenericNames(const GenericSignature& signature, } // namespace serving } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_SESSION_BUNDLE_SIGNATURE_H_ +#endif // TENSORFLOW_CONTRIB_SESSION_BUNDLE_SIGNATURE_H_ diff --git a/tensorflow/contrib/session_bundle/test_util.h b/tensorflow/contrib/session_bundle/test_util.h index dd0fc8d1c0c47c444ac6fe807435fa671f3939f0..f0d41ce5a4b901db80a7a01475dc0917e966dc89 100644 --- a/tensorflow/contrib/session_bundle/test_util.h +++ b/tensorflow/contrib/session_bundle/test_util.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_SESSION_BUNDLE_TEST_UTIL_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_SESSION_BUNDLE_TEST_UTIL_H_ +#ifndef TENSORFLOW_CONTRIB_SESSION_BUNDLE_TEST_UTIL_H_ +#define TENSORFLOW_CONTRIB_SESSION_BUNDLE_TEST_UTIL_H_ #include @@ -35,4 +35,4 @@ string TestSrcDirPath(const string& relative_path); } // namespace serving } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_SESSION_BUNDLE_TEST_UTIL_H_ +#endif // TENSORFLOW_CONTRIB_SESSION_BUNDLE_TEST_UTIL_H_ diff --git a/tensorflow/contrib/slim/python/slim/evaluation_test.py b/tensorflow/contrib/slim/python/slim/evaluation_test.py index 870f504d10362ed5226951adefc3ba9a934900c1..f5a9299d263450ba89617f38bf7a4c5cbc359cb1 100644 --- a/tensorflow/contrib/slim/python/slim/evaluation_test.py +++ b/tensorflow/contrib/slim/python/slim/evaluation_test.py @@ -236,7 +236,7 @@ class SingleEvaluationTest(test.TestCase): def _prepareCheckpoint(self, checkpoint_path): init_op = control_flow_ops.group(variables.global_variables_initializer(), variables.local_variables_initializer()) - saver = saver_lib.Saver(write_version=saver_pb2.SaverDef.V1) + saver = saver_lib.Saver() with self.test_session() as sess: sess.run(init_op) saver.save(sess, checkpoint_path) diff --git a/tensorflow/contrib/solvers/python/kernel_tests/linear_equations_test.py b/tensorflow/contrib/solvers/python/kernel_tests/linear_equations_test.py index 930df2414bc907703c2670ffd92134727a28e856..7b609ae96b20a5c3d078777cc8fbb475e5eebb1b 100644 --- a/tensorflow/contrib/solvers/python/kernel_tests/linear_equations_test.py +++ b/tensorflow/contrib/solvers/python/kernel_tests/linear_equations_test.py @@ -45,32 +45,55 @@ def _get_linear_equations_tests(dtype_, use_static_shape_, shape_): low=-1.0, high=1.0, size=np.prod(shape_)).reshape(shape_).astype(dtype_) # Make a selfadjoint, positive definite. a_np = np.dot(a_np.T, a_np) + # jacobi preconditioner + jacobi_np = np.zeros_like(a_np) + jacobi_np[range(a_np.shape[0]), range(a_np.shape[1])] = (1.0 / + a_np.diagonal()) rhs_np = np.random.uniform( low=-1.0, high=1.0, size=shape_[0]).astype(dtype_) + x_np = np.zeros_like(rhs_np) tol = 1e-6 if dtype_ == np.float64 else 1e-3 max_iter = 20 with self.test_session() as sess: if use_static_shape_: a = constant_op.constant(a_np) rhs = constant_op.constant(rhs_np) + x = constant_op.constant(x_np) + jacobi = constant_op.constant(jacobi_np) else: a = array_ops.placeholder(dtype_) rhs = array_ops.placeholder(dtype_) + x = array_ops.placeholder(dtype_) + jacobi = array_ops.placeholder(dtype_) operator = util.create_operator(a) - cg_graph = linear_equations.conjugate_gradient( - operator, rhs, tol=tol, max_iter=max_iter) - if use_static_shape_: - cg_val = sess.run(cg_graph) - else: - cg_val = sess.run(cg_graph, feed_dict={a: a_np, rhs: rhs_np}) - norm_r0 = np.linalg.norm(rhs_np) - norm_r = np.sqrt(cg_val.gamma) - self.assertLessEqual(norm_r, tol * norm_r0) - # Validate that we get an equally small residual norm with numpy - # using the computed solution. - r_np = rhs_np - np.dot(a_np, cg_val.x) - norm_r_np = np.linalg.norm(r_np) - self.assertLessEqual(norm_r_np, tol * norm_r0) + preconditioners = [None, util.identity_operator(a), + util.create_operator(jacobi)] + cg_results = [] + for preconditioner in preconditioners: + cg_graph = linear_equations.conjugate_gradient( + operator, rhs, preconditioner=preconditioner, + x=x, tol=tol, max_iter=max_iter) + if use_static_shape_: + cg_val = sess.run(cg_graph) + else: + cg_val = sess.run(cg_graph, feed_dict={a: a_np, rhs: rhs_np, x: x_np, + jacobi: jacobi_np}) + norm_r0 = np.linalg.norm(rhs_np) + norm_r = np.linalg.norm(cg_val.r) + self.assertLessEqual(norm_r, tol * norm_r0) + # Validate that we get an equally small residual norm with numpy + # using the computed solution. + r_np = rhs_np - np.dot(a_np, cg_val.x) + norm_r_np = np.linalg.norm(r_np) + self.assertLessEqual(norm_r_np, tol * norm_r0) + cg_results.append(cg_val) + # Validate that we get same results using identity_preconditioner + # and None + self.assertEqual(cg_results[0].i, cg_results[1].i) + self.assertAlmostEqual(cg_results[0].gamma, cg_results[1].gamma) + self.assertAllClose(cg_results[0].r, cg_results[1].r, rtol=tol) + self.assertAllClose(cg_results[0].x, cg_results[1].x, rtol=tol) + self.assertAllClose(cg_results[0].p, cg_results[1].p, rtol=tol) return [test_conjugate_gradient] diff --git a/tensorflow/contrib/solvers/python/kernel_tests/util_test.py b/tensorflow/contrib/solvers/python/kernel_tests/util_test.py index 1566984b27fdab4c2a8c91bd16f587747e69e9e5..12e94369cbae462c21867657119cd2dd9ee29651 100644 --- a/tensorflow/contrib/solvers/python/kernel_tests/util_test.py +++ b/tensorflow/contrib/solvers/python/kernel_tests/util_test.py @@ -63,6 +63,41 @@ class UtilTest(test.TestCase): def testCreateOperatorUnknownShape(self): self._testCreateOperator(False) + def _testIdentityOperator(self, use_static_shape_): + for dtype in np.float32, np.float64: + a_np = np.array([[1., 2.], [3., 4.], [5., 6.]], dtype=dtype) + x_np = np.array([[2.], [-3.]], dtype=dtype) + y_np = np.array([[2], [-3.], [5.]], dtype=dtype) + with self.test_session() as sess: + if use_static_shape_: + a = constant_op.constant(a_np, dtype=dtype) + x = constant_op.constant(x_np, dtype=dtype) + y = constant_op.constant(y_np, dtype=dtype) + else: + a = array_ops.placeholder(dtype) + x = array_ops.placeholder(dtype) + y = array_ops.placeholder(dtype) + id_op = util.identity_operator(a) + ax = id_op.apply(x) + aty = id_op.apply_adjoint(y) + op_shape = ops.convert_to_tensor(id_op.shape) + if use_static_shape_: + op_shape_val, ax_val, aty_val = sess.run([op_shape, ax, aty]) + else: + op_shape_val, ax_val, aty_val = sess.run( + [op_shape, ax, aty], feed_dict={a: a_np, + x: x_np, + y: y_np}) + self.assertAllEqual(op_shape_val, [3, 2]) + self.assertAllClose(ax_val, x_np) + self.assertAllClose(aty_val, y_np) + + def testIdentityOperator(self): + self._testIdentityOperator(True) + + def testIdentityOperatorUnknownShape(self): + self._testIdentityOperator(False) + def testL2Norm(self): with self.test_session(): x_np = np.array([[2], [-3.], [5.]]) diff --git a/tensorflow/contrib/solvers/python/ops/linear_equations.py b/tensorflow/contrib/solvers/python/ops/linear_equations.py index 8cba56eba6b3046b8efbbbbf130705255e1c13bb..4dfaa97ac9834ca3c13a9f8e8d721ddaba33bf7d 100644 --- a/tensorflow/contrib/solvers/python/ops/linear_equations.py +++ b/tensorflow/contrib/solvers/python/ops/linear_equations.py @@ -27,10 +27,13 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import linalg_ops def conjugate_gradient(operator, rhs, + preconditioner=None, + x=None, tol=1e-4, max_iter=20, name="conjugate_gradient"): @@ -55,6 +58,15 @@ def conjugate_gradient(operator, vector with the result of applying the operator to `x`, i.e. if `operator` represents matrix `A`, `apply` should return `A * x`. rhs: A rank-1 `Tensor` of shape `[N]` containing the right-hand size vector. + preconditioner: An object representing a linear operator, see `operator` + for detail. The preconditioner should approximate the inverse of `A`. + An efficient preconditioner could dramatically improve the rate of + convergence. If `preconditioner` represents matrix `M`(`M` approximates + `A^{-1}`), the algorithm uses `preconditioner.apply(x)` to estimate + `A^{-1}x`. For this to be useful, the cost of applying `M` should be + much lower than computing `A^{-1}` directly. + x: A rank-1 `Tensor` of shape `[N]` containing the initial guess for the + solution. tol: A float scalar convergence tolerance. max_iter: An integer giving the maximum number of iterations. name: A name scope for the operation. @@ -65,35 +77,51 @@ def conjugate_gradient(operator, - x: A rank-1 `Tensor` of shape `[N]` containing the computed solution. - r: A rank-1 `Tensor` of shape `[M]` containing the residual vector. - p: A rank-1 `Tensor` of shape `[N]`. `A`-conjugate basis vector. - - gamma: \\(||r||_2^2\\) + - gamma: \\(r \dot M \dot r\\), equivalent to \\(||r||_2^2\\) when + `preconditioner=None`. """ # ephemeral class holding CG state. cg_state = collections.namedtuple("CGState", ["i", "x", "r", "p", "gamma"]) def stopping_criterion(i, state): - return math_ops.logical_and(i < max_iter, state.gamma > tol) + return math_ops.logical_and(i < max_iter, + linalg_ops.norm(state.r) > tol) - # TODO(rmlarsen): add preconditioning def cg_step(i, state): z = operator.apply(state.p) alpha = state.gamma / util.dot(state.p, z) x = state.x + alpha * state.p r = state.r - alpha * z - gamma = util.l2norm_squared(r) - beta = gamma / state.gamma - p = r + beta * state.p + if preconditioner is None: + gamma = util.dot(r, r) + beta = gamma / state.gamma + p = r + beta * state.p + else: + q = preconditioner.apply(r) + gamma = util.dot(r, q) + beta = gamma / state.gamma + p = q + beta * state.p return i + 1, cg_state(i + 1, x, r, p, gamma) with ops.name_scope(name): n = operator.shape[1:] rhs = array_ops.expand_dims(rhs, -1) - gamma0 = util.l2norm_squared(rhs) - tol = tol * tol * gamma0 - x = array_ops.expand_dims( - array_ops.zeros( - n, dtype=rhs.dtype.base_dtype), -1) + if x is None: + x = array_ops.expand_dims( + array_ops.zeros( + n, dtype=rhs.dtype.base_dtype), -1) + r0 = rhs + else: + x = array_ops.expand_dims(x, -1) + r0 = rhs - operator.apply(x) + if preconditioner is None: + p0 = r0 + else: + p0 = preconditioner.apply(r0) + gamma0 = util.dot(r0, p0) + tol = tol * linalg_ops.norm(r0) i = constant_op.constant(0, dtype=dtypes.int32) - state = cg_state(i=i, x=x, r=rhs, p=rhs, gamma=gamma0) + state = cg_state(i=i, x=x, r=r0, p=p0, gamma=gamma0) _, state = control_flow_ops.while_loop(stopping_criterion, cg_step, [i, state]) return cg_state( diff --git a/tensorflow/contrib/solvers/python/ops/util.py b/tensorflow/contrib/solvers/python/ops/util.py index 777e0c185d6c9fffab6a7fe6e6ae4c133c62ad1a..96947e8eea1006bcd03cf09cd13cd1266695cc2e 100644 --- a/tensorflow/contrib/solvers/python/ops/util.py +++ b/tensorflow/contrib/solvers/python/ops/util.py @@ -45,6 +45,23 @@ def create_operator(matrix): apply_adjoint=lambda v: math_ops.matmul(matrix, v, adjoint_a=True)) +def identity_operator(matrix): + """Creates a linear operator from a rank-2 identity tensor.""" + + linear_operator = collections.namedtuple( + "LinearOperator", ["shape", "dtype", "apply", "apply_adjoint"]) + shape = matrix.get_shape() + if shape.is_fully_defined(): + shape = shape.as_list() + else: + shape = array_ops.shape(matrix) + return linear_operator( + shape=shape, + dtype=matrix.dtype, + apply=lambda v: v, + apply_adjoint=lambda v: v) + + # TODO(rmlarsen): Measure if we should just call matmul. def dot(x, y): return math_ops.reduce_sum(math_ops.conj(x) * y) diff --git a/tensorflow/contrib/tensor_forest/kernels/data_spec.h b/tensorflow/contrib/tensor_forest/kernels/data_spec.h index 05590d6992e2fd7eeee8d242561229ab53bb16de..0a3abe56dfc4f611ac8ed0815e4c74a639d2477e 100644 --- a/tensorflow/contrib/tensor_forest/kernels/data_spec.h +++ b/tensorflow/contrib/tensor_forest/kernels/data_spec.h @@ -15,8 +15,8 @@ // This is a surrogate for using a proto, since it doesn't seem to be possible // to use protos in a dynamically-loaded/shared-linkage library, which is // what is used for custom ops in tensorflow/contrib. -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_CORE_OPS_DATA_SPEC_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_CORE_OPS_DATA_SPEC_H_ +#ifndef TENSORFLOW_CONTRIB_TENSOR_FOREST_CORE_OPS_DATA_SPEC_H_ +#define TENSORFLOW_CONTRIB_TENSOR_FOREST_CORE_OPS_DATA_SPEC_H_ #include #include "tensorflow/core/lib/strings/numbers.h" @@ -138,4 +138,4 @@ class TensorForestDataSpec { } // namespace tensorforest } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_CORE_OPS_DATA_SPEC_H_ +#endif // TENSORFLOW_CONTRIB_TENSOR_FOREST_CORE_OPS_DATA_SPEC_H_ diff --git a/tensorflow/contrib/tensor_forest/kernels/tree_utils.h b/tensorflow/contrib/tensor_forest/kernels/tree_utils.h index 35f9fb7eaf4d73e98293f1ea6a4b45b71212a92c..dad9df4898844eaa17bdfe5b4b298a95377fd12e 100644 --- a/tensorflow/contrib/tensor_forest/kernels/tree_utils.h +++ b/tensorflow/contrib/tensor_forest/kernels/tree_utils.h @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_CORE_OPS_TREE_UTILS_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_CORE_OPS_TREE_UTILS_H_ +#ifndef TENSORFLOW_CONTRIB_TENSOR_FOREST_CORE_OPS_TREE_UTILS_H_ +#define TENSORFLOW_CONTRIB_TENSOR_FOREST_CORE_OPS_TREE_UTILS_H_ #include @@ -307,4 +307,4 @@ void GetParentWeightedMean(float leaf_sum, const float* leaf_data, } // namespace tensorforest } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_CORE_OPS_TREE_UTILS_H_ +#endif // TENSORFLOW_CONTRIB_TENSOR_FOREST_CORE_OPS_TREE_UTILS_H_ diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/candidate_graph_runner.h b/tensorflow/contrib/tensor_forest/kernels/v4/candidate_graph_runner.h index 4bd1f06c72945f73e50301c337692e0b510d3693..2e7368dc12c74b9dc44b72394668bf2de71f2f90 100644 --- a/tensorflow/contrib/tensor_forest/kernels/v4/candidate_graph_runner.h +++ b/tensorflow/contrib/tensor_forest/kernels/v4/candidate_graph_runner.h @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_CANDIDATE_GRAPH_RUNNER_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_CANDIDATE_GRAPH_RUNNER_H_ +#ifndef TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_CANDIDATE_GRAPH_RUNNER_H_ +#define TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_CANDIDATE_GRAPH_RUNNER_H_ #include #include @@ -70,4 +70,4 @@ class CandidateGraphRunner { } // namespace tensorforest } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_CANDIDATE_GRAPH_RUNNER_H_ +#endif // TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_CANDIDATE_GRAPH_RUNNER_H_ diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/decision-tree-resource.h b/tensorflow/contrib/tensor_forest/kernels/v4/decision-tree-resource.h index bf88216d663cc9b69746a93379124bf1d9a30df9..cced26b9036ba8ba6c5994b7483261a062f80588 100644 --- a/tensorflow/contrib/tensor_forest/kernels/v4/decision-tree-resource.h +++ b/tensorflow/contrib/tensor_forest/kernels/v4/decision-tree-resource.h @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_DECISION_TREE_RESOURCE_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_DECISION_TREE_RESOURCE_H_ +#ifndef TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_DECISION_TREE_RESOURCE_H_ +#define TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_DECISION_TREE_RESOURCE_H_ #include "tensorflow/contrib/decision_trees/proto/generic_tree_model.pb.h" #include "tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator.h" @@ -88,4 +88,4 @@ class DecisionTreeResource : public ResourceBase { } // namespace tensorforest } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_DECISION_TREE_RESOURCE_H_ +#endif // TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_DECISION_TREE_RESOURCE_H_ diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator.h b/tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator.h index 3f03c2d05bb1090fa75f4b6e7ad4f00caaea61a4..85ce7b825b11983307370bb3ac30eeec9b6b2c99 100644 --- a/tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator.h +++ b/tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator.h @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_DECISION_NODE_EVALUATOR_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_DECISION_NODE_EVALUATOR_H_ +#ifndef TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_DECISION_NODE_EVALUATOR_H_ +#define TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_DECISION_NODE_EVALUATOR_H_ #include "tensorflow/contrib/decision_trees/proto/generic_tree_model.pb.h" #include "tensorflow/contrib/decision_trees/proto/generic_tree_model_extensions.pb.h" @@ -104,4 +104,4 @@ struct CandidateEvalatorCollection { } // namespace tensorforest } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_DECISION_NODE_EVALUATOR_H_ +#endif // TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_DECISION_NODE_EVALUATOR_H_ diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/fertile-stats-resource.h b/tensorflow/contrib/tensor_forest/kernels/v4/fertile-stats-resource.h index dacf033d99018d47787b644b12d3181780df7113..0d6712e9e552d7045eb198f7e65d04eb42eff920 100644 --- a/tensorflow/contrib/tensor_forest/kernels/v4/fertile-stats-resource.h +++ b/tensorflow/contrib/tensor_forest/kernels/v4/fertile-stats-resource.h @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_FERTILE_STATS_RESOURCE_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_FERTILE_STATS_RESOURCE_H_ +#ifndef TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_FERTILE_STATS_RESOURCE_H_ +#define TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_FERTILE_STATS_RESOURCE_H_ #include @@ -98,4 +98,4 @@ class FertileStatsResource : public ResourceBase { } // namespace tensorforest } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_FERTILE_STATS_RESOURCE_H_ +#endif // TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_FERTILE_STATS_RESOURCE_H_ diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/graph_collection_operator.h b/tensorflow/contrib/tensor_forest/kernels/v4/graph_collection_operator.h index 2ae3a79b3dd69b3fd3d31a055589b2edc63afa3c..4ae48179afc8452e6a3ec61dede16b9941482bcc 100644 --- a/tensorflow/contrib/tensor_forest/kernels/v4/graph_collection_operator.h +++ b/tensorflow/contrib/tensor_forest/kernels/v4/graph_collection_operator.h @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_GRAPH_COLLECTION_OPERATOR_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_GRAPH_COLLECTION_OPERATOR_H_ +#ifndef TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_GRAPH_COLLECTION_OPERATOR_H_ +#define TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_GRAPH_COLLECTION_OPERATOR_H_ #include #include "tensorflow/contrib/decision_trees/proto/generic_tree_model.pb.h" @@ -78,4 +78,4 @@ class GraphRunnerSplitCollectionOperator : public SplitCollectionOperator { } // namespace tensorforest } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_GRAPH_COLLECTION_OPERATOR_H_ +#endif // TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_GRAPH_COLLECTION_OPERATOR_H_ diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/grow_stats.h b/tensorflow/contrib/tensor_forest/kernels/v4/grow_stats.h index 3e41ab50b9d78943db8ee58aab85a8c7541e2320..f938d08c84d72b4c5a71e8f7fb0f639aa70e3e49 100644 --- a/tensorflow/contrib/tensor_forest/kernels/v4/grow_stats.h +++ b/tensorflow/contrib/tensor_forest/kernels/v4/grow_stats.h @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_GROW_STATS_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_GROW_STATS_H_ +#ifndef TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_GROW_STATS_H_ +#define TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_GROW_STATS_H_ #include #include @@ -609,4 +609,4 @@ class LeastSquaresRegressionGrowStats : public GrowStats { } // namespace tensorforest } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_GROW_STATS_H_ +#endif // TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_GROW_STATS_H_ diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/input_data.h b/tensorflow/contrib/tensor_forest/kernels/v4/input_data.h index e3d4edbf8a512a027e4b67916d1f2ad3f347a18b..eafad6b591672f67ae816405ff603f9aaba30a1b 100644 --- a/tensorflow/contrib/tensor_forest/kernels/v4/input_data.h +++ b/tensorflow/contrib/tensor_forest/kernels/v4/input_data.h @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_INPUT_DATA_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_INPUT_DATA_H_ +#ifndef TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_INPUT_DATA_H_ +#define TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_INPUT_DATA_H_ #include #include #include "google/protobuf/any.pb.h" @@ -123,4 +123,4 @@ class TensorDataSet { } // namespace tensorforest } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_INPUT_DATA_H_ +#endif // TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_INPUT_DATA_H_ diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/input_target.h b/tensorflow/contrib/tensor_forest/kernels/v4/input_target.h index 0309ec1de9aec1044eb87e01cafc40c26ba3de14..44ec09c50ef3d092bd1bf7f051f492e1fffdd05b 100644 --- a/tensorflow/contrib/tensor_forest/kernels/v4/input_target.h +++ b/tensorflow/contrib/tensor_forest/kernels/v4/input_target.h @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_INPUT_TARGET_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_INPUT_TARGET_H_ +#ifndef TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_INPUT_TARGET_H_ +#define TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_INPUT_TARGET_H_ #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_types.h" @@ -89,4 +89,4 @@ class TensorInputTarget : public StoredInputTarget { } // namespace tensorforest } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_INPUT_TARGET_H_ +#endif // TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_INPUT_TARGET_H_ diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/leaf_model_operators.h b/tensorflow/contrib/tensor_forest/kernels/v4/leaf_model_operators.h index 946a648f22ff4175782c42cc70c59440e6ac0e17..cc4ec8dc9e330784bbcfeb54fa92e0a2db9449a8 100644 --- a/tensorflow/contrib/tensor_forest/kernels/v4/leaf_model_operators.h +++ b/tensorflow/contrib/tensor_forest/kernels/v4/leaf_model_operators.h @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_LEAF_MODEL_OPERATORS_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_LEAF_MODEL_OPERATORS_H_ +#ifndef TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_LEAF_MODEL_OPERATORS_H_ +#define TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_LEAF_MODEL_OPERATORS_H_ #include "tensorflow/contrib/decision_trees/proto/generic_tree_model.pb.h" #include "tensorflow/contrib/tensor_forest/kernels/v4/input_target.h" @@ -146,4 +146,4 @@ class LeafModelOperatorFactory { } // namespace tensorforest } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_LEAF_MODEL_OPERATORS_H_ +#endif // TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_LEAF_MODEL_OPERATORS_H_ diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/params.h b/tensorflow/contrib/tensor_forest/kernels/v4/params.h index 97a9d8d096311faaae774e9e4b2e45f28ed7fa29..b0ed949424756cc498d4b7ad1fb1867fff11b265 100644 --- a/tensorflow/contrib/tensor_forest/kernels/v4/params.h +++ b/tensorflow/contrib/tensor_forest/kernels/v4/params.h @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_PARAMS_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_PARAMS_H_ +#ifndef TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_PARAMS_H_ +#define TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_PARAMS_H_ #include "tensorflow/contrib/tensor_forest/proto/tensor_forest_params.pb.h" #include "tensorflow/core/platform/types.h" @@ -28,5 +28,4 @@ float ResolveParam(const DepthDependentParam& param, int32 depth); } // namespace tensorforest } // namespace tensorflow - -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_PARAMS_H_ +#endif // TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_PARAMS_H_ diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/split_collection_operators.h b/tensorflow/contrib/tensor_forest/kernels/v4/split_collection_operators.h index 6c21c0bd3443347bdb0102727b15b26754a0ed53..ad52f89faddb15be77644b5dc374aca73c46b149 100644 --- a/tensorflow/contrib/tensor_forest/kernels/v4/split_collection_operators.h +++ b/tensorflow/contrib/tensor_forest/kernels/v4/split_collection_operators.h @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_SPLIT_COLLECTION_OPERATORS_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_SPLIT_COLLECTION_OPERATORS_H_ +#ifndef TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_SPLIT_COLLECTION_OPERATORS_H_ +#define TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_SPLIT_COLLECTION_OPERATORS_H_ #include #include "tensorflow/contrib/decision_trees/proto/generic_tree_model.pb.h" @@ -128,6 +128,4 @@ class AnyCollectionCreator : public CollectionCreator { } // namespace tensorforest } // namespace tensorflow - - -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_SPLIT_COLLECTION_OPERATORS_H_ +#endif // TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_SPLIT_COLLECTION_OPERATORS_H_ diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/stat_utils.h b/tensorflow/contrib/tensor_forest/kernels/v4/stat_utils.h index 8e002d0414f48a1f409952f56c57b4e37815bca0..e6140065bbf12f2eb92c28e4affb3327f86af5d3 100644 --- a/tensorflow/contrib/tensor_forest/kernels/v4/stat_utils.h +++ b/tensorflow/contrib/tensor_forest/kernels/v4/stat_utils.h @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_STAT_UTILS_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_STAT_UTILS_H_ +#ifndef TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_STAT_UTILS_H_ +#define TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_STAT_UTILS_H_ #include "tensorflow/contrib/tensor_forest/proto/fertile_stats.pb.h" #include "tensorflow/core/platform/types.h" @@ -47,4 +47,4 @@ float WeightedSmoothedGini(float sum, float square, int num_classes); } // namespace tensorforest } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_STAT_UTILS_H_ +#endif // TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_STAT_UTILS_H_ diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/test_utils.h b/tensorflow/contrib/tensor_forest/kernels/v4/test_utils.h index b6e543b96fd5a00f78555eaf8558f0a95d0a6713..289c81e9d51dbc5d2023f7eabce8c2089748645d 100644 --- a/tensorflow/contrib/tensor_forest/kernels/v4/test_utils.h +++ b/tensorflow/contrib/tensor_forest/kernels/v4/test_utils.h @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_TEST_UTILS_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_TEST_UTILS_H_ +#ifndef TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_TEST_UTILS_H_ +#define TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_TEST_UTILS_H_ #include "tensorflow/contrib/tensor_forest/kernels/v4/input_data.h" #include "tensorflow/contrib/tensor_forest/kernels/v4/input_target.h" @@ -71,4 +71,4 @@ class TestableDataSet : public TensorDataSet { } // namespace tensorforest } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_TEST_UTILS_H_ +#endif // TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_TEST_UTILS_H_ diff --git a/tensorflow/contrib/tpu/profiler/BUILD b/tensorflow/contrib/tpu/profiler/BUILD index 346c03067d51350b9939123d6afa69d8127bdf01..198da0203a7d17249c4f50110713121b74d5ca4f 100644 --- a/tensorflow/contrib/tpu/profiler/BUILD +++ b/tensorflow/contrib/tpu/profiler/BUILD @@ -44,13 +44,22 @@ cc_library( ], ) +cc_library( + name = "version", + hdrs = ["version.h"], + visibility = ["//visibility:public"], +) + tf_cc_binary( name = "capture_tpu_profile", - srcs = ["capture_tpu_profile.cc"], + srcs = [ + "capture_tpu_profile.cc", + ], visibility = ["//visibility:public"], deps = [ ":dump_tpu_profile", ":tpu_profiler_proto_cc", + ":version", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core/distributed_runtime/rpc:grpc_util", diff --git a/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc b/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc index b67f2f47a7b753fd4629d7ad4db0b4c67933ce0b..1cded9f8cf01b931d1d535a54effd54459dd8e9a 100644 --- a/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc +++ b/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/contrib/tpu/profiler/dump_tpu_profile.h" #include "tensorflow/contrib/tpu/profiler/tpu_profiler.grpc.pb.h" +#include "tensorflow/contrib/tpu/profiler/version.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/platform/init_main.h" @@ -84,6 +85,9 @@ int main(int argc, char** argv) { "Duration of tracing in ms. Default is 2000ms."), }; + std::cout << "Welcome to the Cloud TPU Profiler v" << TPU_PROFILER_VERSION + << std::endl; + tensorflow::string usage = tensorflow::Flags::Usage(argv[0], flag_list); bool parse_ok = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_ok || FLAGS_service_addr.empty() || FLAGS_logdir.empty()) { diff --git a/tensorflow/contrib/tpu/profiler/dump_tpu_profile.cc b/tensorflow/contrib/tpu/profiler/dump_tpu_profile.cc index 0ed5b2fad333eaa8a9820da334e953cbc282371f..b842951eb2c22792a22d9a16c022d3122391f4e8 100644 --- a/tensorflow/contrib/tpu/profiler/dump_tpu_profile.cc +++ b/tensorflow/contrib/tpu/profiler/dump_tpu_profile.cc @@ -171,7 +171,6 @@ Status WriteTensorboardTPUProfile(const string& logdir, const string& run, DumpToolDataToLogDirectory(profile_run_dir, tool_data, os)); } } - TF_RETURN_IF_ERROR(DumpGraphEvents(logdir, run, response, os)); return Status::OK(); } diff --git a/tensorflow/contrib/tpu/profiler/dump_tpu_profile.h b/tensorflow/contrib/tpu/profiler/dump_tpu_profile.h index 65b92aa41867ed9e2e8b06c9e34dd99068bb459c..25b958bcfeab7e0cfd9c180b8af4057e9bdfc73b 100644 --- a/tensorflow/contrib/tpu/profiler/dump_tpu_profile.h +++ b/tensorflow/contrib/tpu/profiler/dump_tpu_profile.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_TPU_PROFILER_DUMP_TPU_PROFILE_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_TPU_PROFILER_DUMP_TPU_PROFILE_H_ +#ifndef TENSORFLOW_CONTRIB_TPU_PROFILER_DUMP_TPU_PROFILE_H_ +#define TENSORFLOW_CONTRIB_TPU_PROFILER_DUMP_TPU_PROFILE_H_ #include "tensorflow/contrib/tpu/profiler/tpu_profiler.grpc.pb.h" #include "tensorflow/core/lib/core/status.h" @@ -35,4 +35,4 @@ Status WriteTensorboardTPUProfile(const string& logdir, const string& run, } // namespace tpu } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_TPU_PROFILER_DUMP_TPU_PROFILE_H_ +#endif // TENSORFLOW_CONTRIB_TPU_PROFILER_DUMP_TPU_PROFILE_H_ diff --git a/tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py b/tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py index 7970c20a2693cbbe91a136080240f676d29f2053..846db1332991e8c84f51dc7e6bcc3592a955991e 100644 --- a/tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py +++ b/tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py @@ -42,8 +42,9 @@ def main(unused_argv=None): if not FLAGS.service_addr or not FLAGS.logdir: sys.exit('service_addr and logdir must be provided.') executable_path = os.path.join(os.path.dirname(__file__), EXECUTABLE) + logdir = os.path.expandvars(os.path.expanduser(FLAGS.logdir)) cmd = [executable_path] - cmd.append('--logdir='+FLAGS.logdir) + cmd.append('--logdir='+logdir) cmd.append('--service_addr='+FLAGS.service_addr) cmd.append('--duration_ms='+str(FLAGS.duration_ms)) subprocess.call(cmd) diff --git a/tensorflow/contrib/tpu/profiler/pip_package/setup.py b/tensorflow/contrib/tpu/profiler/pip_package/setup.py index 179d29602b9f970fb450bc057332fa092066255c..92196638318f4a551619d04ba730ac66a58d596e 100644 --- a/tensorflow/contrib/tpu/profiler/pip_package/setup.py +++ b/tensorflow/contrib/tpu/profiler/pip_package/setup.py @@ -20,16 +20,12 @@ from __future__ import print_function from setuptools import setup -_VERSION = '1.3.0-a1' +_VERSION = '1.4.3-a2' CONSOLE_SCRIPTS = [ 'capture_tpu_profile=cloud_tpu_profiler.main:run_main', ] -REQUIRED_PACKAGES = [ - 'tensorflow >= 1.2.0', -] - setup( name='cloud_tpu_profiler', version=_VERSION.replace('-', ''), @@ -45,13 +41,12 @@ setup( entry_points={ 'console_scripts': CONSOLE_SCRIPTS, }, - install_requires=REQUIRED_PACKAGES, classifiers=[ # How mature is this project? Common values are # 3 - Alpha # 4 - Beta # 5 - Production/Stable - 'Development Status :: 3 - Alpha', + 'Development Status :: 4 - Beta', 'Intended Audience :: Developers', 'Intended Audience :: Education', diff --git a/tensorflow/contrib/tpu/profiler/trace_events_to_json.h b/tensorflow/contrib/tpu/profiler/trace_events_to_json.h index 992eae43d903db495850ced7a59e38120d3fed34..3bd76dd01c7d0f35bad9386c11811743e1709fca 100644 --- a/tensorflow/contrib/tpu/profiler/trace_events_to_json.h +++ b/tensorflow/contrib/tpu/profiler/trace_events_to_json.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_TPU_PROFILER_TRACE_EVENTS_TO_JSON_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_TPU_PROFILER_TRACE_EVENTS_TO_JSON_H_ +#ifndef TENSORFLOW_CONTRIB_TPU_PROFILER_TRACE_EVENTS_TO_JSON_H_ +#define TENSORFLOW_CONTRIB_TPU_PROFILER_TRACE_EVENTS_TO_JSON_H_ #include "tensorflow/contrib/tpu/profiler/trace_events.pb.h" #include "tensorflow/core/platform/types.h" @@ -29,4 +29,4 @@ string TraceEventsToJson(const Trace &trace); } // namespace tpu } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_TPU_PROFILER_TRACE_EVENTS_TO_JSON_H_ +#endif // TENSORFLOW_CONTRIB_TPU_PROFILER_TRACE_EVENTS_TO_JSON_H_ diff --git a/tensorflow/contrib/tpu/profiler/version.h b/tensorflow/contrib/tpu/profiler/version.h new file mode 100644 index 0000000000000000000000000000000000000000..0f645a549296b0f05acfb7ae564be1daf37925f8 --- /dev/null +++ b/tensorflow/contrib/tpu/profiler/version.h @@ -0,0 +1,21 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_TPU_PROFILER_VERSION_H_ +#define TENSORFLOW_CONTRIB_TPU_PROFILER_VERSION_H_ + +#define TPU_PROFILER_VERSION "1.4.3" + +#endif // TENSORFLOW_CONTRIB_TPU_PROFILER_VERSION_H_ diff --git a/tensorflow/contrib/tpu/python/ops/tpu_ops.py b/tensorflow/contrib/tpu/python/ops/tpu_ops.py index 33e47f674d798f622fb08121dabb67d7f45af15b..a49a3dcf2999053d9b0d5ffcb6411e693702d785 100644 --- a/tensorflow/contrib/tpu/python/ops/tpu_ops.py +++ b/tensorflow/contrib/tpu/python/ops/tpu_ops.py @@ -21,6 +21,7 @@ from __future__ import print_function import platform +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops if platform.system() != "Windows": @@ -40,6 +41,63 @@ if platform.system() != "Windows": del op # Unused # The gradient of a cross replica sum is also a cross-replica sum. return gen_tpu_ops.cross_replica_sum(grad) + + # This extra type checking exists to give a more helpful error message in + # the common case that uint8 and int64 values are infed. Remove when both + # types are supported. + + _SUPPORTED_INFEED_DTYPES = set([ + dtypes.int32, dtypes.bfloat16, dtypes.float32 + ]) + + def infeed_dequeue(dtype, shape, name=None): + """A placeholder op for a value that will be fed into the computation. + + Args: + dtype: A `tf.DType`. The type of elements in the tensor. + shape: A `tf.TensorShape` or list of `ints`. The shape of the tensor. + name: A name for the operation (optional). + + Returns: + A `Tensor` of type `dtype`. + A tensor that will be provided using the infeed mechanism. + + Raises: + TypeError: If 'dtype` is not a supported infeed type. + """ + if dtype not in _SUPPORTED_INFEED_DTYPES: + raise TypeError( + "{} is not a supported TPU infeed type. Supported types are: " + "{}".format(dtype, list(_SUPPORTED_INFEED_DTYPES))) + + return gen_tpu_ops.infeed_dequeue(dtype, shape, name=name) + + # pylint: disable=redefined-outer-name + def infeed_dequeue_tuple(dtypes, shapes, name=None): + """A placeholder op for values fed into the TPU simultaneously as a tuple. + + Args: + dtypes: A list of `tf.DType`s that has length `>= 1`. + The element types of each element in `outputs`. + shapes: A list of shapes (each a `tf.TensorShape` or list of `ints`). + The shapes of each tensor in `outputs`. + name: A name for the operation (optional). + + Returns: + A list of `Tensor` objects of type `dtypes`. + A list of tensors that will be provided using the infeed mechanism. + + Raises: + TypeError: If a type in 'dtypes` is not a supported infeed type. + """ + for dtype in dtypes: + if dtype not in _SUPPORTED_INFEED_DTYPES: + raise TypeError( + "{} is not a supported TPU infeed type. Supported types are: " + "{}".format(dtype, list(_SUPPORTED_INFEED_DTYPES))) + return gen_tpu_ops.infeed_dequeue_tuple(dtypes, shapes, name=name) + # pylint: enable=redefined-outer-name + else: # We have already built the appropriate libraries into the binary via CMake # if we have built contrib, so we don't need this diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py index bb35f4ece6ea7ebfd0db0332c6e8f2d2e2eb9f81..2ae3a26a853bf4941ac3855ec525293b5a508a2a 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # =================================================================== - """TPUEstimator class.""" from __future__ import absolute_import @@ -24,6 +23,7 @@ from contextlib import contextmanager import copy import threading import time +import traceback import six from six.moves import queue as Queue # pylint: disable=redefined-builtin @@ -60,7 +60,6 @@ from tensorflow.python.training import session_run_hook from tensorflow.python.training import training from tensorflow.python.training import training_util - _INITIAL_LOSS = 1e7 _ZERO_LOSS = 0. _TPU_ESTIMATOR = 'tpu_estimator' @@ -86,28 +85,28 @@ def _create_global_step(graph): initializer=init_ops.zeros_initializer(), trainable=False, use_resource=True, - collections=[ops.GraphKeys.GLOBAL_VARIABLES, - ops.GraphKeys.GLOBAL_STEP]) + collections=[ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.GLOBAL_STEP]) def _create_or_get_iterations_per_loop(): graph = ops.get_default_graph() - iter_vars = graph.get_collection(_TPU_ESTIMATOR) + collection_name = '{}_{}'.format(_TPU_ESTIMATOR, _ITERATIONS_PER_LOOP_VAR) + iter_vars = graph.get_collection(collection_name) if len(iter_vars) == 1: return iter_vars[0] elif len(iter_vars) > 1: raise RuntimeError('Multiple iterations_per_loop_var in collection.') with ops.colocate_with(training_util.get_global_step()): - with variable_scope.variable_scope(_TPU_ESTIMATOR, - reuse=variable_scope.AUTO_REUSE): + with variable_scope.variable_scope( + _TPU_ESTIMATOR, reuse=variable_scope.AUTO_REUSE): return variable_scope.get_variable( _ITERATIONS_PER_LOOP_VAR, initializer=init_ops.zeros_initializer(), shape=[], dtype=dtypes.int32, trainable=False, - collections=[_TPU_ESTIMATOR], + collections=[collection_name, ops.GraphKeys.LOCAL_VARIABLES], use_resource=True) @@ -241,9 +240,9 @@ class _TPUContext(object): return self._eval_batch_size return None - global_batch_size = (self._train_batch_size if - mode == model_fn_lib.ModeKeys.TRAIN - else self._eval_batch_size) + global_batch_size = ( + self._train_batch_size + if mode == model_fn_lib.ModeKeys.TRAIN else self._eval_batch_size) # On TPU if self.is_input_sharded_per_core(): return global_batch_size // self.num_cores @@ -290,8 +289,9 @@ class _TPUContext(object): # The tpu job is determined by the run_config. Right now, this method is # required as tpu_config is not part of the RunConfig. mode = self._assert_mode() - master = (run_config.evaluation_master if mode == model_fn_lib.ModeKeys.EVAL - else run_config.master) + master = ( + run_config.evaluation_master + if mode == model_fn_lib.ModeKeys.EVAL else run_config.master) if master in _LOCAL_MASTERS: return None @@ -318,6 +318,7 @@ class _TPUContext(object): def tpu_host_placement_function(self): """Returns the TPU host place function.""" master = self.master_job + def _placement_function(_sentinal=None, core_id=None, host_id=None): # pylint: disable=invalid-name assert _sentinal is None if core_id is not None and host_id is not None: @@ -332,19 +333,23 @@ class _TPUContext(object): if core_id is not None: host_id = core_id / 8 return '/job:%s/task:%d/device:CPU:0' % (master, host_id) + return _placement_function @property def tpu_device_placement_function(self): master = self.master_job job_device = '' if master is None else ('/job:%s' % master) + def _placement_function(i): return '%s/task:%d/device:TPU:%d' % (job_device, i / 8, i % 8) + return _placement_function @property def tpu_ordinal_function(self): """Returns the TPU ordinal fn.""" + def _tpu_ordinal_function(index): """Return the TPU ordinal associated with a shard. @@ -357,6 +362,7 @@ class _TPUContext(object): The ordinal of the TPU device the shard's infeed should be placed on. """ return index % 8 + return _tpu_ordinal_function @@ -370,14 +376,16 @@ class _SIGNAL(object): STOP = -2 -class TPUEstimatorSpec(collections.namedtuple('TPUEstimatorSpec', [ - 'mode', - 'predictions', - 'loss', - 'train_op', - 'eval_metrics', - 'export_outputs', - 'scaffold_fn'])): +class TPUEstimatorSpec( + collections.namedtuple('TPUEstimatorSpec', [ + 'mode', + 'predictions', + 'loss', + 'train_op', + 'eval_metrics', + 'export_outputs', + 'scaffold_fn' + ])): """Ops and objects returned from a `model_fn` and passed to `TPUEstimator`. See `EstimatorSpec` for `mode`, 'predictions, 'loss', 'train_op', and @@ -387,7 +395,7 @@ class TPUEstimatorSpec(collections.namedtuple('TPUEstimatorSpec', [ `metric_fn` runs on CPU to generate metrics and `tensors` represents the `Tensor`s transferred from TPU system to CPU host and passed to `metric_fn`. To be precise, TPU evaluation expects a slightly different signature from the - ${tf.estimator.Estimator}. While `EstimatorSpec.eval_metric_ops` expects a + @{tf.estimator.Estimator}. While `EstimatorSpec.eval_metric_ops` expects a dict, `TPUEstimatorSpec.eval_metrics` is a tuple of `metric_fn` and `tensors`. The `tensors` could be a list of `Tensor`s or dict of names to `Tensor`s. The `tensors` usually specify the model logits, which are transferred back from @@ -415,111 +423,116 @@ class TPUEstimatorSpec(collections.namedtuple('TPUEstimatorSpec', [ """Creates a validated `TPUEstimatorSpec` instance.""" if eval_metrics is not None: _EvalMetrics.validate(eval_metrics) - return super(TPUEstimatorSpec, cls).__new__(cls, - mode=mode, - predictions=predictions, - loss=loss, - train_op=train_op, - eval_metrics=eval_metrics, - export_outputs=export_outputs, - scaffold_fn=scaffold_fn) + return super(TPUEstimatorSpec, cls).__new__( + cls, + mode=mode, + predictions=predictions, + loss=loss, + train_op=train_op, + eval_metrics=eval_metrics, + export_outputs=export_outputs, + scaffold_fn=scaffold_fn) def as_estimator_spec(self): """Creates an equivalent `EstimatorSpec` used by CPU train/eval.""" eval_metric_ops = _EvalMetrics.to_metric_metric_ops_for_cpu( self.eval_metrics) scaffold = self.scaffold_fn() if self.scaffold_fn else None - return model_fn_lib.EstimatorSpec(mode=self.mode, - predictions=self.predictions, - loss=self.loss, - train_op=self.train_op, - eval_metric_ops=eval_metric_ops, - export_outputs=self.export_outputs, - scaffold=scaffold) + return model_fn_lib.EstimatorSpec( + mode=self.mode, + predictions=self.predictions, + loss=self.loss, + train_op=self.train_op, + eval_metric_ops=eval_metric_ops, + export_outputs=self.export_outputs, + scaffold=scaffold) + + +class _OpQueueContext(object): + """Manages work queue and thread for a infeed/outfeed thread.""" + + def __init__(self, name, target, args): + self._name = name + self._queue = Queue.Queue() + args = (self,) + args + self._thread = threading.Thread(name=name, target=target, args=args) + self._thread.daemon = True + self._thread.start() + + def stop(self): + self._queue.put(_SIGNAL.STOP) + + def send_next_batch_signal(self, iterations): + self._queue.put(iterations) + + def read_iteration_counts(self): + while True: + signal = self._queue.get(block=True) + logging.debug('%s read signal %s', self._name, signal) + if signal == _SIGNAL.STOP: + logging.info('%s received signal, stopping.', self._name) + return + yield signal + def join(self): + logging.info('Shutting down %s thread.' % self._name) + self.stop() + self._thread.join() -class _InfeedOutfeedThreadBaseController(object): - """This wraps the infeed/outfeed thread and stops when Estimator finishes.""" - def __init__(self, thd): - self._signal_queue = Queue.Queue() - thd.daemon = True - thd.start() - self._thd = thd +class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook): + """A Session hook setting up the TPU initialization, infeed, and outfeed. - def block_and_get_signal(self): - return self._signal_queue.get() + This hook does two major things: + 1. initialize and shutdown TPU system. + 2. launch and join the threads for infeed enqueue and (optional) outfeed + dequeue. + """ - def send_next_batch_signal(self, signal=_SIGNAL.NEXT_BATCH): - self._signal_queue.put(signal) + def __init__(self, ctx, enqueue_ops, dequeue_ops=None): + self._master_job = ctx.master_job + self._enqueue_ops = enqueue_ops + self._dequeue_ops = dequeue_ops + self._initial_infeed_sleep_secs = ( + ctx.config.tpu_config.initial_infeed_sleep_secs) + self._session_cancel_timer = None - def join(self): - self._signal_queue.put(_SIGNAL.STOP) - self._thd.join() + self._feed_error = None + self._finished = False + def begin(self): + logging.info('TPU job name %s', self._master_job) + self._iterations_per_loop_var = _create_or_get_iterations_per_loop() + self._init_op = [tpu.initialize_system(job=self._master_job)] + self._finalize_op = [tpu.shutdown_system(job=self._master_job)] -class _OutfeedThreadController(_InfeedOutfeedThreadBaseController): - """This wraps the outfeed thread and stops when Estimator finishes.""" + def _log_error(self, session, error): + """Log an infeed or outfeed error. - def __init__(self, session, dequeue_ops): - super(_OutfeedThreadController, self).__init__( - threading.Thread(target=self._execute_dequeue_ops, - args=(session, dequeue_ops))) + This logs a short error message immediately, and schedules a timer to + emit the full stack trace and error message after a short period of time. + If the main session has terminated by the time the timer triggers, we + assume the real source of the error was from the main session and avoid + emitting a stack trace for the infeed. - def _execute_dequeue_ops(self, session, dequeue_ops): - count = 0 - while True: - signal = self.block_and_get_signal() - if signal == _SIGNAL.STOP: - logging.info('Stop outfeed thread.') - return + Args: + session: `tf.Session`, session to be terminated + error: exception that triggered logging. + """ + logging.warning( + '\n\n' + 'Error occurred during infeed/outfeed. This may be due to a compile ' + 'error in the main session. Waiting for a short time for the main ' + 'session to come back.\n\n%s', error) - iterations = signal - for i in range(iterations): - logging.debug('Outfeed dequeue for iteration (%d, %d)', count, i) - session.run(dequeue_ops) - count += 1 + self._feed_error = traceback.format_exc() - def join(self): - logging.info('Waiting for Outfeed Thread to exit.') - super(_OutfeedThreadController, self).join() - - -class _InfeedThreadController(_InfeedOutfeedThreadBaseController): - """This wraps the infeed thread and stops when Estimator finishes.""" - - def __init__(self, session, enqueue_ops, initial_infeed_sleep_secs): - super(_InfeedThreadController, self).__init__( - threading.Thread( - target=self._input_thread_fn_for_loading, - args=(session, enqueue_ops, initial_infeed_sleep_secs))) - - def _input_thread_fn_for_loading(self, session, enqueue_ops, - initial_infeed_sleep_secs): - count = 0 - if initial_infeed_sleep_secs: - logging.info('Infeed thread sleeping for %d seconds.', - initial_infeed_sleep_secs) - time.sleep(initial_infeed_sleep_secs) - logging.info('Infeed thread starting after sleep') - try: - while True: - signal = self._signal_queue.get() - if signal == _SIGNAL.STOP: - logging.info('Stop Infeed input thread.') - return - - if _WRAP_INPUT_FN_INTO_WHILE_LOOP: - # Enqueue batches for next loop. - session.run(enqueue_ops) - else: - iterations = signal - for i in range(iterations): - logging.debug('Infeed enqueue for iteration (%d, %d)', count, i) - session.run(enqueue_ops) - count += 1 + # If we've already encountered a feed error, don't schedule another + # cancellation op. + if self._session_cancel_timer: + return - except Exception: # pylint: disable=broad-except + def _cancel_session(): # Close the session to avoid the main thread from hanging. If input # pipeline triggers any error, the infeed thread dies but the main thread # for TPU computation waits for the infeed enqueue forever. Close the @@ -534,77 +547,94 @@ class _InfeedThreadController(_InfeedOutfeedThreadBaseController): # exception in the main thread, instead of the expected compile error. # User code that depends on having the proper exception type will # therefore be confused. - logging.error( - 'Failed running infeed, closing session.\n' - 'You may see an exception from your main session after this. ' - 'Sleep for 2 minutes before close Session from infeed thread to ' - 'allow the main thread returning an error first, if any.', - exc_info=1 - ) - time.sleep(120) - logging.error('Closing the failed session.') - session.close() - - def join(self): - logging.info('Waiting for Infeed Thread to exit.') - super(_InfeedThreadController, self).join() - - -class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook): - """A Session hook setting up the TPU initialization, infeed, and outfeed. - - This hook does two major things: - 1. initialize and shutdown TPU system. - 2. launch and join the threads for infeed enqueue and (optional) outfeed - dequeue. - """ + time.sleep(5) + + # If the main session is still running, the infeed/outfeed errors are + # legitimate, and should be logged. + if not self._finished: + logging.error('Feed error: %s', self._feed_error) + logging.error('Closing session. A RuntimeError should follow.') + session.close() + + self._session_cancel_timer = threading.Thread(target=_cancel_session) + self._session_cancel_timer.daemon = True + self._session_cancel_timer.start() + + def _run_infeed(self, queue_ctx, session): + logging.info('Starting infeed thread controller.') + if self._initial_infeed_sleep_secs: + logging.info('%s thread sleeping for %d seconds.', self._name, + self._initial_infeed_sleep_secs) + time.sleep(self._initial_infeed_sleep_secs) + logging.info('%s thread starting after sleep', self._name) - def __init__(self, ctx, enqueue_ops, dequeue_ops=None): - self._master_job = ctx.master_job - self._enqueue_ops = enqueue_ops - self._dequeue_ops = dequeue_ops - self._initial_infeed_sleep_secs = ( - ctx.config.tpu_config.initial_infeed_sleep_secs) + try: + if _WRAP_INPUT_FN_INTO_WHILE_LOOP: + for _ in queue_ctx.read_iteration_counts(): + session.run(self._enqueue_ops) + else: + for count, steps in enumerate(queue_ctx.read_iteration_counts()): + for i in xrange(steps): + logging.debug('Infeed enqueue for iteration (%d, %d)', count, i) + session.run(self._enqueue_ops) + logging.debug('Infeed thread finished, shutting down.') + except Exception as e: # pylint: disable=broad-except + self._log_error(session, e) - def begin(self): - logging.info('TPU job name %s', self._master_job) - self._iterations_per_loop_var = _create_or_get_iterations_per_loop() - self._init_op = [tpu.initialize_system(job=self._master_job)] - self._finalize_op = [tpu.shutdown_system(job=self._master_job)] + def _run_outfeed(self, queue_ctx, session): + logging.info('Starting outfeed thread controller.') + try: + for count, steps in enumerate(queue_ctx.read_iteration_counts()): + for i in xrange(steps): + logging.debug('Outfeed dequeue for iteration (%d, %d)', count, i) + session.run(self._dequeue_ops) + except Exception as e: # pylint: disable=broad-except + self._log_error(session, e) def after_create_session(self, session, coord): logging.info('Init TPU system') - session.run(self._init_op, - options=config_pb2.RunOptions(timeout_in_ms=5*60*1000)) + session.run( + self._init_op, + options=config_pb2.RunOptions(timeout_in_ms=5 * 60 * 1000)) logging.info('Start infeed thread controller') - self._infeed_thd_controller = _InfeedThreadController( - session, self._enqueue_ops, self._initial_infeed_sleep_secs) + self._infeed_controller = _OpQueueContext( + name='InfeedController', target=self._run_infeed, args=(session,)) if self._dequeue_ops is not None: logging.info('Start outfeed thread controller') - self._outfeed_thd_controller = _OutfeedThreadController( - session, self._dequeue_ops) + self._outfeed_controller = _OpQueueContext( + name='OutfeedController', target=self._run_outfeed, args=(session,)) def before_run(self, run_context): + if self._feed_error: + logging.warning('Feed error occurred, terminating session.') + run_context.request_stop() + return + iterations = run_context.session.run(self._iterations_per_loop_var) logging.info('Enqueue next (%d) batch(es) of data to infeed.', iterations) + self._infeed_controller.send_next_batch_signal(iterations) - self._infeed_thd_controller.send_next_batch_signal(iterations) if self._dequeue_ops is not None: # TODO(xiejw): Refactor the outfeed dequeue into tf.while_loop. - logging.info( - 'Dequeue next (%d) batch(es) of data from outfeed.', iterations) - self._outfeed_thd_controller.send_next_batch_signal(iterations) + logging.info('Dequeue next (%d) batch(es) of data from outfeed.', + iterations) + self._outfeed_controller.send_next_batch_signal(iterations) def end(self, session): + if self._session_cancel_timer: + logging.warning('Feed error occurred; waiting for message.') + self._session_cancel_timer.join() + + self._finished = True logging.info('Stop infeed thread controller') - self._infeed_thd_controller.join() + self._infeed_controller.join() if self._dequeue_ops is not None: logging.info('Stop output thread controller') - self._outfeed_thd_controller.join() + self._outfeed_controller.join() logging.info('Shutdown TPU system.') session.run(self._finalize_op) @@ -675,8 +705,8 @@ class _TPUStopAtStepHook(session_run_hook.SessionRunHook): run_context.request_stop() else: iterations = self._next_iterations(global_step, self._last_step) - self._iterations_per_loop_var.load(iterations, - session=run_context.session) + self._iterations_per_loop_var.load( + iterations, session=run_context.session) class _SetEvalIterationsHook(session_run_hook.SessionRunHook): @@ -697,8 +727,8 @@ class _SetEvalIterationsHook(session_run_hook.SessionRunHook): self._iterations_per_loop_var.load(self._num_steps, session=session) -def generate_per_core_enqueue_ops_fn_for_host( - ctx, input_fn, inputs_structure_recorder): +def generate_per_core_enqueue_ops_fn_for_host(ctx, input_fn, + inputs_structure_recorder): """Generates infeed enqueue ops for per-core input_fn on a single host.""" captured_infeed_queue = _CapturedObject() @@ -728,9 +758,9 @@ def generate_per_core_enqueue_ops_fn_for_host( per_host_sharded_inputs) per_host_enqueue_ops = infeed_queue.generate_enqueue_ops( - per_host_sharded_inputs, - tpu_ordinal_function=ctx.tpu_ordinal_function) + per_host_sharded_inputs, tpu_ordinal_function=ctx.tpu_ordinal_function) return per_host_enqueue_ops + return enqueue_ops_fn, captured_infeed_queue @@ -747,8 +777,7 @@ def generate_per_host_enqueue_ops_fn_for_host( features, labels = inputs else: features, labels = inputs, None - inputs_structure_recorder.validate_and_record_structure( - features, labels) + inputs_structure_recorder.validate_and_record_structure(features, labels) unsharded_tensor_list = ( inputs_structure_recorder.flatten_features_and_labels( features, labels)) @@ -762,9 +791,9 @@ def generate_per_host_enqueue_ops_fn_for_host( per_host_enqueue_ops = ( infeed_queue.split_inputs_and_generate_enqueue_ops( - unsharded_tensor_list, - placement_function=lambda x: device)) + unsharded_tensor_list, placement_function=lambda x: device)) return per_host_enqueue_ops + return enqueue_ops_fn, captured_infeed_queue @@ -814,6 +843,7 @@ class _InputPipeline(object): def validate_and_record_structure(self, features, labels): """Validates and records the structure of features` and `labels`.""" + def _extract_key_names(tensor_or_dict): if tensor_or_dict is None: return [] @@ -841,8 +871,8 @@ class _InputPipeline(object): flattened_inputs = [] if self._feature_names: # We need a fixed ordering for enqueueing and dequeueing. - flattened_inputs.extend([features[name] - for name in self._feature_names]) + flattened_inputs.extend( + [features[name] for name in self._feature_names]) else: flattened_inputs.append(features) @@ -869,11 +899,11 @@ class _InputPipeline(object): ValueError: If the number of expected tensors from `flattened_inputs` mismatches the recorded structure. """ - expected_num_features = (len(self._feature_names) if self._feature_names - else 1) + expected_num_features = ( + len(self._feature_names) if self._feature_names else 1) if self._has_labels: - expected_num_labels = (len(self._label_names) if self._label_names - else 1) + expected_num_labels = ( + len(self._label_names) if self._label_names else 1) else: expected_num_labels = 0 @@ -894,8 +924,8 @@ class _InputPipeline(object): if expected_num_labels == 0: unflattened_label = None elif self._label_names: - unflattened_label = dict(zip(self._label_names, - flattened_inputs[expected_num_features:])) + unflattened_label = dict( + zip(self._label_names, flattened_inputs[expected_num_features:])) else: # Single tensor case. unflattened_label = flattened_inputs[expected_num_features] @@ -960,8 +990,9 @@ class _InputPipeline(object): self._ctx, self._input_fn, self._inputs_structure_recorder)) if _WRAP_INPUT_FN_INTO_WHILE_LOOP: - enqueue_ops.append(_wrap_computation_in_while_loop( - device=host_device, op_fn=enqueue_ops_fn)) + enqueue_ops.append( + _wrap_computation_in_while_loop( + device=host_device, op_fn=enqueue_ops_fn)) else: enqueue_ops.append(enqueue_ops_fn()) # Infeed_queue_getter must be called after enqueue_ops_fn is called. @@ -978,8 +1009,9 @@ class _InputPipeline(object): self._batch_axis, host_device)) if _WRAP_INPUT_FN_INTO_WHILE_LOOP: - enqueue_ops.append(_wrap_computation_in_while_loop( - device=host_device, op_fn=enqueue_ops_fn)) + enqueue_ops.append( + _wrap_computation_in_while_loop( + device=host_device, op_fn=enqueue_ops_fn)) else: enqueue_ops.append(enqueue_ops_fn()) infeed_queues.append(captured_infeed_queue.get()) @@ -1065,6 +1097,7 @@ class _ModelFnWrapper(object): with ops.control_dependencies([train_op]): return array_ops.identity(loss) + return train_step, captured_scaffold_fn def convert_to_single_tpu_eval_step(self, dequeue_fn): @@ -1113,6 +1146,7 @@ class _ModelFnWrapper(object): with ops.control_dependencies([outfeed_ops]): return math_ops.add(total_loss, loss) + return eval_step, eval_metrics, captured_scaffold_fn def _call_model_fn(self, features, labels): @@ -1137,10 +1171,9 @@ class _ModelFnWrapper(object): kwargs['params'] = params if 'params' not in model_fn_args: - raise ValueError( - 'model_fn ({}) does not include params argument, ' - 'required by TPUEstimator to pass batch size as ' - 'params[\'batch_size\']'.format(self._model_fn)) + raise ValueError('model_fn ({}) does not include params argument, ' + 'required by TPUEstimator to pass batch size as ' + 'params[\'batch_size\']'.format(self._model_fn)) batch_size_for_model_fn = self._ctx.batch_size_for_model_fn if batch_size_for_model_fn is not None: @@ -1347,8 +1380,9 @@ class ExamplesPerSecondHook(basic_session_run_hooks.StepCounterHook): def _log_and_record(self, elapsed_steps, elapsed_time, global_step): examples_per_sec = self._batch_size * elapsed_steps / elapsed_time if self._summary_writer is not None: - example_summary = Summary(value=[Summary.Value( - tag='examples_sec', simple_value=examples_per_sec)]) + example_summary = Summary(value=[ + Summary.Value(tag='examples_sec', simple_value=examples_per_sec) + ]) self._summary_writer.add_summary(example_summary, global_step) logging.info('examples/sec: %g', examples_per_sec) @@ -1487,9 +1521,8 @@ class TPUEstimator(estimator_lib.Estimator): '`config` must be provided with type `tpu_config.RunConfig`') if params is not None and any(k in params for k in _RESERVED_PARAMS_KEYS): - raise ValueError( - '{} are reserved keys but existed in params {}.'.format( - _RESERVED_PARAMS_KEYS, params)) + raise ValueError('{} are reserved keys but existed in params {}.'.format( + _RESERVED_PARAMS_KEYS, params)) if use_tpu: if train_batch_size is None: @@ -1570,8 +1603,9 @@ class TPUEstimator(estimator_lib.Estimator): if max_steps is not None: util_lib.check_positive_integer(max_steps, 'Train max_steps') - return [_TPUStopAtStepHook(self._iterations_per_training_loop, steps, - max_steps)] + return [ + _TPUStopAtStepHook(self._iterations_per_training_loop, steps, max_steps) + ] def _convert_eval_steps_to_hooks(self, steps): with self._ctx.with_mode(model_fn_lib.ModeKeys.EVAL) as ctx: @@ -1639,6 +1673,7 @@ class TPUEstimator(estimator_lib.Estimator): # `features` in `model_fn` signature. def _input_fn(): return input_fn(**kwargs) + return _input_fn def _augment_model_fn(self, model_fn, batch_axis): @@ -1694,9 +1729,10 @@ class TPUEstimator(estimator_lib.Estimator): total_loss, eval_metric_ops, scaffold = _eval_on_tpu_system( ctx, model_fn_wrapper, dequeue_fn) iterations_per_loop_var = _create_or_get_iterations_per_loop() - mean_loss = math_ops.div( - total_loss, - math_ops.cast(iterations_per_loop_var, dtype=total_loss.dtype)) + mean_loss = math_ops.div(total_loss, + math_ops.cast( + iterations_per_loop_var, + dtype=total_loss.dtype)) # Creates a dummy metric update_op for all metrics. Estimator expects # all metrics in eval_metric_ops have update_op and calls them one by @@ -1724,6 +1760,7 @@ class TPUEstimator(estimator_lib.Estimator): evaluation_hooks=hooks, eval_metric_ops=eval_metric_ops, scaffold=scaffold) + return _model_fn @@ -1736,15 +1773,16 @@ def _eval_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn): model_fn_wrapper.convert_to_single_tpu_eval_step(dequeue_fn)) def multi_tpu_eval_steps_on_single_shard(): - return training_loop.repeat(iterations_per_loop_var, - single_tpu_eval_step, - [_ZERO_LOSS], - name='loop') + return training_loop.repeat( + iterations_per_loop_var, + single_tpu_eval_step, [_ZERO_LOSS], + name='loop') - (loss,) = tpu.shard(multi_tpu_eval_steps_on_single_shard, - inputs=[], - num_shards=num_cores, - outputs_from_all_shards=False) + (loss,) = tpu.shard( + multi_tpu_eval_steps_on_single_shard, + inputs=[], + num_shards=num_cores, + outputs_from_all_shards=False) scaffold = _get_scaffold(captured_scaffold_fn) return loss, eval_metric_ops, scaffold @@ -1761,14 +1799,14 @@ def _train_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn): def multi_tpu_train_steps_on_single_shard(): return training_loop.repeat( iterations_per_loop_var, - single_tpu_train_step, - [_INITIAL_LOSS], + single_tpu_train_step, [_INITIAL_LOSS], name=b'loop') - (loss,) = tpu.shard(multi_tpu_train_steps_on_single_shard, - inputs=[], - num_shards=num_cores, - outputs_from_all_shards=False) + (loss,) = tpu.shard( + multi_tpu_train_steps_on_single_shard, + inputs=[], + num_shards=num_cores, + outputs_from_all_shards=False) scaffold = _get_scaffold(captured_scaffold_fn) return loss, scaffold @@ -1776,6 +1814,7 @@ def _train_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn): def _wrap_computation_in_while_loop(device, op_fn): """Wraps the ops generated by `op_fn` in tf.while_loop.""" + def computation(i): with ops.control_dependencies(op_fn()): return i + 1 @@ -1787,7 +1826,8 @@ def _wrap_computation_in_while_loop(device, op_fn): iterations = array_ops.identity(iterations_per_loop_var) return control_flow_ops.while_loop( lambda i: i < iterations, - computation, [constant_op.constant(0)], parallel_iterations=1) + computation, [constant_op.constant(0)], + parallel_iterations=1) def _validate_tpu_training_graph(): @@ -1800,8 +1840,9 @@ def _validate_tpu_training_graph(): # Check if there is atleast one CrossReplicaSum operation in the graph # This should be introduced by using the CrossShardOptimizer wrapper - cross_replica_sum_ops = [o for o in operations - if o.type == _CROSS_REPLICA_SUM_OP] + cross_replica_sum_ops = [ + o for o in operations if o.type == _CROSS_REPLICA_SUM_OP + ] if not cross_replica_sum_ops: raise ValueError( 'CrossShardOptimizer must be used for model training on TPUs.') @@ -1848,9 +1889,11 @@ def _get_scaffold(captured_scaffold_fn): if scaffold: wrapped_finalize = scaffold.finalize + def _finalize(): with _CapturingContext('Inside Scaffold.finalize'): wrapped_finalize() + scaffold.finalize = _finalize return scaffold @@ -1865,9 +1908,8 @@ class _CapturingContext(control_flow_ops.ControlFlowContext): def AddOp(self, op): # pylint: disable=invalid-name for c in op.inputs: if tpu._TPU_REPLICATE_ATTR in c.op.node_def.attr: # pylint: disable=protected-access - raise ValueError( - '{}: Op {} depends on TPU computation {}, ' - 'which is not allowed.'.format(self._message, op, c)) + raise ValueError('{}: Op {} depends on TPU computation {}, ' + 'which is not allowed.'.format(self._message, op, c)) def __enter__(self): # pylint: disable=protected-access diff --git a/tensorflow/contrib/tpu/tpu_estimator.md b/tensorflow/contrib/tpu/tpu_estimator.md new file mode 100644 index 0000000000000000000000000000000000000000..ca1255b16b1575d291df51dfde696b36c38359ae --- /dev/null +++ b/tensorflow/contrib/tpu/tpu_estimator.md @@ -0,0 +1,241 @@ +# Using the Estimator API with TPUs + + +This document describes how to train a TensorFlow model on TPUs using the +Estimator API. If you are interested in the hardware itself, check out the +[Cloud TPU documentation](https://cloud.google.com/tpu/docs). + +The TPU Estimator simplifies running models on a Cloud TPU by automatically +handling numerous low-level hardware-specific details + +[TOC] + +## Introduction to Estimator + +[TensorFlow +tutorials](https://www.tensorflow.org/extend/estimators) cover the Estimator +API. At a high-level, the Estimator API provides: + +* `Estimator.train()` - train a model on a given input for a fixed number of + steps. +* `Estimator.evaluate()` - evaluate the model on a test set. +* `Estimator.predict()` - run inference using the trained model. +* `Estimator.export_savedmodel()` - export your model for serving. + +In addition, `Estimator` includes default behavior common to training jobs, +such as saving and restoring checkpoints, creating summaries for TensorBoard, +etc. + +`Estimator` requires you to write a `model_fn` and an `input_fn`, which +correspond to the model and input portions of your TensorFlow graph. + +The following code demonstrates using `TPUEstimator` with MNIST example to +handle training: + + def model_fn(features, labels, mode, params): + """A simple CNN.""" + del params # unused + + input_layer = tf.reshape(features, [-1, 28, 28, 1]) + conv1 = tf.layers.conv2d( + inputs=input_layer, filters=32, kernel_size=[5, 5], padding="same", + activation=tf.nn.relu) + pool1 = tf.layers.max_pooling2d(inputs=conv1, pool_size=[2, 2], strides=2) + conv2 = tf.layers.conv2d( + inputs=pool1, filters=64, kernel_size=[5, 5], + padding="same", activation=tf.nn.relu) + pool2 = tf.layers.max_pooling2d(inputs=conv2, pool_size=[2, 2], strides=2) + pool2_flat = tf.reshape(pool2, [-1, 7 * 7 * 64]) + dense = tf.layers.dense(inputs=pool2_flat, units=128, activation=tf.nn.relu) + dropout = tf.layers.dropout( + inputs=dense, rate=0.4, training=mode == tf.estimator.ModeKeys.TRAIN) + logits = tf.layers.dense(inputs=dropout, units=10) + onehot_labels = tf.one_hot(indices=tf.cast(labels, tf.int32), depth=10) + + loss = tf.losses.softmax_cross_entropy( + onehot_labels=onehot_labels, logits=logits) + + learning_rate = tf.train.exponential_decay( + FLAGS.learning_rate, tf.train.get_global_step(), 100000, 0.96) + + optimizer = tpu_optimizer.CrossShardOptimizer( + tf.train.GradientDescentOptimizer(learning_rate=learning_rate)) + + train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step()) + return tpu_estimator.TPUEstimatorSpec(mode=mode, loss=loss, train_op=train_op) + + + def get_input_fn(filename): + """Returns an `input_fn` for train and eval.""" + + def input_fn(params): + """An input_fn to parse 28x28 images from filename using tf.data.""" + batch_size = params["batch_size"] + + def parser(serialized_example): + """Parses a single tf.Example into image and label tensors.""" + features = tf.parse_single_example( + serialized_example, + features={ + "image_raw": tf.FixedLenFeature([], tf.string), + "label": tf.FixedLenFeature([], tf.int64), + }) + image = tf.decode_raw(features["image_raw"], tf.uint8) + image.set_shape([28 * 28]) + # Normalize the values of the image from the range [0, 255] to [-0.5, 0.5] + image = tf.cast(image, tf.float32) * (1. / 255) - 0.5 + label = tf.cast(features["label"], tf.int32) + return image, label + + dataset = tf.contrib.data.TFRecordDataset( + filename, buffer_size=FLAGS.dataset_reader_buffer_size) + dataset = dataset.map(parser).cache().repeat().batch(batch_size) + images, labels = dataset.make_one_shot_iterator().get_next() + # set_shape to give inputs statically known shapes. + images.set_shape([batch_size, 28 * 28]) + labels.set_shape([batch_size]) + return images, labels + return input_fn + + + def main(unused_argv): + + tf.logging.set_verbosity(tf.logging.INFO) + + run_config = tpu_config.RunConfig( + master=FLAGS.master, + model_dir=FLAGS.model_dir, + session_config=tf.ConfigProto( + allow_soft_placement=True, log_device_placement=True), + tpu_config=tpu_config.TPUConfig(FLAGS.iterations, FLAGS.num_shards),) + + estimator = tpu_estimator.TPUEstimator( + model_fn=model_fn, + use_tpu=FLAGS.use_tpu, + train_batch_size=FLAGS.batch_size, + eval_batch_size=FLAGS.batch_size, + config=run_config) + + estimator.train(input_fn=get_input_fn(FLAGS.train_file), + max_steps=FLAGS.train_steps) + + +Although this code is quite simple by appearance, there are some new +concepts to learn for using `TPU`s. The next section will cover the most +important details. + +## New Concepts Related to TPU/TPUEstimator + +TF programs run with `TPU Estimator` use an [in-graph +replication](https://www.tensorflow.org/deploy/distributed) approach. + +In-graph replication (also known as single-session replication) differs from +the between-graph replication (also known as multi-session replication) +training typically used in distributed TensorFlow. The major +differences include: + +1. The TensorFlow Session master is not local anymore. The user python program + creates one single graph that is replicated across all the cores in the Cloud + TPU. The typical configuration today sets the TensorFlow session master to be + the first worker. + +1. The input pipeline is placed on remote hosts (instead of local) to ensure the + training examples can be fed as fast as possible to TPU system. All queue-based + input pipelines do not work effectively. Dataset (tf.data) is + required. + +1. Workers in the TPU system operate in synchronous fashion, and each perform + the same step at the same time. + +Regarding programming model, _"The programmer picks a (large) batch size B and +writes the program (and sets hyperparameters) based on that batch size. The +system distributes the computation across the available devices." + +To align these, `TPUEstimator` wraps the computation (the `model_fn`) and +distributes it to all available TPU chips. + +To summarize: + +- The `input_fn` models the input pipeline running on remote host CPU. Use + `tf.data` to program the input Ops. `input_fn` is expected to be invoked + multiple times when using TPU pods. Each handles one device's input of the + global batch. The shard batch size should be retrieved from + `params['batch_size']`. We plan to provide better abstraction about the + sharding mechanism for `tf.data` to remove the `params['batch_size']`. + +- The `model_fn` models the computation which will be replicated and distributed + to all TPU chips. It should only contains ops that are supported by TPUs. + +## Convert from Vanilla Estimator to TPUEstimator + +It is always recommended to port a small, simple model first to make sure that +you are familiar with the basic concepts of `TPUEstimator` and test end-to-end +behavior. Once your simple model runs, gradually add more functionality. +In addition, there are several sample models, available at +[github.com/tensorflow/tpu-demos](https://github.com/tensorflow/tpu-demos). + +To convert your code from the vanilla `Estimator` class to use TPUs, change the +following (note some of the details may change over time): + +- Switch from `tf.estimator.RunConfig` to `tf.contrib.tpu.RunConfig`. +- Set the `TPUConfig` (part of the `tf.contrib.tpu.RunConfig`) to specify the + `iterations_per_loop`, number of iterations to run on the TPU device for one + `session.run` call (per training loop), and `num_shards`, the number of shards + (typically the number of TPU cores you’re running on). TPUs run a number of + iterations of the training loop before returning to host. Until all iterations + on the TPU device are run, no checkpoints or summaries will be saved. In the + future, we’ll choose a reasonable default. +- In `model_fn`, use `tf.contrib.tpu.CrossShardOptimizer` to wrap your + optimizer. Example: + + optimizer = tpu_optimizer.CrossShardOptimizer( + tf.train.GradientDescentOptimizer(learning_rate=learning_rate)) + +- Switch from `tf.estimator.Estimator` to `tf.contrib.tpu.TPUEstimator`. + +The default `RunConfig` will save summaries for TensorBoard every 100 steps and +write checkpoints every 10 minutes. + + +## FAQ + +### Why `tf.data` is Required for the Input Pipeline + +There are two reasons: + +1. The user code runs on the client, while the TPU computation is executed on + the `worker`. Input pipeline ops must be placed on the remote worker for + good performance. Only `tf.data` (Dataset) supports this. + +1. In order to amortize the TPU launch cost, the model train step is wrapped in + a `tf.while_loop`, such that one `Session.run` actually runs many iterations + for one train loop. To remove network back and forth, the input pipeline + in the future will be wrapped in a `tf.while_loop` and be placed on the + corresponding `worker`. Withou this, unnecessary network latency becomes + the performance bottleneck for models with short training-step times, or in + environments where network latency is higher. Only `tf.data` can be wrapped + by a `tf.while_loop`. + + +### How to add other CPU Ops into Graph +As `model_fn` only allows TPU Ops for computation, the easier workaround to add +CPU Ops into Graph is: + +1. Create a [SessionRunHook](https://www.tensorflow.org/api_docs/python/tf/train/SessionRunHook). +1. Modify the graph in the `def begin(self)`, +1. Pass the hook to `TPUEstimator.train`. + +### Running On GCP Cloud TPUs +To run your models on GCP Cloud TPUs refer to the [Cloud Documentation](https://cloud.google.com/tpu/docs/tutorials/mnist). +Refer to this link for all [Cloud TPU documentation](https://cloud.google.com/tpu/docs). + + +### Profiling +You can profile the `worker` by using instructions as spcified in the [Cloud TPU Tools](https://cloud.google.com/tpu/docs/cloud-tpu-tools). + + +### Is `int64` supported? +`int64` is not supported by TPU. Cast to int32 if applicable. The only exception +is global step, which relies on `assign_add`. `int64` support for global step +is added to ensure checkpoint compatibility between `TPUEstimator` and non-TPU +`Estimator`. diff --git a/tensorflow/contrib/training/python/training/hparam.py b/tensorflow/contrib/training/python/training/hparam.py index 80de0f6eb7e36a1c86f7d44e4053a9757b09f0ae..fdfd27d6a414933b0bec824bae512c45dac24d3c 100644 --- a/tensorflow/contrib/training/python/training/hparam.py +++ b/tensorflow/contrib/training/python/training/hparam.py @@ -40,7 +40,7 @@ PARAM_RE = re.compile(r""" ((?P[^,\[]*) # single value: "a" or None | \[(?P[^\]]*)\]) # list of values: None or "1,2,3" - ($|,)""", re.VERBOSE) + ($|,\s*)""", re.VERBOSE) def _parse_fail(name, var_type, value, values): diff --git a/tensorflow/contrib/training/python/training/hparam_test.py b/tensorflow/contrib/training/python/training/hparam_test.py index 28e4b4d01eda9bef07ff7929f74894e09a3e987c..16397622edd382bc6dcb12870de5fa22130a2c2b 100644 --- a/tensorflow/contrib/training/python/training/hparam_test.py +++ b/tensorflow/contrib/training/python/training/hparam_test.py @@ -55,7 +55,7 @@ class HParamsTest(test.TestCase): self.assertEqual(12, hparams.aaa) self.assertEqual(2.0, hparams.b) self.assertEqual('relu6', hparams.c_c) - hparams.parse('c_c=relu4,b=-2.0e10') + hparams.parse('c_c=relu4, b=-2.0e10') self.assertDictEqual({ 'aaa': 12, 'b': -2.0e10, diff --git a/tensorflow/contrib/util/convert_graphdef_memmapped_format_lib.h b/tensorflow/contrib/util/convert_graphdef_memmapped_format_lib.h index 6518e7a10f587a687f2ee3258f5399d74d87364e..61fc6f36f7e5211e43c279506faf09624086d167 100644 --- a/tensorflow/contrib/util/convert_graphdef_memmapped_format_lib.h +++ b/tensorflow/contrib/util/convert_graphdef_memmapped_format_lib.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_UTIL_CONVERT_GRAPHDEF_MEMMAPPED_FORMAT_LIB_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_UTIL_CONVERT_GRAPHDEF_MEMMAPPED_FORMAT_LIB_H_ +#ifndef TENSORFLOW_CONTRIB_UTIL_CONVERT_GRAPHDEF_MEMMAPPED_FORMAT_LIB_H_ +#define TENSORFLOW_CONTRIB_UTIL_CONVERT_GRAPHDEF_MEMMAPPED_FORMAT_LIB_H_ #include @@ -31,4 +31,4 @@ Status ConvertConstantsToImmutable(const string& in_graph_filename, } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_UTIL_CONVERT_GRAPHDEF_MEMMAPPED_FORMAT_LIB_H_ +#endif // TENSORFLOW_CONTRIB_UTIL_CONVERT_GRAPHDEF_MEMMAPPED_FORMAT_LIB_H_ diff --git a/tensorflow/contrib/verbs/BUILD b/tensorflow/contrib/verbs/BUILD index 38a84ffb10e594568a18dbd06debf32545cb2229..80a5d07ea43531ed2532443b6ff9327b9ece6df7 100644 --- a/tensorflow/contrib/verbs/BUILD +++ b/tensorflow/contrib/verbs/BUILD @@ -99,7 +99,7 @@ cc_library( alwayslink = 1, ) -tf_cuda_library( +cc_library( name = "rdma_rendezvous_mgr", srcs = ["rdma_rendezvous_mgr.cc"], hdrs = ["rdma_rendezvous_mgr.h"], @@ -114,7 +114,7 @@ tf_cuda_library( ], ) -cc_library( +tf_cuda_library( name = "rdma_mgr", srcs = ["rdma_mgr.cc"], hdrs = ["rdma_mgr.h"], @@ -141,6 +141,8 @@ tf_cuda_library( "//conditions:default": [], }), deps = [ + ":grpc_verbs_client", + ":verbs_service_proto_cc", ":verbs_util", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", diff --git a/tensorflow/contrib/verbs/README.md b/tensorflow/contrib/verbs/README.md index 7c1c8ea45912be8c471efbe42f43e083639e91fc..1b99f4ce4f645d0c59b2552cf26f47495cbbba73 100644 --- a/tensorflow/contrib/verbs/README.md +++ b/tensorflow/contrib/verbs/README.md @@ -24,66 +24,144 @@ The design is based on TensorFlow r1.0. An RDMA path is added between servers fo During the server setup, an RDMA manager is created to manage low-level RDMA components such as RDMA channel and RDMA adapter, an RDMA rendezvous manager is created to oversee send/recv operations between servers. Following the distributed TensorFlow design philosophy, the send operation is passive, i.e. merely placing a tensor in the local out-going table. It is the receive operation that actually initiates the tensor transfer. -TensorFlow dynamically allocates memory for tensors that are to be sent or received. This causes difficulty for RDMA operations where pinned memory is required. Two remedies are possible, either the memory is pinned, transfer, then unpinned for each and every tensor to be transferred, or a buffer is pre-allocated and pinned for each tensor. The former incurs significant operation overhead since pinning and unpinning memory for each dynamically generated tensor is slow. The latter incurs large memory overhead and extra copying from the tensor to its pinned buffer, but may still be faster than the former. The second approach is adopted in this design. Each RDMA channel, representing a RDMA connection to a peer, contains a table of pinned buffers for all the seen tensors that requires transfer. It is assumed that the tensor size rarely changes across different steps. So only one buffer is created for the same tensor across all the steps. In the rare case when the tensor size does increases, the old buffer is discarded and new buffer of larger size is created and pinned. +TensorFlow dynamically allocates memory for tensors that are to be sent or received. This causes difficulty for RDMA operations where pinned memory is required. Few remedies are possible: +1. The memory is pinned, transfered, then unpinned for each and every tensor to be transferred. This incurs significant operation overhead since pinning and unpinning memory for each dynamically generated tensor is slow. +2. Buffer is pre-allocated and pinned for each tensor. This incurs large memory overhead and extra copying from the tensor to its pinned buffer, but may still be faster than the former. +3. Following HKUST research on the use of GPU direct, and their [GDR implementation](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/gdr/README.md), there is a smart way to benefit from the TensorFlow allocation theme which is mostly pool based, i.e allocators pre-allocate a large memory block, and allocate the tensors from there. By attaching a custom Visitor to relevant alloactors, we can do a single registration of the entire memory block, which zeros the registration overhead. Once the block is registered, each new tensor allocated will be at a registred address, which will allow us to do direct RDMA writes to it. -When a tensor is prepared for transfer, it is first converted to TensorProto, then the proto is serialized to byte array and copied to the pinned buffer. The content of the buffer is transferred to the remote node via RDMA write. On the remote side, the process is reversed. This is illustrated in the diagram below. The conversion of TensorProto is introduced to simplify transfer of string-tensors. Also since the TensorProto lives in host memory, even if the origin tensor lives in the device, the pinned buffers are all allocated in the host memory. -![TensorFlow RDMA path](./design_diagram.png) +For best performance, we will adopt HKUST 0 copies approach in our solution. This means: + +1. Tensor writes will be done directly from the source tensor to the **result** tensor, with no memory copies in between. This should be done for all DMAable tensors which are located either on CPU or on a RDMA compatible GPU device (GPU direct). +2. Non DMAable tensors (CanMemCopy == false) will be serialized to a TensorProto on the sender side, RDMA written to a registered buffer on the receiver side, and then deserialized by the receiver. +3. Tensors which are located on a non-RDMA-compatible GPU, will be RDMA written to a registered CPU **proxy** buffer on the receiver side, and then copied to GPU by the receiver. -The following improvements can be made in the future. First, conversion to TensorProto and serialization can be avoided for numeric (float/int) tensors since their internal buffer can be access directly as byte array. Second, the pinned buffer may be allocated on device if the tensor is located in the device. This avoids extra device-to-host copy at the expense of extra device memory consumption. ## Design details -### RDMA components +### Terminology -* **RDMA adapter:** The base for RDMA communications. It may contain multiple channels and buffers. It is responsible for handling various incoming RDMA messages. -* **RDMA channel:** Responsible for RDMA connection to a particular node. It manages multiple buffers. A channel has a callback table which stores all the callbacks for the requested tensors. -* **RDMA buffer:** Responsible for sending or receiving data. It has a fixed size memory to store the data. It has a queue to store the pending jobs. There are three types of buffers, message buffer, ACK buffer and tensor buffer. A channel has two message buffers, two ack buffers and many tensor buffers. -* **RDMA manager:** Manages the adapter and channels, including channel creation, channel setup via GRPC service, channel lookup, etc. -* **RDMA rendezvous manager:** manages multiple rdma rendezvous. -* **RDMA rendezvous:** a derived class of BaseRemoteRendezvous. This class is the back end for "send" and "recv" ops. When the sendrecv_op wants to send or receive a tensor, it calls the rendezvous' "send" and "recv" functions respectively. Rendezvous are identified by "step_id", a random number, so that tensors for different iterations don't get mixed up. +* **Sender** - The node which sends the tensor. +* **Receiver** - The node which receives the tensor. +* **Result tensor** - The destination tensor, allocated on its appropriate device. +* **Proxy tensor** - A CPU allocated tensor, which will be used in the case where the result tensor cannot be RDMA written to directly (GPU direct is disabled or not available). The RDMA write will therefore be done to the proxy tensor, and afterwards we will do a manual local copy from it to the result tensor. -### The SEND operation +### Messages -In TensorFlow, when rendezvous sends a tensor, it merely puts a tensor in a local table in the corresponding rendezvous. If the tensor has been requested, a callback exists in the table. "send" will activate the callback, which tries to send the tensor across the node. +* RDMA_MESSAGE_TENSOR_REQUEST +* RDMA_MESSAGE_META_DATA_RESPONSE +* RDMA_MESSAGE_TENSOR_RE_REQUEST +### Transport protocol -### The RECV operation +The tensor transfer process is initiated when the receiver requests a tensor. In code it is done by calling **Rendezvous::Recv()** or **Rendezvous::RecvAsync()**. The TensorFlow base implementation handles the case where the requested tensor is located on the same node. The more interesting case where the requested tensor is located on a remote node (receiver != sender) is to be handled in a derivation of the pure virtual **BaseRemoteRendezvous::RecvFromRemoteAsync()**. TensorFlow provides a default GRPC based implementation which comes in the vanilla version but suffers in scalability when running large models. Our RDMA based implementation presumes to be more scalable. HKUST's contrib GDR implementation is more scalable than GRPC, and less scalable than ours, only because we did our evolution based on it. -When a tensor is requested, rendezvous' recv function is called. The function first places a callback in the channel's callback table, which will be activated once the tensor is sent from the source. In the next step, a message is sent to notify the source of the requested tensor. Once the source receives the message, it will check locally for the tensor, if not found, a callback is placed in the table, otherwise, the tensor id will be placed at corresponding RDMA buffer's job queue for future transmission. When a tensor is scheduled to be transmitted, the RDMA buffer needs to have the memory allocated and initialized (registered with the remote buffer info). If the memory is not ready, the transmission is deferred, a message is sent to the destination to establish the memory first. The other case a transmission can be deferred is when the buffer is still being used by an on-going transmission. +Our entry point is the implementation of **RdmaRemoteRendezvous::RecvFromRemoteAsync()**, located in rdma_rendezvous_mgr.cc. The implementation creates a new **RdmaTensorRequest** object, keyed by request index (uint32_t), stores it in a list of pending requests, and calls its **Start()** method. The **Start()** method basically does 2 things: -### Three types of RDMA buffers +1. Allocate the result tensor (and the proxy tensor if required). +2. Send a **RDMA_MESSAGE_TENSOR_REQUEST** to the sender, containing the address of the destination tensor (result/proxy) for RDMA write. -* **Message buffer:** responsible for sending message only. -* **Ack buffer:** once a message is sent, the recipient needs to send an ack via the ack buffer to free up the message buffer. An ack buffer is exclusively for its coupled message buffer. -* **Tensor buffer:** responsible for sending tensors. The recipient needs to send back a message to free up the sending buffer. +In order to allocate the result and proxy tensors, we need to know the tensor's meta-data, i.e. shape and data-type for DMAable tensors, and proto-size for serialized tensors. Unfortunately, this information is only available on the sender side which complicates manners. In order to avoid sending extra messages for querying the meta-data at each step, we store a local meta-data cache per tensor, which will only be update upon changes. Based on the assumption that the meta-data of a tensor rarely changes between steps, we expect that on most times the cache will only be updated once. The sender is responsible to detect changes in the meta-data, and update the receiver. In order for the sender to know that the meta-data had changed, each **RDMA_MESSAGE_TENSOR_REQUEST** will contain the meta-data that the receiver had grabbed from the local cache. The sender will then compare the meta-data from the message to the tensor's new meta-data. -### RDMA packet format +When the sender receives an **RDMA_MESSAGE_TENSOR_REQUEST**, it will create a new **RdmaTensorResponse** object for the given request message, store it in a list of pending responses, and will invoke its **Start()** method. The **Start()** method does the following: -|type|name_size|name|step_id|buffer_size|remote_addr|rkey|is_dead|data_type|tensor_shape|tensor_bytes|tensor_buffer| +1. Grab the source tensor from the local table (In code, **RecvLocalAsync()**). +2. If the source tensor is not DMAable, serialize it to a TensorProto. +3. If the source tensor is located on a device which cannot be DMA written from, copy it to CPU. +4. If it is the first time this tensor is requested, or if the tensor's meta-data changed: + 1. Clone the tensor's data to be sent later. + 2. Send a **RDMA_MESSAGE_META_DATA_RESPONSE** containing the new meta-data. +5. Otherwise: + 1. RDMA write the tensor (or TensorProto) to the destination address and rkey specified in the request message. The immediate value for the write will be the request index. -### Six types of RDMA messages -* RDMA_MESSAGE_ACK -* RDMA_MESSAGE_BUFFER_IDLE -* RDMA_MESSAGE_BUFFER_REQUEST -* RDMA_MESSAGE_BUFFER_RESPONSE -* RDMA_MESSAGE_TENSOR_REQUEST -* RDMA_MESSAGE_TENSOR_WRITE - -### Actions upon receiving RDMA messages -* RDMA_MESSAGE_ACK - * sender: mark local ack buffer idle. - * receiver: mark remote message buffer idle, send next item. -* RDMA_MESSAGE_BUFFER_IDLE - * sender: mark local message buffer idle, send next item. - * receiver: send ack, set remote tensor buffer idle, send next item. -* RDMA_MESSAGE_BUFFER_REQUEST - * sender: mark local message buffer idle, send next item. - * receiver: send ack, find or create tensor buffer, send BUFFER_RESPONSE. -* RDMA_MESSAGE_BUFFER_RESPONSE - * sender: mark local message buffer idle, send next item. - * receiver: send ack, set remote buffer info, set local and remote buffer idle, send next item. -* RDMA_MESSAGE_TENSOR_REQUEST - * sender: mark local message buffer idle, send next item. - * receiver: send ack, find or create tensor buffer, enqueue tensor id, send next item. -* RDMA_MESSAGE_TENSOR_WRITE - * sender: mark local message buffer idle, send next item. - * receiver: run callback. + +When the receiver receives the **RDMA_MESSAGE_META_DATA_RESPONSE**, it will locate the relevant **RdmaTensorRequest** using the request index specified in the message, and invoke its **RecvTensorMetaData()** which does the following: + +1. Update the local meta-data cache. +2. Reallocate the result/proxy tensors. +3. Re-send the tensor request. For tracability, the new message has a different name: **RDMA_MESSAGE_TENSOR_RE_REQUEST**. + +When the sender receives a **RDMA_MESSAGE_TENSOR_RE_REQUEST**, it will locate the relevant **RdmaTensorResponse** using the request index specified in the message, and invoke its **Resume()** method, which will RDMA write the contents of the tensor that was cloned earlier, to the new remote address specified in the re-request. + +When the receiver receives the RDMA write, it will locate the relevant **RdmaTensorRequest** using the request index which is the immediate value. It will then invoke its **RecvTensorContent()** which does the following: + +1. Proxy copy/deserialize if required. +2. Invoke the done callback. +3. Deallocate the result/proxy tensors and remove the request from the pending list. + +![alt text](verbs_with_0_copies.png "Transport protocol") + +### Additional design notes + +1. When the sender receives a tensor request, the source tensor may or may not be ready yet. The situation is handled through a process of tag matching: + * If the request arrives before the tensor is ready, then a callback is put in a local table, and will be invoked once the tensor arrives. + * If the tensor is ready before the request arives, than the tensor is put in a local table. When the request arrives, it will invoke the callback immediatly. + In code it is done by calling **RecvLocalAsync()**, which receives the tensor's key, step-id, and the callback. +2. When the callback is invoked, the relevant tensor is removed from the tag matching table. In the case where we need to send the tensor's meta-data, the **RdmaTensorResponse** will store a copy of the tensor until the re-request arrives. +3. The sending of protocol messages (**RDMA_MESSAGE_TENSOR_REQUEST**, **RDMA_MESSAGE_META_DATA_RESPONSE** and **RDMA_MESSAGE_TENSOR_RE_REQUEST**) is done by the class **RdmaMessageBuffer**. All messages are sent using RDMA writes from/to fixed messages buffers. This implies that we cannot send on a specific channel more than one message at a time. In order to synchronize the messages, the **RdmaMessageBuffer** holds the a local and remote buffer statuses which can be either busy or idle. When a write is issued, both statuses will be changed to busy. When the write-complete event is received, the local status is changed to idle. When the write is received on the remote side, the remote side will parse the message, and return an ACK back to the sending side on which the sending side will update the remote status to idle. When both the local and remote statuses are idle, the next message can be sent. +5. ACK writes are empty writes (hence they require no buffer) with immediate value 0xFFFFFFFE. Message writes have the immediate value 0xFFFFFFFF. All other writes are tensor-content writes whose immediate value is the request-index. + +### RDMA components + +* **enum RdmaImmDataType** - Immediate types to distinguish between different RDMA writes on the remote side. Ack writes and control-message writes have a fixed immediate value. The rest of the writes are tensor writes and the immediate value is the relevant request index. +* **enum RdmaWriteIDType** - Types to distinguish between different RDMA write-complete events: Ack, control message and tensor writes. +* **class RdmaWriteID** - Context for RDMA write complete events. Holds the RdmaWriteIDType and additional data. +* **class RdmaTensorMetaData** - Meta-data for a tensor (type, shape, is_dead, proto_size). +* **class RdmaMemoryMgr** - Manages the meta-data cache, and the registered memory regions. +* **class RdmaTensorRequest** - Holds and manages information for a single tensor request throughout the entire receive cycle. API: + * **Start()** - Start the request sequence. + * Allocate the result tensor (and proxy tensor if required). + * Send RDMA_MESSAGE_TENSOR_REQUEST to the remote side. + * **RecvTensorMetaData()** - Receive meta-data from the remote side. + * Update the local meta-data cache. + * Reallocate the result tensor (and proxy tensor if required). + * Re-send the request to the remote side. + * **RecvTensorContent()** - Receive tensor content from the remote side (RDMA write was completed). + * Decode proto if required and/or move to GPU if the content was not written to it directly (GPU direct is not avaliable). + * Invoke the done callback. +* **class RdmaTensorResponse** - Holds and manages information for a single tensor response throughout the entire send cycle. API: + * **Start()** - Start the response sequence. + * Find the tensor in the local tag-match table. + * Compare the tensor's meta-data to the meta-data in the message (taken from the requester's local cache). + * If meta-data changed: + * Clone the tensor to be sent later. + * Send a meta-data update message and wait for re-request. + * Else: + * Send the tensor's content (using direct RDMA write). + * **Resume()** - Resume the response sequence after a re-request. Send the tensor's content that was cloned earlier. + * **Destroy()** - Destroy the response's resources and remove it form the pending list. +* **class RdmaAdapter** - The base for RDMA communications. It may contain multiple channels and buffers. It is responsible for handling various incoming RDMA messages. +* **class RdmaChannel** - Responsible for RDMA connection to a particular node. It manages messagee buffers. A channel has a request table which stores all the pending tensor requests. +* **class RdmaMessageBuffer** - Responsible for sending or receiving messages. It has a fixed size memory to store the data. It has a queue to store the pending jobs. A channel has two message buffers one for tx and one for rx. +* **class RdmaMgr** - Manages the adapter and channels, including channel creation, channel setup via GRPC service, channel lookup, etc. +* **class RdmaRendezvousMgr** - Manages multiple rdma rendezvous. +* **class RdmaRemoteRendezvous** - A derived class of BaseRemoteRendezvous. This class is the back end for "send" and "recv" ops. When the sendrecv_op wants to send or receive a tensor, it calls the rendezvous' "send" and "recv" functions respectively. Rendezvous are identified by "step_id", a random number, so that tensors for different iterations don't get mixed up. + +### Message structure: + +| type | name_size | name | step_id | request_index | remote_addr/checksum | rkey | is_dead | data_type | tensor_shape | tensor_bytes | error_status | +|------|---------- |------|---------|---------------|----------------------|------|---------|-----------|--------------|--------------|-----------------------| +| 1B | 2B | 512 | 8B | 8B | 8B | 4B | 1B | XB | XB | 8B | Size - 4B, proto - XB | + +* **RDMA_MESSAGE_TENSOR_REQUEST** - (receiver ==> sender) The original tensor request. + * type - The message type. + * name (name_size) - Name of the requested tensor. + * step_id - Step ID. + * request_index - Request index. + * remote_addr/rkey - Address/rkey of the result/proxy tensor. Irrelevant for first-time request. + * is_dead/data_type/tensor_shape/tensor_bytes - The current meta-data as stored in the receiver local cache. The sender will use that information to know if the receiver's cache requires updating. +* **RDMA_MESSAGE_META_DATA_RESPONSE** - (sender ==> receiver) The meta-data update message in case meta-data had changed (or if it is the first time the tensor is requested). + * type - The message type. + * request_index - Request index. + * is_dead/data_type/tensor_shape/tensor_bytes - The up-to-date meta-data. + * checksum - In data validation mode, this will hold the checksum of the source tensor. +* **RDMA_MESSAGE_TENSOR_RE_REQUEST** - (receiver ==> sender) Tensor re-requset after meta-data update and reallocation of result/proxy tensors. + * type - The message type. + * name (name_size) - Name of the requested tensor. + * step_id - Step ID. + * request_index - Request index. + * remote_addr/rkey - Address/rkey of the reallocated result/proxy tensor. +* **RDMA_MESSAGE_ERROR_STATUS** - (sender ==> receiver) Notify the receiver that an error had occured on the sender side, so it can propagate it to the upper levels. + * type - The message type. + * name (name_size) - Name of the requested tensor. + * step_id - Step ID. + * request_index - Request index. + * error_status - The error status (code, message, details). diff --git a/tensorflow/contrib/verbs/grpc_verbs_client.h b/tensorflow/contrib/verbs/grpc_verbs_client.h index 358977f92543e1a38b594cf45cdbff34f89277be..2cfaa4986cb0923d9687cb77b8e1116a937594a1 100644 --- a/tensorflow/contrib/verbs/grpc_verbs_client.h +++ b/tensorflow/contrib/verbs/grpc_verbs_client.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_GRPC_VERBS_CLIENT_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_GRPC_VERBS_CLIENT_H_ +#ifndef TENSORFLOW_CONTRIB_GRPC_VERBS_CLIENT_H_ +#define TENSORFLOW_CONTRIB_GRPC_VERBS_CLIENT_H_ #include "tensorflow/contrib/verbs/grpc_verbs_service_impl.h" #include "tensorflow/contrib/verbs/verbs_service.pb.h" @@ -47,4 +47,4 @@ class GrpcVerbsClient { } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_GRPC_VERBS_CLIENT_H_ +#endif // TENSORFLOW_CONTRIB_GRPC_VERBS_CLIENT_H_ diff --git a/tensorflow/contrib/verbs/grpc_verbs_service.cc b/tensorflow/contrib/verbs/grpc_verbs_service.cc index f2af6b79fba6a480afbfe88fcbefcbf8a6670ce6..742f946c9536973eb8a6a11afda1b32ae4a7726b 100644 --- a/tensorflow/contrib/verbs/grpc_verbs_service.cc +++ b/tensorflow/contrib/verbs/grpc_verbs_service.cc @@ -122,17 +122,15 @@ Status GrpcVerbsService::GetRemoteAddressSync( rc->SetRemoteAddress(ra, false); rc->Connect(); int i = 0; - int idx[] = {1, 0, 3, 2}; - std::vector mb(rc->message_buffers()); - CHECK_EQ(request->mr_size(), 4); + int idx[] = {1, 0}; + std::vector mb(rc->message_buffers()); + CHECK_EQ(request->mr_size(), RdmaChannel::kNumMessageBuffers); for (const auto& mr : request->mr()) { // the connections are crossed, i.e. // local tx_message_buffer <---> remote rx_message_buffer_ // local rx_message_buffer <---> remote tx_message_buffer_ - // local tx_ack_buffer <---> remote rx_ack_buffer_ - // local rx_ack_buffer <---> remote tx_ack_buffer_ - // hence idx[] = {1, 0, 3, 2}. - RdmaBuffer* rb = mb[idx[i]]; + // hence idx[] = {1, 0}. + RdmaMessageBuffer* rb = mb[idx[i]]; RemoteMR rmr; rmr.remote_addr = mr.remote_addr(); rmr.rkey = mr.rkey(); diff --git a/tensorflow/contrib/verbs/grpc_verbs_service.h b/tensorflow/contrib/verbs/grpc_verbs_service.h index aa509602b51e7749547f1ff8eb5193acd1a3ec65..444c863b942ef8bce8d54d59765563b12eb6087e 100644 --- a/tensorflow/contrib/verbs/grpc_verbs_service.h +++ b/tensorflow/contrib/verbs/grpc_verbs_service.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_VERBS_GRPC_VERBS_SERVICE_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_VERBS_GRPC_VERBS_SERVICE_H_ +#ifndef TENSORFLOW_CONTRIB_VERBS_GRPC_VERBS_SERVICE_H_ +#define TENSORFLOW_CONTRIB_VERBS_GRPC_VERBS_SERVICE_H_ #ifdef TENSORFLOW_USE_VERBS @@ -69,4 +69,4 @@ void SetNewVerbsService(GrpcVerbsService** handle, const WorkerEnv* worker_env, } // namespace tensorflow #endif // TENSORFLOW_USE_VERBS -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_VERBS_GRPC_VERBS_SERVICE_H_ +#endif // TENSORFLOW_CONTRIB_VERBS_GRPC_VERBS_SERVICE_H_ diff --git a/tensorflow/contrib/verbs/grpc_verbs_service_impl.h b/tensorflow/contrib/verbs/grpc_verbs_service_impl.h index 86431ca030c38c56155801202714ee4a49b764df..1f0f10517e98a32ae882c027330091928f1a6ee2 100644 --- a/tensorflow/contrib/verbs/grpc_verbs_service_impl.h +++ b/tensorflow/contrib/verbs/grpc_verbs_service_impl.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_GRPC_VERBS_SERVICE_IMPL_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_GRPC_VERBS_SERVICE_IMPL_H_ +#ifndef TENSORFLOW_CONTRIB_GRPC_VERBS_SERVICE_IMPL_H_ +#define TENSORFLOW_CONTRIB_GRPC_VERBS_SERVICE_IMPL_H_ #include "grpc++/impl/codegen/async_stream.h" #include "grpc++/impl/codegen/async_unary_call.h" @@ -86,4 +86,4 @@ class VerbsService GRPC_FINAL { } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_GRPC_VERBS_SERVICE_IMPL_H_ +#endif // TENSORFLOW_CONTRIB_GRPC_VERBS_SERVICE_IMPL_H_ diff --git a/tensorflow/contrib/verbs/patch_notes_verbs_with_0_copies.md b/tensorflow/contrib/verbs/patch_notes_verbs_with_0_copies.md new file mode 100644 index 0000000000000000000000000000000000000000..956b8f2147cf8154b6f1ade006d7bff194864c9b --- /dev/null +++ b/tensorflow/contrib/verbs/patch_notes_verbs_with_0_copies.md @@ -0,0 +1,87 @@ +## Verbs implementation to use direct tensor writes (0 copies) + +### Motivation: + +Following HKUST research on the use of GPU direct, and their [GDR implementation](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/gdr/README.md), we wish to adopt the 0 copies approach and apply it to the current verbs implementation, while keeping the current implementation advantages, such as configurability and the use of RDMA for control messages. + +### Performance: + +Compared with the current GRPC, verbs and GDR implementation, the result implementation gave the best performance for every model, with any number of nodes. For VGG16 on 8 nodes with 4 P100 GPUs each, the prototype beat the second place by over 15%. + +### Implementation requirements: + +1. Tensor writes need to be done directly from the source Tensor to the destination Tensor, with no memory copies in between. This should be done for all DMAble tensors which are located either on CPU or on a RDMA compatible GPU device (GPU direct). +2. Non DMAble tensors (CanMemCopy == false) will be serialized to proto on the sender side, RDMA written to a registered buffer on the receiver side, and then deserialized by the receiver. +3. Tensors which are located on a non-RDMA-compatible GPU, will be RDMA written to a registered CPU proxy buffer on the receiver side, and then copied to GPU by the receiver. + +### Implementation constrains: + +For best stability and proof of correctness, we will divide the implementation to two stages: +1. At first stage we will keep changes to the current implementation to the minimum possible. The expense will be that we may have unused or unnecessary code leftovers, which may also affect performance. +2. At second stage, we will re-iterate over the code and remove irrelevant code parts. +The design of the solution aims that we will achieve both stages with relative ease. + +### Design guidelines: + +1. Since we do not want to do any unnecessary memory copying, we will no longer allocate a fixed CPU buffer as the destination for the RDMA write. Instead we will do the writing directly to the result tensor, or if the result tensor is on a device which does not support RDMA, we will do the writing to a proxy CPU tensor and then copy its content to the result tensor. +2. The address of the destination Tensor needs to be sent to the sender side for writing, meaning that the result/proxy tensor should be pre-allocated on the receiver side, prior to sending the tensor request. In order to do that, we need to know its meta-data, i.e. shape and data-type for DMAble tensors, and proto-size for serialized tensors. Unfortunately, this information is only available on the sender side which complicates manners. In order to avoid sending extra messages for querying the meta-data on each step, we store a local meta-data cache per tensor. Based on the assumption that the meta-data of a tensor rarely changes between steps, we expect that on most times the cache will only be updated once. When the sender receives a request for a tensor, if it is the first time this tensor is requested, or in the rare case that the meta-data did change, the sender will first send a meta-data response, on which the receiver will update the local cache, and reallocate the result/proxy tensors if required. When the receiver sends the tensor request, it will contain also the meta-data currently stored in its local cache, so the sender can compare it to see if there was a change. +3. When the sender writes the tensor content to the result tensor, no additional data is being written with it. That means we need to reside on ibverbs immediate (uint32_t) to indicate which request we are responding to (in order to trigger the receive callback). The easiest and most elegant way is to key the recv callback with a unique request_index (uint32_t), instead of the current key_with_step_id (string). +4. Since the sender no longer writes the tensor from/to fixed buffers, we no longer need to schedule the writes using the local/remote status. In addition we no longer rely on the RmdaTensorBuffer members as the source/destination addresses and rkey/lkey. Instead, each RdmaTensorBuffer will hold multiple "Response" objects (one per step-id), from which we derive destination address and rkey. The source address and lkey are always the ones of the source Tensor. +5. With the addition of tensor pre-allocation, we noticed there is a large code similarity between sending the first tensor request and re-sending the request in case of meta-data changes. After implementing a common method for tensor pre-allocation, it turned out that implementation becomes much simpler by encapsulating the process of request sending/re-sending, meta-data response callback and content response callback, all in a single "Request" class. The request class holds all the relevant request information, which reduces excessive parameter passing and lambda capturing. This decision is purely for elegance and code simplicity, and we decided to implement it in first stage because it makes the implementation much easier. + +### New types/classes: + +* **enum RdmaImmDataType** - Immediate types to distinguish between different RDMA writes on the remote side. Ack writes and control-message writes have a fixed immediate value. The rest of the writes are tensor writes and the immediate value is the relevant request index. +* **enum RdmaWriteIDType** - Types to distinguish between different RDMA write-complete events: Ack, control message, tensor DMA write and tensor proto write. +* **class RdmaWriteID** - Context for RDMA write complete events. Holds the RdmaWriteIDType and additional data. +* **class RemoteAddressContext** - Remote address information (address + mr). Will be passed as write context for tensor proto writes. +* **class RdmaTensorMetaData** - Meta-data for a tensor (type, shape, is_dead, proto_size). +* **class RdmaMemoryMgr** - Manages the meta-data cache, and the registered memory regions. +* **class RdmaTensorRequest** - Holds and manages information for a single tensor request throughout the entire receive cycle. API: + * Start() - Start the request. + * RecvTensorMetaData() - Receive meta-data from the remote side. + * RecvTensorContent() - Receive tensor content from the remote side and invoke the done() callback. +* **class RdmaTensorResponse** - Holds information for a single tensor response, such as destination address and rkey. + +### Protocol changes: + +The protocol messages themselves will remain mostly unchanged at the first stage, but will be used differently, as described below. The current messages structures already have most of the required fields for the new implementation. The only change is the "buffer_size" field which is no longer used since we are no longer sending additional information with the tensor, and thus it is now always equal to the "tensor_bytes" field. Instead, we use that field to pass the "request_index". + +### Message structure: + +| type | name_size | name | step_id | request_index | remote_addr | rkey | is_dead | data_type | tensor_shape | tensor_bytes | +|------|---------- |------|---------|---------------|-------------|------|---------|-----------|--------------|--------------| +| 1B | 2B | 512 | 8B | 8B | 8B | 4B | 1B | XB | XB | 8B | + +* **RDMA_MESSAGE_TENSOR_REQUEST** - (receiver ==> sender) The original tensor request. + * type - The message type. + * name (name_size) - Name of the requested tensor. + * step_id - Step ID. + * request_index - Request index. + * remote_addr/rkey - Address/rkey of the result/proxy tensor. Irrelevant for first-time request. + * is_dead/data_type/tensor_shape/tensor_bytes - The current meta-data as stored in the receiver local cache. The sender will use that information to know if the receiver's cache requires updating. +* **RDMA_MESSAGE_BUFFER_REQUEST** - (sender ==> receiver) The meta-data update message in case meta-data had changed (or if it is the first time the tensor is requested). + * type - The message type. + * request_index - Request index. + * is_dead/data_type/tensor_shape/tensor_bytes - The up-to-date meta-data. +* **RDMA_MESSAGE_BUFFER_RESPONSE** - (receiver ==> sender) Tensor re-requset after meta-data update and reallocation of result/proxy tensors. + * type - The message type. + * name (name_size) - Name of the requested tensor. + * step_id - Step ID. + * request_index - Request index. + * remote_addr/rkey - Address/rkey of the reallocated result/proxy tensor. + * is_dead/data_type/tensor_shape/tensor_bytes - The new meta-data. Will be removed in the next phase. +* **RDMA_MESSAGE_TENSOR_WRITE** - (sender ==> receiver) No longer sent. There is only a direct write of the tensor content to the result/proxy tensor. Request index passed as the immediate value of the write. +* **RDMA_MESSAGE_TENSOR_IDLE** - (receiver ==> sender) No longer sent. + +![alt text](verbs_with_0_copies_phase1_protocol.jpg "Phase 1 message protocol") + +### Second stage optimizations: +1. Remove unused code leftovers. +2. Remove the ACK buffer completely, since we can rely completely on its immediate value. + +### Future optimizations: +1. Map the tensor names to indexes, to significantly reduce the request message size. +2. Understand the purpose of empty tensors and if we can skip remote fetching for them. +3. Consider concatenating multiple requests and/or using multiple message buffers. +4. Consider a no-request architecture. diff --git a/tensorflow/contrib/verbs/rdma.cc b/tensorflow/contrib/verbs/rdma.cc index ae9a384565a6ad0e63a6cf3acf07c591c65f0637..86350a08e57e5050f18d019fe80d70f6381c1f7d 100644 --- a/tensorflow/contrib/verbs/rdma.cc +++ b/tensorflow/contrib/verbs/rdma.cc @@ -15,58 +15,48 @@ limitations under the License. #ifdef TENSORFLOW_USE_VERBS -#include "tensorflow/contrib/verbs/rdma.h" #include #include -#include -#include "tensorflow/contrib/verbs/verbs_util.h" + +#include "tensorflow/contrib/verbs/rdma.h" +#include "tensorflow/contrib/verbs/verbs_service.pb.h" #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/dma_helper.h" +#include "tensorflow/core/common_runtime/process_util.h" #if GOOGLE_CUDA #include "tensorflow/core/common_runtime/gpu/gpu_util.h" #include "tensorflow/core/common_runtime/gpu/process_state.h" #endif #include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h" +#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h" #include "tensorflow/core/distributed_runtime/session_mgr.h" +#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h" #include "tensorflow/core/framework/rendezvous.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/lib/random/random.h" -#include "tensorflow/core/lib/core/threadpool.h" namespace tensorflow { #define RoCE_V2 "RoCE v2" namespace { -// hash name to 32-bit integer -uint32_t NameHash(const string& name) { - return Hash32(name.data(), name.size(), 0x1234ABCD); -} // convenience function for printing message string MessageTypeToString(RdmaMessageType rmt) { switch (rmt) { - case RDMA_MESSAGE_ACK: - return "RDMA_MESSAGE_ACK"; - break; - case RDMA_MESSAGE_BUFFER_IDLE: - return "RDMA_MESSAGE_BUFFER_IDLE"; + case RDMA_MESSAGE_META_DATA_UPDATE: + return "RDMA_MESSAGE_META_DATA_UPDATE"; break; - case RDMA_MESSAGE_BUFFER_REQUEST: - return "RDMA_MESSAGE_BUFFER_REQUEST"; - break; - case RDMA_MESSAGE_BUFFER_RESPONSE: - return "RDMA_MESSAGE_BUFFER_RESPONSE"; + case RDMA_MESSAGE_TENSOR_RE_REQUEST: + return "RDMA_MESSAGE_TENSOR_RE_REQUEST"; break; case RDMA_MESSAGE_TENSOR_REQUEST: return "RDMA_MESSAGE_TENSOR_REQUEST"; break; - case RDMA_MESSAGE_TENSOR_WRITE: - return "RDMA_MESSAGE_TENSOR_WRITE"; - break; default: return "UNKNOWN MESSAGE"; } @@ -347,7 +337,7 @@ uint32_t set_param(uint32_t default_val, const char* env_param) { enum ibv_mtu set_mtu(uint8_t port_num, ibv_context* context) { ibv_port_attr port_attr; - enum ibv_mtu mtu; + enum ibv_mtu mtu = IBV_MTU_512; string mtu_s; int rc, mtu_i; @@ -459,106 +449,79 @@ void RdmaAdapter::Process_CQ() { CHECK_GE(ne, 0); for (int i = 0; i < ne; ++i) { CHECK(wc_[i].status == IBV_WC_SUCCESS) - << "Failed status \n" << ibv_wc_status_str(wc_[i].status) << " " - << wc_[i].status << " " << static_cast(wc_[i].wr_id) << " " - << wc_[i].vendor_err; + << "Failed status \n" + << ibv_wc_status_str(wc_[i].status) << " " << wc_[i].status << " " + << static_cast(wc_[i].wr_id) << " " << wc_[i].vendor_err; if (wc_[i].opcode == IBV_WC_RECV_RDMA_WITH_IMM) { RdmaChannel* rc = reinterpret_cast(wc_[i].wr_id); // put back a recv wr. rc->Recv(); // imm_data is the index of RX buffer in the buffer table. uint32_t imm_data = wc_[i].imm_data; - RdmaBuffer* rb = rc->FindBuffer(imm_data); + RdmaMessageBuffer* rb; RdmaMessage rm; - RdmaMessage::ParseMessage(rm, rb->buffer_); - VLOG(2) << "recv RDMA message: " << MessageTypeToString(rm.type_); - if (rm.type_ == RDMA_MESSAGE_ACK) { + if (imm_data == RDMA_IMM_DATA_ACK) { // receive an ack to a message rb = rc->tx_message_buffer_; rb->SetBufferStatus(remote, idle); rb->SendNextItem(); - } else if (rm.type_ == RDMA_MESSAGE_TENSOR_REQUEST) { - // received a request-for-tensor message - // send ack to release remote tx message buffer - RdmaBuffer* ab = rc->tx_ack_buffer_; - ab->SendNextItem(); - // find or create buffer - RdmaBuffer* tb = rc->FindOrCreateBuffer(rm.name_); - string key_with_step_id = - VerbsUtil::AppendStepidToKey(rm.name_, rm.step_id_); - tb->EnqueueItem(key_with_step_id); - // send the next tensor - worker_env_->compute_pool->Schedule([tb]() { tb->SendNextItem(); }); - } else if (rm.type_ == RDMA_MESSAGE_BUFFER_IDLE) { - // receive tensor-buffer-ready message - // send ack to release remote tx message buffer - RdmaBuffer* ab = rc->tx_ack_buffer_; - ab->SendNextItem(); - // find buffer - RdmaTensorBuffer* tb = - reinterpret_cast(rc->FindBuffer(rm.name_)); - tb->SetBufferStatus(remote, idle); - worker_env_->compute_pool->Schedule([tb]() { tb->ReSendNextItem(); }); - } else if (rm.type_ == RDMA_MESSAGE_BUFFER_REQUEST) { - // remote host requests to create a tensor buffer; - // send ack to release remote tx message buffer - RdmaBuffer* ab = rc->tx_ack_buffer_; - ab->SendNextItem(); - // find or create the buffer - RdmaBuffer* tb = rc->FindOrCreateBuffer(rm.name_, TENSOR); - RemoteMR rmr; - rmr.remote_addr = rm.remote_addr_; - rmr.rkey = rm.rkey_; - tb->SetRemoteMR(rmr, true); - tb->CreateCPUBuffer(rm.buffer_size_); - // create RDMA_MESSAGE_BUFFER_RESPONSE message - RdmaMessage br; - br.type_ = RDMA_MESSAGE_BUFFER_RESPONSE; - br.name_size_ = rm.name_.size(); - br.name_ = rm.name_; - br.buffer_size_ = rm.buffer_size_; - br.remote_addr_ = reinterpret_cast(tb->buffer_); - br.rkey_ = tb->self_->rkey; - string message = RdmaMessage::CreateMessage(br); - RdmaBuffer* mb = rc->tx_message_buffer_; - mb->EnqueueItem(message); - mb->SendNextItem(); - } else if (rm.type_ == RDMA_MESSAGE_BUFFER_RESPONSE) { - // remote creates a buffer and responds - // send ack to release remote tx message buffer - RdmaBuffer* ab = rc->tx_ack_buffer_; - ab->SendNextItem(); - // find buffer - RdmaTensorBuffer* tb = - reinterpret_cast(rc->FindBuffer(rm.name_)); - CHECK(rm.buffer_size_ == tb->size_) - << "rm.buffer_size = " << rm.buffer_size_ - << "tb->size_ = " << tb->size_ << "rm.name_ = " << rm.name_; - RemoteMR rmr; - rmr.remote_addr = rm.remote_addr_; - rmr.rkey = rm.rkey_; - tb->SetRemoteMR(rmr, true); - tb->SetBufferStatus(local, idle); - tb->SetBufferStatus(remote, idle); - worker_env_->compute_pool->Schedule([tb]() { tb->ReSendNextItem(); }); - } else if (rm.type_ == RDMA_MESSAGE_TENSOR_WRITE) { - // tensor RDMA write completed - worker_env_->compute_pool->Schedule([rm, rc]() { - string key_with_step_id = - VerbsUtil::AppendStepidToKey(rm.name_, rm.step_id_); - rc->RunRecvCallback(key_with_step_id); - }); + continue; } - } else if (wc_[i].opcode == IBV_WC_RDMA_WRITE) { - RdmaBuffer* rb = reinterpret_cast(wc_[i].wr_id); - rb->SetBufferStatus(local, idle); - RdmaMessage rm; + + if (imm_data <= RDMA_IMM_MAX_REQUEST_ID) { + // receive a tensor RDMA write + uint32_t request_index = imm_data; + RdmaTensorRequest* request = rc->GetTensorRequest(request_index); + request->RecvTensorContent(); + continue; + } + + // receive a control message + rb = rc->rx_message_buffer_; RdmaMessage::ParseMessage(rm, rb->buffer_); - VLOG(2) << "sent RDMA message: " << MessageTypeToString(rm.type_); - if (rm.type_ != RDMA_MESSAGE_ACK) { - worker_env_->compute_pool->Schedule([rb]() { rb->SendNextItem(); }); + RdmaMessageBuffer::SendAck(rc); + RDMA_LOG(1) << "Step 0x" << std::hex << rm.step_id_ << std::dec + << ": Received " << MessageTypeToString(rm.type_) << " " + << "#" << rm.request_index_ << ": " << rm.name_; + + if (rm.type_ == RDMA_MESSAGE_TENSOR_REQUEST) { + RdmaTensorResponse* response = rc->AddTensorResponse(rm); + response->Start(); + } else if (rm.type_ == RDMA_MESSAGE_META_DATA_UPDATE) { + RdmaTensorRequest* request = rc->GetTensorRequest(rm.request_index_); + request->RecvTensorMetaData(rm.data_type_, rm.tensor_shape_, + rm.is_dead_, rm.tensor_bytes_); +#ifdef RDMA_DATA_VALIDATION + request->RecvTensorChecksum(rm.checksum_); +#endif + } else if (rm.type_ == RDMA_MESSAGE_TENSOR_RE_REQUEST) { + RdmaTensorResponse* response = rc->UpdateTensorResponse(rm); + response->Resume(); + } else if (rm.type_ == RDMA_MESSAGE_ERROR_STATUS) { + RdmaTensorRequest* request = rc->GetTensorRequest(rm.request_index_); + request->RecvErrorStatus(rm.status_); } + } else if (wc_[i].opcode == IBV_WC_RDMA_WRITE) { + RdmaWriteID* wr_id = reinterpret_cast(wc_[i].wr_id); + RDMA_LOG(2) << "Write complete of type " << wr_id->write_type; + switch (wr_id->write_type) { + case RDMA_WRITE_ID_ACK: + break; + case RDMA_WRITE_ID_MESSAGE: { + RdmaMessageBuffer* rb = + reinterpret_cast(wr_id->write_context); + rb->SetBufferStatus(local, idle); + rb->SendNextItem(); + break; + } + case RDMA_WRITE_ID_TENSOR_WRITE: { + RdmaTensorResponse* response = + reinterpret_cast(wr_id->write_context); + response->Destroy(); + } + } + delete wr_id; } } } @@ -577,7 +540,7 @@ int RdmaChannel::PingPostRecv() { int RdmaChannel::PingPostSend() { struct ibv_send_wr wr, *bad_wr; memset(&wr, 0, sizeof(wr)); - wr.wr_id = (uint64_t) this; + wr.wr_id = (uint64_t)this; wr.sg_list = &ping_sge_list_; wr.num_sge = 1; wr.opcode = IBV_WR_SEND; @@ -588,8 +551,10 @@ int RdmaChannel::PingPostSend() { RdmaChannel::RdmaChannel(const RdmaAdapter* adapter, const string local_name, const string remote_name) - : adapter_(adapter), local_name_(local_name), remote_name_(remote_name) { - + : adapter_(adapter), + local_name_(local_name), + remote_name_(remote_name), + request_serial_(0) { struct ibv_sge list; mr_ = ibv_reg_mr(adapter_->pd_, ping_buff_, kPingBuffSize, @@ -651,29 +616,15 @@ RdmaChannel::RdmaChannel(const RdmaAdapter* adapter, const string local_name, // create message and ack buffers, then initialize the tables. { - const string buffer_names[] = {"tx_message_buffer", "rx_message_buffer", - "tx_ack_buffer", "rx_ack_buffer"}; + const string buffer_names[] = {"tx_message_buffer", "rx_message_buffer"}; tx_message_buffer_ = new RdmaMessageBuffer(this, buffer_names[0]); rx_message_buffer_ = new RdmaMessageBuffer(this, buffer_names[1]); - tx_ack_buffer_ = new RdmaAckBuffer(this, buffer_names[2]); - rx_ack_buffer_ = new RdmaAckBuffer(this, buffer_names[3]); message_buffers_.reserve(kNumMessageBuffers); message_buffers_.push_back(tx_message_buffer_); message_buffers_.push_back(rx_message_buffer_); - message_buffers_.push_back(tx_ack_buffer_); - message_buffers_.push_back(rx_ack_buffer_); // create buffer on host tx_message_buffer_->CreateCPUBuffer(RdmaMessage::kRdmaMessageBufferSize); rx_message_buffer_->CreateCPUBuffer(RdmaMessage::kRdmaMessageBufferSize); - tx_ack_buffer_->CreateCPUBuffer(RdmaMessage::kRdmaAckBufferSize); - rx_ack_buffer_->CreateCPUBuffer(RdmaMessage::kRdmaAckBufferSize); - // bt_mu_.lock() is not used in constructor. - for (int i = 0; i < kNumMessageBuffers; i++) { - uint32_t index = NameHash(buffer_names[i]); - buffer_table_.insert({index, message_buffers_[i]}); - buffer_index_name_table_.insert({index, buffer_names[i]}); - buffer_name_index_table_.insert({buffer_names[i], index}); - } } CHECK(PingPostRecv() == 0) << "Couldn't post receive from " << remote_name_ << " with error " << std::strerror(errno); @@ -684,8 +635,6 @@ RdmaChannel::~RdmaChannel() { CHECK(!ibv_destroy_qp(qp_)) << "Failed to destroy QP"; delete tx_message_buffer_; delete rx_message_buffer_; - delete tx_ack_buffer_; - delete rx_ack_buffer_; } void RdmaChannel::SetRemoteAddress(const RdmaAddress& ra, bool override) { @@ -711,119 +660,36 @@ void RdmaChannel::SetRemoteAddress(const RdmaAddress& ra, bool override) { void RdmaChannel::Recv() { struct ibv_recv_wr wr; memset(&wr, 0, sizeof(wr)); - wr.wr_id = (uint64_t) this; + wr.wr_id = (uint64_t)this; struct ibv_recv_wr* bad_wr; CHECK(!ibv_post_recv(qp_, &wr, &bad_wr)) << "Failed to post recv"; } -// Lookup 32-bit buffer index from buffer name -// Args: -// buffer_name: name of the buffer -// Returns: -// 32-bit index -uint32_t RdmaChannel::LookupBufferIndex(const string& buffer_name) { - mutex_lock lock{bt_mu_}; - BufferNameIndexTable::iterator iter = - buffer_name_index_table_.find(buffer_name); - CHECK(iter != buffer_name_index_table_.end()); - return iter->second; -} - -// Find a buffer by its 32-bit index -// Args: -// index: 32-bit hash code of the tensor buffer name -// Returns: -// name of the tensor buffer -RdmaBuffer* RdmaChannel::FindBuffer(const uint32_t index) { - mutex_lock lock{bt_mu_}; - BufferTable::iterator iter = buffer_table_.find(index); - CHECK(iter != buffer_table_.end()); - return iter->second; -} - -// Find a buffer by its name -// Args: -// name: name of the buffer -// Returns: -// the named rdma buffer -RdmaBuffer* RdmaChannel::FindBuffer(const string& name) { - uint32_t index = LookupBufferIndex(name); - return FindBuffer(index); -} - -// Find a buffer if it exists, otherwise create one. -// The memory inside the created buffer is not allocated. -// Args: -// name: the name of the buffer -// buffer_type: TENSOR, MESSAGE or ACK. -// Returns: -// the named buffer -RdmaBuffer* RdmaChannel::FindOrCreateBuffer(const string& name, - BufferType buffer_type) { - mutex_lock lock{bt_mu_}; - RdmaBuffer* rb; - // find index - BufferNameIndexTable::iterator iter = buffer_name_index_table_.find(name); - if (iter != buffer_name_index_table_.end()) { - uint32_t index = iter->second; - // find buffer - BufferTable::iterator iter = buffer_table_.find(index); - CHECK(iter != buffer_table_.end()); - rb = iter->second; - } else { - uint32_t index = NameHash(name); - if (buffer_type == TENSOR) { - rb = new RdmaTensorBuffer(this, name); - } else if (buffer_type == MESSAGE) { - rb = new RdmaMessageBuffer(this, name); - } else if (buffer_type == ACK) { - rb = new RdmaAckBuffer(this, name); - } - buffer_name_index_table_.insert({name, index}); - buffer_index_name_table_.insert({index, name}); - buffer_table_.insert({index, rb}); +RdmaTensorRequest* RdmaChannel::InsertTensorRequest( + const string& key, int64 step_id, Device* dst_dev, + const Rendezvous::Args recv_args, + const RdmaTensorRequest::RecvDoneCallback& done) { + mutex_lock lock{ct_mu_}; + uint32_t request_index = request_serial_++; + if (request_serial_ > RDMA_IMM_MAX_REQUEST_ID) { + request_serial_ = 0; } - CHECK(rb); - return rb; + RdmaTensorRequest request(request_index, key, step_id, this, dst_dev, + recv_args, done); + auto it = request_table_.emplace(request_index, request); + return &it.first->second; } -// Insert callback to the callback_table. -// The callback is activated when the corresponding tensor is received. -// Arg: -// key: the name of the tensor -// recv_done: the callback associated with the tensor. -// Returns: -// None -void RdmaChannel::InsertRecvCallback(const string& key, - std::function recv_done) { +void RdmaChannel::RemoveTensorRequest(uint32_t request_index) { mutex_lock lock{ct_mu_}; - callback_table_.insert({key, recv_done}); + request_table_.erase(request_index); } -// Remove callback from the callback_table. -// Arg: -// key: the name of the tensor -// Returns: -// None -void RdmaChannel::RemoveRecvCallback(const string& key) { +RdmaTensorRequest* RdmaChannel::GetTensorRequest(uint32_t request_index) { mutex_lock lock{ct_mu_}; - callback_table_.erase(key); -} - -// Run named callback in the callback_table. -// Arg: -// key: the name of the tensor -// Returns: -// None -void RdmaChannel::RunRecvCallback(const string& key) { - std::function recv_done; - { - mutex_lock lock{ct_mu_}; - CallbackTable::iterator iter = callback_table_.find(key); - CHECK(iter != callback_table_.end()); - recv_done = iter->second; - } - recv_done(); + RequestTable::iterator iter = request_table_.find(request_index); + CHECK(iter != request_table_.end()); + return &iter->second; } void RdmaChannel::Connect() { @@ -865,11 +731,11 @@ void RdmaChannel::Connect(const RdmaAddress& remoteAddr) { attr.ah_attr.grh.traffic_class = adapter_->params_.traffic_class; int r; - CHECK(!(r = ibv_modify_qp(qp_, &attr, IBV_QP_STATE | IBV_QP_AV | - IBV_QP_PATH_MTU | - IBV_QP_DEST_QPN | IBV_QP_RQ_PSN | - IBV_QP_MAX_DEST_RD_ATOMIC | - IBV_QP_MIN_RNR_TIMER))) + CHECK(!(r = ibv_modify_qp(qp_, &attr, + IBV_QP_STATE | IBV_QP_AV | IBV_QP_PATH_MTU | + IBV_QP_DEST_QPN | IBV_QP_RQ_PSN | + IBV_QP_MAX_DEST_RD_ATOMIC | + IBV_QP_MIN_RNR_TIMER))) << "QP to Ready to Receive " << r; memset(&attr, 0, sizeof(ibv_qp_attr)); @@ -880,33 +746,30 @@ void RdmaChannel::Connect(const RdmaAddress& remoteAddr) { attr.rnr_retry = 7; /* infinite */ attr.max_rd_atomic = 1; - CHECK(!(r = ibv_modify_qp(qp_, &attr, IBV_QP_STATE | IBV_QP_TIMEOUT | - IBV_QP_RETRY_CNT | - IBV_QP_RNR_RETRY | IBV_QP_SQ_PSN | - IBV_QP_MAX_QP_RD_ATOMIC))) + CHECK(!(r = ibv_modify_qp(qp_, &attr, + IBV_QP_STATE | IBV_QP_TIMEOUT | IBV_QP_RETRY_CNT | + IBV_QP_RNR_RETRY | IBV_QP_SQ_PSN | + IBV_QP_MAX_QP_RD_ATOMIC))) << "QP to Ready to Send " << r; connected_ = true; } else { - LOG(INFO) << "channel already connected"; + RDMA_LOG(2) << "channel already connected"; } } -RdmaBuffer::RdmaBuffer(RdmaChannel* channel, string name) +RdmaMessageBuffer::RdmaMessageBuffer(RdmaChannel* channel, string name) : channel_(channel), name_(name) {} -RdmaBuffer::~RdmaBuffer() { +RdmaMessageBuffer::~RdmaMessageBuffer() { CHECK(!ibv_dereg_mr(self_)) << "ibv_dereg_mr failed"; FreeBuffer(); } -void RdmaBuffer::FreeBuffer() { +void RdmaMessageBuffer::FreeBuffer() { if ((buffer_ != nullptr) && buffer_on_host_) { free(buffer_); } - // TODO - // release buffer if it is on device. - // We don't support RDMABuffer on device at this moment. } // Allocate CPU memory for the Rdma buffer @@ -915,7 +778,7 @@ void RdmaBuffer::FreeBuffer() { // lock: whether or not mutex_lock the process to protect concurrency. // Returns: // None -void RdmaBuffer::CreateCPUBuffer(size_t size, bool lock) { +void RdmaMessageBuffer::CreateCPUBuffer(size_t size, bool lock) { CHECK(size > 0); if (lock) { mu_.lock(); @@ -943,7 +806,7 @@ void RdmaBuffer::CreateCPUBuffer(size_t size, bool lock) { // override: whether override existing information // Returns: // None -void RdmaBuffer::SetRemoteMR(RemoteMR rmr, bool override) { +void RdmaMessageBuffer::SetRemoteMR(RemoteMR rmr, bool override) { mutex_lock lock{mu_}; if ((override) || (remote_status_ == none)) { remote_.remote_addr = rmr.remote_addr; @@ -956,63 +819,51 @@ void RdmaBuffer::SetRemoteMR(RemoteMR rmr, bool override) { } // Put a task in the buffer's job queue -void RdmaBuffer::EnqueueItem(string item) { +void RdmaMessageBuffer::EnqueueItem(string item) { mutex_lock lock{mu_}; queue_.push(item); } // Rdma-Write the content of the buffer -void RdmaBuffer::Write(uint32_t imm_data, size_t buffer_size) { +void RdmaMessageBuffer::Write(uint32_t imm_data, size_t buffer_size) { + Write(channel_, imm_data, buffer_size, (uint64_t)buffer_, self_->lkey, + remote_.remote_addr, remote_.rkey, RDMA_WRITE_ID_MESSAGE, this); +} + +// Generalized Write method +void RdmaMessageBuffer::Write(const RdmaChannel* channel, uint32_t imm_data, + size_t buffer_size, uint64_t src_addr, + uint32_t lkey, uint64_t remote_addr, + uint32_t rkey, RdmaWriteIDType write_type, + void* write_context) { struct ibv_sge list; - list.addr = (uint64_t)buffer_; + list.addr = src_addr; list.length = buffer_size; - list.lkey = self_->lkey; + list.lkey = lkey; struct ibv_send_wr wr; memset(&wr, 0, sizeof(wr)); - wr.wr_id = (uint64_t) this; + wr.wr_id = (uint64_t) new RdmaWriteID(write_type, write_context); wr.sg_list = &list; wr.num_sge = 1; wr.opcode = IBV_WR_RDMA_WRITE_WITH_IMM; wr.send_flags = IBV_SEND_SIGNALED; wr.imm_data = imm_data; - wr.wr.rdma.remote_addr = (uint64_t)remote_.remote_addr; - wr.wr.rdma.rkey = remote_.rkey; + wr.wr.rdma.remote_addr = remote_addr; + wr.wr.rdma.rkey = rkey; struct ibv_send_wr* bad_wr; - CHECK(!ibv_post_send(channel_->qp_, &wr, &bad_wr)) << "Failed to post send"; -} - -RdmaAckBuffer::RdmaAckBuffer(RdmaChannel* channel, string name) - : RdmaBuffer(channel, name) {} - -RdmaMessageBuffer::RdmaMessageBuffer(RdmaChannel* channel, string name) - : RdmaBuffer(channel, name) {} - -RdmaTensorBuffer::RdmaTensorBuffer(RdmaChannel* channel, string name) - : RdmaBuffer(channel, name) {} - -RdmaTensorBuffer::~RdmaTensorBuffer() { - for (Itable it = retable.begin(); it != retable.end(); ++it) { - delete (it->second); - } + CHECK(!ibv_post_send(channel->qp_, &wr, &bad_wr)) << "Failed to post send"; } // Send the next ack from the buffer's job queue. -void RdmaAckBuffer::SendNextItem() { - uint32_t imm_data = LookupBufferIndex("rx_ack_buffer"); - RdmaMessage rm; - rm.name_ = "rx_ack_buffer"; - rm.type_ = RDMA_MESSAGE_ACK; - rm.name_size_ = rm.name_.size(); - string message = RdmaMessage::CreateMessage(rm); - memcpy(buffer_, message.data(), message.size()); - Write(imm_data, message.size()); +void RdmaMessageBuffer::SendAck(const RdmaChannel* channel) { + Write(channel, RDMA_IMM_DATA_ACK, 0, 0, 0, 0, 0, RDMA_WRITE_ID_ACK, nullptr); } // Send the next message from the buffer's job queue. void RdmaMessageBuffer::SendNextItem() { - uint32_t imm_data = LookupBufferIndex("rx_message_buffer"); + uint32_t imm_data = RDMA_IMM_DATA_MESSAGE; mu_.lock(); if (!queue_.empty() && (local_status_ == idle) && (remote_status_ == idle)) { local_status_ = busy; @@ -1029,244 +880,390 @@ void RdmaMessageBuffer::SendNextItem() { } } -Rendezvous::DoneCallback RdmaTensorBuffer::getRecvTensorCallback( - const string& key_with_step_id, const string& key, int64 step_id, - const Rendezvous::ParsedKey& parsed) { - Rendezvous::DoneCallback cb = [this, key_with_step_id, key, step_id, parsed]( - const Status& status, const Rendezvous::Args& send_args, - const Rendezvous::Args& recv_args, const Tensor& in, bool is_dead) { - CHECK(status.ok()) << "RecvLocalAsync was not ok, key" << key_with_step_id - << " error message: " << status.error_message(); - size_t buffer_size = RdmaMessage::kMessageTotalBytes; - size_t tensor_bytes = 0; - // Figures out which device the tensor is hosted on. - Device* src_dev = nullptr; - Status s = channel_->adapter_->worker_env_->device_mgr->LookupDevice( - parsed.src_device, &src_dev); - CHECK(s.ok()) << "src device not found"; - // Does the device have the right incarnation number we expect? - CHECK(src_dev->attributes().incarnation() == parsed.src_incarnation) - << "RecvTensor expects a different device incarnation: " - << parsed.src_incarnation << " vs. " - << src_dev->attributes().incarnation() - << ". Your worker job was probably restarted. Check your " - << "worker job for the reason why it was restarted."; - Device* dst_dev = nullptr; - // destination is on CPU. - s = channel_->adapter_->worker_env_->device_mgr->LookupDevice("CPU:0", - &dst_dev); - CHECK(s.ok()) << "dst device not found"; - AllocatorAttributes dst_alloc_attr; - dst_alloc_attr.set_on_host(true); - - bool can_memcpy = DataTypeCanUseMemcpy(in.dtype()); - // string tensor needs to be serialized - Tensor copy; - TensorProto proto; - if (src_dev->tensorflow_gpu_device_info() && - (!send_args.alloc_attrs.on_host())) { #if GOOGLE_CUDA - CHECK(send_args.device_context) << "send dev name: " << src_dev->name() - << " gpu_info: " - << src_dev->tensorflow_gpu_device_info(); - - if (can_memcpy) { - AllocatorAttributes host_alloc_attrs; - host_alloc_attrs.set_gpu_compatible(true); - host_alloc_attrs.set_on_host(true); - Allocator* alloc = ProcessState::singleton()->GetCUDAHostAllocator(0); - copy = Tensor(alloc, in.dtype(), in.shape()); - tensor_bytes = in.TotalBytes(); - buffer_size += tensor_bytes; - GPUUtil::CopyGPUTensorToCPU( - src_dev, send_args.device_context, &in, ©, - [this, copy, tensor_bytes, buffer_size, key, in, step_id, - key_with_step_id, is_dead, send_args, recv_args](const Status& s) { - CHECK(s.ok()) << "copy tensor from gpu sync"; - StringPiece copy_buf; - copy_buf = copy.tensor_data(); - PostCopyOperations(true, buffer_size, tensor_bytes, key, in, - step_id, is_dead, key_with_step_id, ©, - NULL, ©_buf, send_args, recv_args); - }); - } else { - // "val" is on a GPU. No longer uses GPUUtil to fill the proto, use - // aync instead - GPUUtil::SetProtoFromGPU( - in, src_dev, send_args.device_context, &proto, is_dead, - [this, proto, buffer_size, key, in, step_id, key_with_step_id, - is_dead, send_args, recv_args](const Status& s) mutable { - CHECK(s.ok()) << "copy proto from gpu sync"; - auto tensor_bytes = proto.ByteSize(); - buffer_size += tensor_bytes; - PostCopyOperations(false, buffer_size, tensor_bytes, key, in, - step_id, is_dead, key_with_step_id, NULL, - &proto, NULL, send_args, recv_args); - }); - } +static void CountCopies(const std::string& key, void* src_addr, void* dst_addr, + size_t tensor_bytes, bool is_gpu_to_cpu) { +#ifdef RDMA_COUNT_COPIES + static uint64_t numGPUToCPUCopies = 0; + static uint64_t numGPUToCPUCopiedBytes = 0; + static uint64_t numCPUToGPUCopies = 0; + static uint64_t numCPUToGPUCopiedBytes = 0; + static uint64_t numTotalCopies = 0; + + if (is_gpu_to_cpu) { + ++numGPUToCPUCopies; + numGPUToCPUCopiedBytes += tensor_bytes; + } else { + ++numCPUToGPUCopies; + numCPUToGPUCopiedBytes += tensor_bytes; + } + if ((++numTotalCopies % 0x400) == 0) { + RDMA_LOG(0) << "Tensor copies:" + << " GPU to CPU: " << numGPUToCPUCopies << " (" + << numGPUToCPUCopiedBytes << " Bytes)" + << " CPU to GPU: " << numCPUToGPUCopies << " (" + << numCPUToGPUCopiedBytes << " Bytes)"; + } + RDMA_LOG(2) << "Copying tensor " << key << " From: " << src_addr + << " To: " << dst_addr; +#endif // RDMA_COUNT_COPIES +} #endif // GOOGLE_CUDA - } else { - // tensor is in CPU memory. - StringPiece copy_buf; - if (can_memcpy) { - copy_buf = in.tensor_data(); - tensor_bytes = in.TotalBytes(); - } else { - in.AsProtoTensorContent(&proto); - tensor_bytes = proto.ByteSize(); - } - buffer_size += tensor_bytes; - PostCopyOperations(can_memcpy, buffer_size, tensor_bytes, key, in, - step_id, is_dead, key_with_step_id, ©, &proto, - ©_buf, send_args, recv_args); + +#ifdef RDMA_DATA_VALIDATION +static uint64_t Checksum(Device* device, const DeviceContext* device_context, + const Tensor& in) { + uint64 checksum = 0; + if (DataTypeCanUseMemcpy(in.dtype())) { +#if GOOGLE_CUDA + if (in.TotalBytes() == 0) { + return 0; } - }; - return cb; + checksum = (device_context != nullptr) + ? GPUUtil::Checksum(device, device_context, in) + : GPUUtil::Checksum(in); +#endif // GOOGLE_CUDA + } else { + string s = in.SummarizeValue(999999); + checksum = Hash64(s.c_str(), s.size(), 0); + } + return checksum; } -// Send the next tensor from the buffer's job queue. -void RdmaTensorBuffer::SendNextItem() { - // get the key - string key_with_step_id = ""; - { - mutex_lock lock{mu_}; - if (!queue_.empty()) { - key_with_step_id = queue_.front(); - queue_.pop(); +static void ValidateChecksum(uint64_t expected, uint64_t actual, + const Tensor& in, uint32_t request_index, + const std::string& key, const std::string& msg) { + RDMA_LOG(2) << "Request #" << request_index << ": " << key + << ": Checksum: " << std::hex << " Expected = 0x" << expected + << ". Actual = 0x" << actual << "."; + + if (expected != actual) { + // Checksum failed. There is one case where this is allowed - if the + // tensor is an AssignAdd of the global step. Since the data-validation + // always postpones the Tensor response in order to send a checksum message, + // it is possible that the global-step was updated while the response was + // still in queue. + if ((in.TotalBytes() == 8) && (in.dtype() == DT_INT64)) { + int64_t prev_val = *(int64_t*)DMAHelper::base(&in) - 1; + actual = Hash64((const char*)&prev_val, 8, 0); + } + if (expected != actual) { + LOG(FATAL) << "[" << msg << "]: Checksum validation failed for request #" + << request_index << ": " << key << std::hex << " " + << DataTypeString(in.dtype()) << " " + << in.shape().DebugString() << " (0x" << in.TotalBytes() + << " bytes): " + << " Expected 0x" << expected << ". Got 0x" << actual << "."; } } +} +#endif // RDMA_DATA_VALIDATION + +#if GOOGLE_CUDA +// Sync the 'done' operation on the GPU stream, but without all the data +// copying. +static void StreamGPUOp(Device* gpu_device, const DeviceContext* device_context, + StatusCallback done) { + Tensor dummy1, dummy2; + GPUUtil::CopyGPUTensorToCPU(gpu_device, device_context, &dummy1, &dummy2, + done); +} +#endif // GOOGLE_CUDA + +RdmaTensorResponse* RdmaChannel::AddTensorResponse(const RdmaMessage& rm) { + mutex_lock lock{mu_}; + auto it = + responses_table_.emplace(rm.request_index_, RdmaTensorResponse(this, rm)); + CHECK(it.second) << "Response with the ID " << rm.request_index_ + << " already exists."; + return &it.first->second; +} + +RdmaTensorResponse* RdmaChannel::UpdateTensorResponse(const RdmaMessage& rm) { + mutex_lock lock{mu_}; + auto it = responses_table_.find(rm.request_index_); + CHECK(it != responses_table_.end()) << "No response found."; + RdmaTensorResponse* response = &it->second; + response->Update(rm); + return response; +} + +void RdmaChannel::RemoveTensorResponse(uint32_t request_index) { + mutex_lock lock{mu_}; + responses_table_.erase(request_index); +} + +void RdmaTensorResponse::Start() { + Rendezvous::ParsedKey parsed; + Status s = Rendezvous::ParseKey(rm_.name_, &parsed); + if (!s.ok()) { + SendErrorStatus(s); + return; + } - // send the tensor if a key is acquired. - if (key_with_step_id != "") { - VLOG(2) << "try to send tensor: " << key_with_step_id; - string key; - int64 step_id; - VerbsUtil::GetKeyAndStepId(key_with_step_id, key, step_id); - CHECK(key.compare(name_) == 0); - Rendezvous::ParsedKey parsed; - Rendezvous::ParseKey(key, &parsed); - Rendezvous::DoneCallback cb = - getRecvTensorCallback(key_with_step_id, key, step_id, parsed); - channel_->adapter_->worker_env_->rendezvous_mgr->RecvLocalAsync(step_id, - parsed, cb); + channel_->adapter_->worker_env_->rendezvous_mgr->RecvLocalAsync( + rm_.step_id_, parsed, + [this, parsed](const Status& status, const Rendezvous::Args& send_args, + const Rendezvous::Args& recv_args, const Tensor& in, + bool is_dead) { + CHECK(status.ok()) << "RecvLocalAsync was not ok." + << " error message: " << status.error_message(); + RecvHandler(parsed, send_args, recv_args, in, is_dead); + }); +} + +void RdmaTensorResponse::Resume() { SendContent(*tensor_, *proto_, is_dead_); } + +// Helper for RecvTensor. Validates "key" and returns the source +// device in "*src_dev". +Status RdmaTensorResponse::PrepareRecvTensor( + const Rendezvous::ParsedKey& parsed, Device** src_dev) { + // Figures out which device the tensor is hosted on. + string local_name = DeviceNameUtils::LocalName(parsed.src_device); + TF_RETURN_IF_ERROR(channel_->adapter_->worker_env_->device_mgr->LookupDevice( + local_name, src_dev)); + + // Does the device have the right incarnation number we expect? + if ((*src_dev)->attributes().incarnation() != parsed.src_incarnation) { + return errors::Aborted( + "RecvTensor expects a different device incarnation: ", + parsed.src_incarnation, " vs. ", (*src_dev)->attributes().incarnation(), + ". Your worker job was probably restarted. Check your " + "worker job for the reason why it was restarted."); } + + return Status::OK(); } -void RdmaTensorBuffer::ReSendNextItem() { - // get the key - string key_with_step_id = ""; - { - mutex_lock lock{mu_}; - if (!requeue.empty()) { - key_with_step_id = requeue.front(); - requeue.pop(); - } +void RdmaTensorResponse::RecvHandler(Rendezvous::ParsedKey parsed, + const Rendezvous::Args& send_args, + const Rendezvous::Args& recv_args, + const Tensor& in, bool is_dead) { + Status s = PrepareRecvTensor(parsed, &src_dev_); + if (!s.ok()) { + SendErrorStatus(s); + return; } - // send the tensor if a key is acquired. - if (key_with_step_id != "") { - VLOG(2) << "try to send tensor: " << key_with_step_id; - string key; - int64 step_id; - VerbsUtil::GetKeyAndStepId(key_with_step_id, key, step_id); - CHECK(key.compare(name_) == 0); - Rendezvous::ParsedKey parsed; - Rendezvous::ParseKey(key, &parsed); - Rendezvous::DoneCallback cb = - getRecvTensorCallback(key_with_step_id, key, step_id, parsed); - ReItem* item; - { - mutex_lock lock{mu_}; - Itable it = retable.find(key_with_step_id); - CHECK(it != retable.end()) << "Could not find dup-recv context"; - item = it->second; - retable.erase(it); + meta_data_changed_ = TensorMetaDataChanged(in, is_dead); +#ifdef RDMA_DATA_VALIDATION + // Always send a meta data message with the source checksum + meta_data_changed_ = rm_.type_ == RDMA_MESSAGE_TENSOR_REQUEST; + checksum_ = Checksum(src_dev_, send_args.device_context, in); +#endif + bool can_memcpy = DataTypeCanUseMemcpy(in.dtype()); + // string tensor needs to be serialized + Tensor copy; + TensorProto proto; + const bool on_host = send_args.alloc_attrs.on_host(); + if (src_dev_->tensorflow_gpu_device_info() && !on_host) { +#if GOOGLE_CUDA + DeviceContext* send_dev_context = send_args.device_context; + CHECK(send_dev_context) + << "send dev name: " << src_dev_->name() + << " gpu_info: " << src_dev_->tensorflow_gpu_device_info(); + + if (can_memcpy) { + // If the tensor is located on a GDR compatible GPU, there is no need to + // copy it. We can send directly from the source, just need to make sure + // we are in sync with the GPU stream. + // If the tensor's meta-data changed however, we will need to clone it, + // so anyway we'll have to copy it from GPU to CPU first. If at some + // point in time Clone() is changed to only save a shallow copy, we can + // skip the copy here as well. + if ((in.TotalBytes() > 0) && !meta_data_changed_ && + (RdmaMemoryMgr::Singleton().FindMemoryRegion( + (void*)DMAHelper::base(&in), in.TotalBytes()) != nullptr)) { + StreamGPUOp(src_dev_, send_dev_context, + [this, in, proto, is_dead](const Status& s) { + Send(in, proto, is_dead, s); + }); + return; + } + + // The tensor must be copied from GPU to CPU, because either: + // 1. The tensor is located on a non GDR compatible GPU. + // 2. The tensor's meta-data has changed. + Allocator* alloc = ProcessState::singleton()->GetCUDAHostAllocator(0); + copy = Tensor(alloc, in.dtype(), in.shape()); + CountCopies(rm_.name_, (void*)DMAHelper::base(&in), + (void*)DMAHelper::base(©), in.TotalBytes(), true); + GPUUtil::CopyGPUTensorToCPU( + src_dev_, send_dev_context, &in, ©, + [this, copy, proto, is_dead](const Status& s) { + Send(copy, proto, is_dead, s); + }); + } else { + GPUUtil::SetProtoFromGPU( + in, src_dev_, send_args.device_context, &proto, is_dead, + [this, in, proto, is_dead](const Status& s) mutable { + Send(in, proto, is_dead, s); + }); + } +#else + SendErrorStatus(errors::Internal("No GPU device in process")); +#endif // GOOGLE_CUDA + } else { + // tensor is in CPU memory. + if (!can_memcpy) { + in.AsProtoTensorContent(&proto); } - cb(Status::OK(), item->send_args, item->recv_args, item->in, item->is_dead); - delete (item); + Send(in, proto, is_dead, Status::OK()); + } +} + +void RdmaTensorResponse::Send(const Tensor& in, const TensorProto& proto, + bool is_dead, const Status& status) { + if (!status.ok()) { + SendErrorStatus(status); + return; + } + bool can_memcpy = DataTypeCanUseMemcpy(in.dtype()); + bool proto_size_changed = + (!can_memcpy) && (proto.ByteSize() != rm_.tensor_bytes_); + if (meta_data_changed_ || proto_size_changed) { + Clone(in, proto, is_dead); + SendMetaData(in, proto, is_dead); + } else { + SendContent(in, proto, is_dead); } } -void RdmaTensorBuffer::PostCopyOperations( - bool can_memcpy, size_t buffer_size, size_t tensor_bytes, const string& key, - const Tensor& in, int64 step_id, bool is_dead, - const string& key_with_step_id, const Tensor* copy, - const TensorProto* proto, const StringPiece* copy_buf, - const Rendezvous::Args& send_args, const Rendezvous::Args& recv_args) { - // prepare message +bool RdmaTensorResponse::TensorMetaDataChanged(const Tensor& in, bool is_dead) { + return (rm_.data_type_ != in.dtype()) || (rm_.tensor_shape_ != in.shape()) || + (rm_.is_dead_ != is_dead); +} + +void RdmaTensorResponse::Clone(const Tensor& in, const TensorProto& proto, + bool is_dead) { + // Clone the data to be sent later. For simplicity, we clone the tensor's + // data even if it is already a copy. Performance is less of a concern here + // since the meta-data hardly ever changes. The reason we create a copy, is + // that some tensors share their buffer between different step-ids, so the + // tensor content may change before re-request was completed. + bool can_memcpy = DataTypeCanUseMemcpy(in.dtype()); + if (can_memcpy && (in.TotalBytes() > 0)) { + AllocatorAttributes host_alloc_attrs; + host_alloc_attrs.set_nic_compatible(true); + host_alloc_attrs.set_on_host(true); + Allocator* allocator = src_dev_->GetAllocator(host_alloc_attrs); + tensor_ = new Tensor(allocator, in.dtype(), in.shape()); + memcpy(DMAHelper::base(tensor_), DMAHelper::base(&in), in.TotalBytes()); + } else { + tensor_ = new Tensor(in.dtype(), in.shape()); + } + if (!can_memcpy) { + proto_ = new TensorProto(proto); + } + is_dead_ = is_dead; +} + +void RdmaTensorResponse::SendMetaData(const Tensor& in, + const TensorProto& proto, bool is_dead) { + RDMA_LOG(2) << "Request #" << rm_.request_index_ + << ": Meta data changed: " << rm_.name_; + bool can_memcpy = DataTypeCanUseMemcpy(in.dtype()); + size_t tensor_bytes = (can_memcpy) ? in.TotalBytes() : proto.ByteSize(); + + // Send meta-data update: RdmaMessage rm; - rm.name_size_ = key.size(); - rm.name_ = key; + rm.type_ = RDMA_MESSAGE_META_DATA_UPDATE; + rm.name_size_ = rm_.name_.size(); + rm.name_ = rm_.name_; rm.tensor_shape_ = in.shape(); rm.data_type_ = in.dtype(); - rm.step_id_ = step_id; + rm.step_id_ = rm_.step_id_; rm.is_dead_ = is_dead; rm.tensor_bytes_ = tensor_bytes; - rm.buffer_size_ = buffer_size; - mu_.lock(); - if (local_status_ == none || (buffer_size > size_ && local_status_ == idle && - remote_status_ == idle)) { - if ((local_status_ != none) && (buffer_size > size_)) { - VLOG(2) << "Extend RDMA buffer from " << size_ << " to " << buffer_size; - } - CreateCPUBuffer(buffer_size, false); - // Need to be received again, put into the re-recv queue and the table - requeue.push(key_with_step_id); - ReItem* item = new ReItem(send_args, recv_args, in, is_dead); - retable.insert(std::pair(key_with_step_id, item)); - mu_.unlock(); - // no longer used: put back the key since it is not sent; - // ask the remote to create the same buffer - rm.type_ = RDMA_MESSAGE_BUFFER_REQUEST; - rm.remote_addr_ = reinterpret_cast(buffer_); - rm.rkey_ = self_->rkey; - string message = RdmaMessage::CreateMessage(rm); - channel_->tx_message_buffer_->EnqueueItem(message); - channel_->tx_message_buffer_->SendNextItem(); - } else if ((local_status_ == idle) && (remote_status_ == idle)) { - // both buffers are ready, send the tensor - local_status_ = busy; - remote_status_ = busy; - // local/remote_status_ won't be set back to idle - // unitl Write() is successful - mu_.unlock(); - if (!((buffer_size == size_ && rm.data_type_ != DT_STRING) || - (buffer_size <= size_ && rm.data_type_ == DT_STRING))) { - VLOG(2) << "Tensor and buffer size do not agree," - << " buffer_size = " << size_ - << " requested tensor size = " << buffer_size << in.DebugString(); - } - uint32_t imm_data = LookupBufferIndex(key); - rm.type_ = RDMA_MESSAGE_TENSOR_WRITE; - string message = RdmaMessage::CreateMessage(rm); - memcpy(buffer_, message.data(), message.size()); - if (!is_dead) { - // copy the tensor buffer content - void* output = static_cast(static_cast(buffer_) + - RdmaMessage::kTensorBufferStartIndex); - CHECK(tensor_bytes + RdmaMessage::kTensorBufferStartIndex <= size_); - if (can_memcpy) { - CHECK(copy != NULL) << "callback missing pointer to copy tensor"; - CHECK(copy_buf != NULL) << "callback missing pointer to copy buffer"; - CHECK(copy_buf->size() == tensor_bytes) - << "unexpected tensor size: " << copy_buf->size() - << " != " << tensor_bytes; - memcpy(output, copy_buf->data(), tensor_bytes); - } else { - CHECK(proto != NULL) << "callback missing pointer to proto tensor"; - proto->SerializeToArray(output, tensor_bytes); + rm.request_index_ = rm_.request_index_; +#ifdef RDMA_DATA_VALIDATION + rm.checksum_ = checksum_; +#endif + RDMA_LOG(1) << "Step 0x" << std::hex << rm.step_id_ << std::dec + << ": Sending RDMA_MESSAGE_META_DATA_UPDATE #" + << rm.request_index_ << ": " << rm.name_ + << " (shape = " << rm.tensor_shape_.DebugString() << "." + << " data-type = " << DataTypeString(rm.data_type_) << "." + << " is-dead = " << rm.is_dead_ << ")"; + + string message = RdmaMessage::CreateMessage(rm); + channel_->tx_message_buffer_->EnqueueItem(message); + channel_->tx_message_buffer_->SendNextItem(); +} + +void RdmaTensorResponse::SendContent(const Tensor& in, const TensorProto& proto, + bool is_dead) { + bool can_memcpy = DataTypeCanUseMemcpy(in.dtype()); + size_t tensor_bytes = (can_memcpy) ? in.TotalBytes() : proto.ByteSize(); + uint32_t imm_data = rm_.request_index_; + if (!is_dead) { + if (can_memcpy) { + src_buffer_ = const_cast(DMAHelper::buffer(&in)); + if (src_buffer_ != nullptr) { + src_buffer_->Ref(); // Keep buffer alive until write is complete + src_addr_ = src_buffer_->data(); + mr_ = RdmaMemoryMgr::Singleton().FindMemoryRegion(src_addr_, + tensor_bytes); } } else { - buffer_size = RdmaMessage::kMessageTotalBytes; + RDMA_LOG(2) << "Encoding proto: " << rm_.name_ + << " (Size: " << tensor_bytes << ") " << in.DebugString(); + src_addr_ = malloc(tensor_bytes); + mr_ = ibv_reg_mr(channel_->adapter_->pd_, src_addr_, tensor_bytes, + IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE); + proto.SerializeToArray(src_addr_, tensor_bytes); } - Write(imm_data, buffer_size); } else { - // Need to be received again, put into the re-recv queue and the table - requeue.push(key_with_step_id); - ReItem* item = new ReItem(send_args, recv_args, in, is_dead); - retable.insert(std::pair(key_with_step_id, item)); - mu_.unlock(); + tensor_bytes = 0; + } + + uint32_t lkey = (mr_ == nullptr) ? 0 : mr_->lkey; + RDMA_LOG(1) << "Step 0x" << std::hex << rm_.step_id_ << std::dec + << ": Sending tensor content #" << rm_.request_index_ << " from " + << std::hex << src_addr_ << " (0x" << lkey << ")" + << " to " << rm_.remote_addr_ << " (0x" << rm_.rkey_ + << "): " << rm_.name_ << " (size: 0x" << std::hex << tensor_bytes + << ")"; + + RdmaMessageBuffer::Write(channel_, imm_data, tensor_bytes, + (uint64_t)src_addr_, lkey, rm_.remote_addr_, + rm_.rkey_, RDMA_WRITE_ID_TENSOR_WRITE, this); +} + +void RdmaTensorResponse::SendErrorStatus(const Status& status) { + RdmaMessage rm; + rm.type_ = RDMA_MESSAGE_ERROR_STATUS; + rm.name_size_ = rm_.name_.size(); + rm.name_ = rm_.name_; + rm.step_id_ = rm_.step_id_; + rm.request_index_ = rm_.request_index_; + rm.status_ = status; + LOG(ERROR) << "Step 0x" << std::hex << rm.step_id_ << std::dec + << ": Sending RDMA_MESSAGE_ERROR_STATUS #" << rm.request_index_ + << ": " << rm.name_ << ". Status: " << status.ToString(); + + string message = RdmaMessage::CreateMessage(rm); + channel_->tx_message_buffer_->EnqueueItem(message); + channel_->tx_message_buffer_->SendNextItem(); + + // Destroy the response. + Destroy(); +} + +void RdmaTensorResponse::Destroy() { + if (src_buffer_ != nullptr) { + src_buffer_->Unref(); + } + if (tensor_ != nullptr) { + delete tensor_; + } + if (proto_ != nullptr) { + ibv_dereg_mr(mr_); + free(src_addr_); + delete proto_; } + // Remove response from the pending list: + channel_->RemoveTensorResponse(rm_.request_index_); } // Create a RdmaMessage according to the pre-defined format @@ -1276,43 +1273,46 @@ void RdmaTensorBuffer::PostCopyOperations( // message in string format string RdmaMessage::CreateMessage(const RdmaMessage& rm) { // Rdma Message format - // type|name_size|name|step_id|buffer_size|remote_addr|rkey|is_dead|... - // 1B| 2B | 512| 8B | 8B | 8B | 4B | 1B |... - // ...|data_type|tensor_shape|tensor_bytes|tensor_buffer - // ...| XB | XB | 8B |... + // type|name_size|name|step_id|request_index|remote_addr|rkey|is_dead|... + // 1B| 2B | 512| 8B | 8B | 8B | 4B | 1B |... + // ...|data_type|tensor_shape|tensor_bytes|error_status | + // ...| XB | XB | 8B |size - 4B, proto - XB | // - // ACK: type|13|"rx_ack_buffer" - // TENSOR_REQUEST: type|name_size|tensor_name|step_id - // TENSOR_WRITE: type|name_size|tensor_name|step_id|...|is_dead - // |data_type|tensor_shape|tensor_bytes - // BUFFER_IDLE: type|name_size|buffer_name - // BUFFER_REQUEST: - // type|name_size|buffer_name|...|buffer_size|remote_addr|rkey| - // BUFFER_RESPONSE: - // type|name_size|buffer_name|...|buffer_size|remote_addr|rkey| - char message[kMessageTotalBytes]; + // ACK: Imm-type: ACK + // TENSOR_REQUEST: Imm-type: MESSAGE + // Fields: type, request_index, name, step_id, remote_addr, + // rkey, is_dead, data_type, tensor_shape, tensor_bytes + // META_DATA_UPDATE: Imm-type: MESSAGE + // Fields: type, request_index, is_dead, data_type, + // tensor_shape, tensor_bytes + // TENSOR_RE_REQUST: Imm-type: MESSAGE + // Fields: type, request_index, name, step_id, remote_addr, + // rkey, is_dead, data_type, tensor_shape, tensor_bytes + // ERROR_STATUS: Imm-type: MESSAGE + // Fields: type, request_index, name, step_id, error_status + // Tensor content: Imm-type: request_index + size_t message_size = kMessageTotalBytes; + char message[kMessageTotalBytes + kErrorStatusMaxSize]; // type message[kTypeStartIndex] = static_cast(rm.type_) & 0xff; - // size of name - memcpy(&message[kNameSizeStartIndex], &rm.name_size_, sizeof(rm.name_size_)); - // name - memcpy(&message[kNameStartIndex], rm.name_.data(), rm.name_.size()); - // buffer_size, remote_addr, rkey - if ((rm.type_ == RDMA_MESSAGE_BUFFER_REQUEST) || - (rm.type_ == RDMA_MESSAGE_BUFFER_RESPONSE)) { - memcpy(&message[kBufferSizeStartIndex], &rm.buffer_size_, - sizeof(rm.buffer_size_)); + // request index + memcpy(&message[kRequestIndexStartIndex], &rm.request_index_, + sizeof(rm.request_index_)); + // name, step_id, remote_addr, rkey + if ((rm.type_ == RDMA_MESSAGE_TENSOR_REQUEST) || + (rm.type_ == RDMA_MESSAGE_TENSOR_RE_REQUEST)) { + memcpy(&message[kNameSizeStartIndex], &rm.name_size_, + sizeof(rm.name_size_)); + memcpy(&message[kNameStartIndex], rm.name_.data(), rm.name_.size()); memcpy(&message[kRemoteAddrStartIndex], &rm.remote_addr_, sizeof(rm.remote_addr_)); memcpy(&message[kRkeyStartIndex], &rm.rkey_, sizeof(rm.rkey_)); - } - // step_id - if ((rm.type_ == RDMA_MESSAGE_TENSOR_WRITE) || - (rm.type_ == RDMA_MESSAGE_TENSOR_REQUEST)) { memcpy(&message[kStepIdStartIndex], &rm.step_id_, sizeof(rm.step_id_)); } // is_dead, data_type, tensor_shape, tensor_bytes - if (rm.type_ == RDMA_MESSAGE_TENSOR_WRITE) { + if ((rm.type_ == RDMA_MESSAGE_TENSOR_REQUEST) || + (rm.type_ == RDMA_MESSAGE_META_DATA_UPDATE) || + (rm.type_ == RDMA_MESSAGE_TENSOR_RE_REQUEST)) { memcpy(&message[kIsDeadStartIndex], &rm.is_dead_, sizeof(rm.is_dead_)); memcpy(&message[kDataTypeStartIndex], &rm.data_type_, @@ -1322,7 +1322,30 @@ string RdmaMessage::CreateMessage(const RdmaMessage& rm) { memcpy(&message[kTensorBytesStartIndex], &rm.tensor_bytes_, sizeof(rm.tensor_bytes_)); } - return string(message, kMessageTotalBytes); + // checksum +#ifdef RDMA_DATA_VALIDATION + memcpy(&message[kChecksumStartIndex], &rm.checksum_, sizeof(rm.checksum_)); +#endif + // error status + if (rm.type_ == RDMA_MESSAGE_ERROR_STATUS) { + ::grpc::Status gs = ToGrpcStatus(rm.status_); + ErrorStatusProto gsProto; + gsProto.set_error_code(gs.error_code()); + gsProto.set_error_message(gs.error_message()); + gsProto.set_error_details(gs.error_details()); + uint32_t gsProtoSize = gsProto.ByteSize(); + if (gsProtoSize + 4 > kErrorStatusMaxSize) { + LOG(ERROR) << "Error status (" << gsProtoSize + 4 << " bytes) " + << "is too big to fit in RDMA message (" << kErrorStatusMaxSize + << " bytes). Truncated."; + gsProtoSize = kErrorStatusMaxSize - 4; + } + uint32_t* proto_size = (uint32_t*)&message[kErrorStatusStartIndex]; + *proto_size = gsProtoSize; + gsProto.SerializeToArray(&message[kErrorStatusStartIndex + 4], gsProtoSize); + message_size += gsProtoSize + 4; + } + return string(message, message_size); } // Parse a RdmaMessage according to the pre-defined format @@ -1335,26 +1358,24 @@ void RdmaMessage::ParseMessage(RdmaMessage& rm, void* buffer) { char* message = static_cast(buffer); // type rm.type_ = static_cast(message[kTypeStartIndex]); - // name_size_ - memcpy(&rm.name_size_, &message[kNameSizeStartIndex], sizeof(rm.name_size_)); - // name - rm.name_ = string(&message[kNameStartIndex], rm.name_size_); - // buffer_size, remote_addr, rkey - if ((rm.type_ == RDMA_MESSAGE_BUFFER_REQUEST) || - (rm.type_ == RDMA_MESSAGE_BUFFER_RESPONSE)) { - memcpy(&rm.buffer_size_, &message[kBufferSizeStartIndex], - sizeof(rm.buffer_size_)); + // request index + memcpy(&rm.request_index_, &message[kRequestIndexStartIndex], + sizeof(rm.request_index_)); + // name, step_id, remote_addr, rkey + if ((rm.type_ == RDMA_MESSAGE_TENSOR_REQUEST) || + (rm.type_ == RDMA_MESSAGE_TENSOR_RE_REQUEST)) { + memcpy(&rm.name_size_, &message[kNameSizeStartIndex], + sizeof(rm.name_size_)); + rm.name_ = string(&message[kNameStartIndex], rm.name_size_); memcpy(&rm.remote_addr_, &message[kRemoteAddrStartIndex], sizeof(rm.remote_addr_)); memcpy(&rm.rkey_, &message[kRkeyStartIndex], sizeof(rm.rkey_)); - } - // step_id - if ((rm.type_ == RDMA_MESSAGE_TENSOR_WRITE) || - (rm.type_ == RDMA_MESSAGE_TENSOR_REQUEST)) { memcpy(&rm.step_id_, &message[kStepIdStartIndex], sizeof(rm.step_id_)); } // data_type, tensor_bytes, tensor_shape, is_dead - if (rm.type_ == RDMA_MESSAGE_TENSOR_WRITE) { + if ((rm.type_ == RDMA_MESSAGE_TENSOR_REQUEST) || + (rm.type_ == RDMA_MESSAGE_META_DATA_UPDATE) || + (rm.type_ == RDMA_MESSAGE_TENSOR_RE_REQUEST)) { memcpy(&rm.is_dead_, &message[kIsDeadStartIndex], sizeof(rm.is_dead_)); memcpy(&rm.data_type_, &message[kDataTypeStartIndex], sizeof(rm.data_type_)); @@ -1363,6 +1384,291 @@ void RdmaMessage::ParseMessage(RdmaMessage& rm, void* buffer) { memcpy(&rm.tensor_bytes_, &message[kTensorBytesStartIndex], sizeof(rm.tensor_bytes_)); } + // checksum +#ifdef RDMA_DATA_VALIDATION + memcpy(&rm.checksum_, &message[kChecksumStartIndex], sizeof(rm.checksum_)); +#endif + // error status + if (rm.type_ == RDMA_MESSAGE_ERROR_STATUS) { + ErrorStatusProto gsProto; + uint32_t gsProtoSize = *(uint32_t*)&message[kErrorStatusStartIndex]; + CHECK(ParseProtoUnlimited(&gsProto, &message[kErrorStatusStartIndex + 4], + gsProtoSize)) + << "Failed to parse error status proto from message. Aborting."; + ::grpc::Status gs((::grpc::StatusCode)gsProto.error_code(), + gsProto.error_message(), gsProto.error_details()); + rm.status_ = FromGrpcStatus(gs); + } +} + +//***************************************************************************** +// RdmaMemoryMgr +//***************************************************************************** + +ibv_mr* RdmaMemoryMgr::FindMemoryRegion(void* addr, size_t length) { + mutex_lock l(mrs_mu_); + auto iter = std::upper_bound(mrs_.begin(), mrs_.end(), addr, &Comparator); + if (iter == std::end(mrs_) || iter->get()->addr > addr) { + return nullptr; + } else { + return iter->get(); + } +} + +void RdmaMemoryMgr::InsertMemoryRegion(void* addr, size_t length, + const std::string& allocator_name) { + if (length == 0) return; + ibv_mr* mr = ibv_reg_mr(pd_, addr, length, + IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE); + RDMA_LOG(1) << "Insert memory region 0x" << std::hex << mr->rkey << ". [" + << addr << "-" << (void*)((uint64_t)addr + length - 1) << "]" + << " SIZE: 0x" << length << " (" << allocator_name << ")."; + if (mr != nullptr) { + mutex_lock l(mrs_mu_); + auto iter = std::upper_bound(mrs_.begin(), mrs_.end(), addr, &Comparator); + mrs_.insert(iter, {mr, &MRDeleter}); + } else { + LOG(WARNING) << "Cannot register memory region"; + } +} + +void RdmaMemoryMgr::EvictMemoryRegion(void* addr, size_t length) { + if (length == 0) return; + mutex_lock l(mrs_mu_); + auto iter = std::upper_bound(mrs_.begin(), mrs_.end(), addr, &Comparator); + if (iter != std::end(mrs_) && iter->get()->addr == addr) { + mrs_.erase(iter); + RDMA_LOG(1) << "Evict memory region 0x" << std::hex << iter->get()->rkey; + + } else { + LOG(WARNING) << "Failed to de-register memory region"; + } +} + +const TensorMetaData* RdmaMemoryMgr::GetTensorMetaData( + const std::string& tensor_name) { + mutex_lock l(tensor_meta_data_mu_); + auto it = tensors_meta_data_.find(tensor_name); + if (it == tensors_meta_data_.end()) { + return nullptr; + } + return &it->second; +} + +const TensorMetaData* RdmaMemoryMgr::SetTensorMetaData( + const std::string& tensor_name, DataType dtype, const TensorShape& shape, + bool is_dead, size_t proto_size) { + mutex_lock l(tensor_meta_data_mu_); + TensorMetaData& meta_data = tensors_meta_data_[tensor_name]; + meta_data.data_type_ = dtype; + meta_data.tensor_shape_ = shape; + meta_data.proto_size_ = proto_size; + meta_data.is_dead_ = is_dead; + return &meta_data; +} + +//***************************************************************************** +// RdmaTensorRequest +//***************************************************************************** + +RdmaTensorRequest::RdmaTensorRequest( + uint32_t index, const string& key, int64 step_id, RdmaChannel* channel, + Device* dst_dev, const Rendezvous::Args recv_args, + const RdmaTensorRequest::RecvDoneCallback& done) + : index_(index), + key_(key), + step_id_(step_id), + channel_(channel), + dst_dev_(dst_dev), + recv_args_(recv_args), + meta_data_(RdmaMemoryMgr::Singleton().GetTensorMetaData(key)), + result_tensor_(nullptr), + proxy_tensor_(nullptr), + rdma_addr_(nullptr), + mr_(nullptr), + done_(done) {} + +RdmaTensorRequest::~RdmaTensorRequest() { DeallocateTensors(); } + +void RdmaTensorRequest::Done(const Status& s) { + Tensor val = std::move(*result_tensor_); + +#ifdef RDMA_DATA_VALIDATION + // Validate checksum + // Unfortunately we can't always do a Checksum directly on the result tensor. + // If the result tensor is on GPU, then we need to copy it back to CPU. If + // we happen to be in the midst of a proxy callback, then the copying will + // get stuck. + uint64_t checksum = (proxy_tensor_ != nullptr) + ? Checksum(nullptr, nullptr, *proxy_tensor_) + : Checksum(dst_dev_, recv_args_.device_context, val); + ValidateChecksum(checksum_, checksum, val, index_, key_, "RDMA"); +#endif + + Rendezvous::Args recv_args = std::move(recv_args_); + bool is_dead = (meta_data_ == nullptr) ? false : meta_data_->is_dead_; + RecvDoneCallback done = done_; + DeallocateTensors(); + channel_->RemoveTensorRequest(index_); + done(s, Rendezvous::Args(), recv_args, val, is_dead); +} + +void RdmaTensorRequest::DeallocateTensors() { + if (result_tensor_ != nullptr) { + delete result_tensor_; + result_tensor_ = nullptr; + } + if (proxy_tensor_ != nullptr) { + delete proxy_tensor_; + proxy_tensor_ = nullptr; + } +} + +bool RdmaTensorRequest::AllocateTensors() { + result_tensor_ = + new Tensor(dst_dev_->GetAllocator(recv_args_.alloc_attrs), + meta_data_->data_type_, meta_data_->tensor_shape_); + + size_t tensor_size = result_tensor_->TotalBytes(); + bool can_memcpy = DataTypeCanUseMemcpy(result_tensor_->dtype()); + if (can_memcpy) { + if (tensor_size == 0) { + return true; + } + rdma_addr_ = DMAHelper::base(result_tensor_); + mr_ = RdmaMemoryMgr::Singleton().FindMemoryRegion(rdma_addr_, tensor_size); +#if GOOGLE_CUDA + if (mr_ == nullptr) { + // Can't RDMA directly to result. Use a proxy. + proxy_tensor_ = + new Tensor(ProcessState::singleton()->GetCUDAHostAllocator(0), + result_tensor_->dtype(), result_tensor_->shape()); + rdma_addr_ = DMAHelper::base(proxy_tensor_); + mr_ = + RdmaMemoryMgr::Singleton().FindMemoryRegion(rdma_addr_, tensor_size); + } +#endif + } else { + uint32_t proto_size = meta_data_->proto_size_; + rdma_addr_ = malloc(proto_size); + mr_ = ibv_reg_mr(RdmaMemoryMgr::Singleton().pd_, rdma_addr_, proto_size, + IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE); + } + CHECK(mr_ != nullptr) << " No memory region found for address " << rdma_addr_ + << ": " << key_; + return true; +} + +void RdmaTensorRequest::AllocateTensorsAsync(StatusCallback done) { + AllocateTensors(); + bool on_host = recv_args_.alloc_attrs.on_host(); + if (dst_dev_->tensorflow_gpu_device_info() && !on_host && + (proxy_tensor_ == nullptr)) { +#if GOOGLE_CUDA + // We need to sync the memory allocation on the GPU: + StreamGPUOp(dst_dev_, recv_args_.device_context, done); +#endif + } else { + done(Status::OK()); + } +} + +void RdmaTensorRequest::Send(RdmaMessageType message_type) { + RdmaMessageBuffer* rb = channel_->tx_message_buffer_; + RdmaMessage rm; + rm.type_ = message_type; + rm.request_index_ = index_; + rm.name_size_ = key_.size(); + rm.name_ = key_; + rm.step_id_ = step_id_; + rm.remote_addr_ = (uint64_t)rdma_addr_; + if (meta_data_ != nullptr) { + rm.data_type_ = meta_data_->data_type_; + rm.tensor_shape_ = meta_data_->tensor_shape_; + rm.is_dead_ = meta_data_->is_dead_; + rm.tensor_bytes_ = meta_data_->proto_size_; + } else { + rm.data_type_ = DT_INVALID; + } + rm.rkey_ = (mr_ == nullptr) ? 0 : mr_->rkey; + + RDMA_LOG(1) << "Step 0x" << std::hex << rm.step_id_ << std::dec + << ": Sending " << MessageTypeToString(message_type) << " #" + << index_ << ": " << rm.name_ << " on " << rdma_addr_ + << " (rkey: 0x" << std::hex << rm.rkey_ << ")"; + + string message = RdmaMessage::CreateMessage(rm); + rb->EnqueueItem(message); + rb->SendNextItem(); +} + +void RdmaTensorRequest::RecvTensorMetaData(DataType dtype, TensorShape shape, + bool is_dead, size_t proto_size) { + meta_data_ = RdmaMemoryMgr::Singleton().SetTensorMetaData( + key_, dtype, shape, is_dead, proto_size); + + DeallocateTensors(); + AllocateTensorsAsync( + [this](const Status& s) { Send(RDMA_MESSAGE_TENSOR_RE_REQUEST); }); +} + +void RdmaTensorRequest::RecvTensorContent() { + bool can_memcpy = DataTypeCanUseMemcpy(meta_data_->data_type_); + size_t message_size = + can_memcpy ? result_tensor_->TotalBytes() : meta_data_->proto_size_; + RDMA_LOG(1) << "Step 0x" << std::hex << step_id_ << std::dec + << ": Received tensor content #" << index_ << ": " << key_ + << " (Size: 0x" << std::hex << message_size << ")"; + + Tensor val; + +#if GOOGLE_CUDA + if (proxy_tensor_ != nullptr) { + CountCopies(key_, (void*)DMAHelper::base(proxy_tensor_), + (void*)DMAHelper::base(result_tensor_), + result_tensor_->TotalBytes(), false); + GPUUtil::CopyCPUTensorToGPU(proxy_tensor_, recv_args_.device_context, + dst_dev_, result_tensor_, + [this](const Status& s) { + CHECK(s.ok()) << "copy tensor to gpu sync"; + Done(s); + }); + return; + } +#endif + + if (can_memcpy) { + Done(Status::OK()); + } else { + RDMA_LOG(2) << "Decoding proto: " << key_ + << " (Size: " << meta_data_->proto_size_ << ")"; + TensorProto proto; + CHECK(ParseProtoUnlimited(&proto, rdma_addr_, meta_data_->proto_size_)) + << "fail to parse proto from array"; + ibv_dereg_mr(mr_); + free(rdma_addr_); + Status s = dst_dev_->MakeTensorFromProto(proto, recv_args_.alloc_attrs, + result_tensor_); + Done(s); + } +} + +void RdmaTensorRequest::RecvErrorStatus(const Status& status) { + if (result_tensor_ == nullptr) { + result_tensor_ = new Tensor(); + } + LOG(ERROR) << "Received RDMA_MESSAGE_ERROR_STATUS: " << status.ToString(); + Done(status); +} + +void RdmaTensorRequest::Start() { + meta_data_ = RdmaMemoryMgr::Singleton().GetTensorMetaData(key_); + if (meta_data_ != nullptr) { + AllocateTensorsAsync( + [this](const Status& s) { Send(RDMA_MESSAGE_TENSOR_REQUEST); }); + } else { + Send(RDMA_MESSAGE_TENSOR_REQUEST); + } } } // end namespace tensorflow diff --git a/tensorflow/contrib/verbs/rdma.h b/tensorflow/contrib/verbs/rdma.h index fea2327d77ffff67c4b3c45835a81f790bbd1574..94203ee2b3654bffe82d203cde8780a64f63ba2a 100644 --- a/tensorflow/contrib/verbs/rdma.h +++ b/tensorflow/contrib/verbs/rdma.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_VERBS_RDMA_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_VERBS_RDMA_H_ +#ifndef TENSORFLOW_CONTRIB_VERBS_RDMA_H_ +#define TENSORFLOW_CONTRIB_VERBS_RDMA_H_ #ifdef TENSORFLOW_USE_VERBS @@ -27,6 +27,7 @@ limitations under the License. #include #include +#include "tensorflow/contrib/verbs/verbs_util.h" #include "tensorflow/core/distributed_runtime/worker_env.h" #include "tensorflow/core/framework/rendezvous.h" #include "tensorflow/core/framework/tensor.h" @@ -43,6 +44,11 @@ namespace tensorflow { #define SL_DEFAULT 0 #define TRAFFIC_CLASS 0 +#define RDMA_LOG_0 LOG(INFO) +#define RDMA_LOG_1 VLOG(1) +#define RDMA_LOG_2 VLOG(2) +#define RDMA_LOG(LEVEL) RDMA_LOG_##LEVEL + struct RdmaParams { uint8_t port_num; uint8_t sgid_index; @@ -67,38 +73,305 @@ struct RemoteMR { uint64_t remote_addr; uint32_t rkey; }; -enum BufferStatus { - none, - idle, - busy +enum BufferStatus { none, idle, busy }; +enum Location { local, remote }; + +enum RdmaMessageType { + RDMA_MESSAGE_META_DATA_UPDATE, + RDMA_MESSAGE_TENSOR_RE_REQUEST, + RDMA_MESSAGE_TENSOR_REQUEST, + RDMA_MESSAGE_ERROR_STATUS, +}; + +struct RdmaMessage { + RdmaMessageType type_; + uint16_t name_size_; + string name_; + int64 step_id_; + uint64_t request_index_; + union { + uint64_t remote_addr_; +#ifdef RDMA_DATA_VALIDATION + uint64_t checksum_; +#endif + }; + uint32_t rkey_; + bool is_dead_; + DataType data_type_; + TensorShape tensor_shape_; + size_t tensor_bytes_; + + // For error status: + Status status_; + + // type|name_size|name|step_id|request_index|remote_addr/checksum|rkey|... + // 1B| 2B | 512| 8B | 8B | 8B | 4B |... + // ...|is_dead|data_type|tensor_shape|tensor_bytes|error_status | + // ...| 1B | XB | XB | 8B |size - 4B, proto - XB | + static const size_t kNameCapacity = 512; + static const size_t kTypeStartIndex = 0; + static const size_t kNameSizeStartIndex = kTypeStartIndex + sizeof(type_); + static const size_t kNameStartIndex = + kNameSizeStartIndex + sizeof(name_size_); + static const size_t kStepIdStartIndex = kNameStartIndex + kNameCapacity; + static const size_t kRequestIndexStartIndex = + kStepIdStartIndex + sizeof(step_id_); + static const size_t kRemoteAddrStartIndex = + kRequestIndexStartIndex + sizeof(request_index_); + static const size_t kChecksumStartIndex = kRemoteAddrStartIndex; + static const size_t kRkeyStartIndex = + kRemoteAddrStartIndex + sizeof(remote_addr_); + static const size_t kIsDeadStartIndex = kRkeyStartIndex + sizeof(rkey_); + static const size_t kDataTypeStartIndex = + kIsDeadStartIndex + sizeof(is_dead_); + static const size_t kTensorShapeStartIndex = + kDataTypeStartIndex + sizeof(data_type_); + static const size_t kTensorBytesStartIndex = + kTensorShapeStartIndex + sizeof(TensorShape); + static const size_t kErrorStatusStartIndex = + kTensorBytesStartIndex + sizeof(tensor_bytes_); + static const size_t kErrorStatusMaxSize = 4096; + + static const size_t kMessageTotalBytes = kErrorStatusStartIndex; + static const size_t kRdmaMessageBufferSize = + kMessageTotalBytes + kErrorStatusMaxSize; + static string CreateMessage(const RdmaMessage& rm); + static void ParseMessage(RdmaMessage& rm, void* buffer); +}; + +// Immediate types for RDMA write +enum RdmaImmDataType { + RDMA_IMM_MAX_REQUEST_ID = 0xFFFFFFFD, + RDMA_IMM_DATA_ACK = 0xFFFFFFFE, + RDMA_IMM_DATA_MESSAGE = 0xFFFFFFFF }; -enum Location { - local, - remote + +// Write types for RDMA write-complete events +enum RdmaWriteIDType { + RDMA_WRITE_ID_ACK, + RDMA_WRITE_ID_MESSAGE, + RDMA_WRITE_ID_TENSOR_WRITE }; -enum BufferType { - ACK, - MESSAGE, - TENSOR + +// Context for RDMA write-complete events +class RdmaWriteID { + public: + RdmaWriteID(RdmaWriteIDType write_type, void* write_context) + : write_type(write_type), write_context(write_context) {} + + RdmaWriteIDType write_type; + void* write_context; }; -enum RdmaMessageType { - RDMA_MESSAGE_ACK, - RDMA_MESSAGE_BUFFER_IDLE, - RDMA_MESSAGE_BUFFER_REQUEST, - RDMA_MESSAGE_BUFFER_RESPONSE, - RDMA_MESSAGE_TENSOR_REQUEST, - RDMA_MESSAGE_TENSOR_WRITE + +// Tensor meta-data +class TensorMetaData { + public: + TensorShape tensor_shape_; + DataType data_type_; + size_t proto_size_; + bool is_dead_; + + std::ostream& print(std::ostream& out) const { + out << "Dtype = " << DataTypeString(data_type_) + << ", Shape = " << tensor_shape_.DebugString() << ", Proto size = 0x" + << std::hex << proto_size_ << ", Is dead = " << is_dead_; + return out; + } +}; + +inline std::ostream& operator<<(std::ostream& out, + const TensorMetaData& meta_data) { + return meta_data.print(out); +} + +class RdmaChannel; + +void MRDeleter(ibv_mr* mr); +using MemoryRegionPtr = std::unique_ptr; + +// RdmaMemoryMgr +// Manages the local meta-data cache, and the registered RDMA memory regions. +class RdmaMemoryMgr { + public: + static RdmaMemoryMgr& Singleton() { + static RdmaMemoryMgr instance; + return instance; + } + + // Memory regions + ibv_mr* FindMemoryRegion(void* addr, size_t length); + void InsertMemoryRegion(void* addr, size_t length, + const std::string& allocator_name); + void EvictMemoryRegion(void* addr, size_t length); + + // Tensor meta-data cache + const TensorMetaData* GetTensorMetaData(const std::string& tensor_name); + const TensorMetaData* SetTensorMetaData(const std::string& tensor_name, + DataType dtype, + const TensorShape& shape, + bool is_dead, size_t proto_size); + + struct ibv_pd* pd_; + + protected: + RdmaMemoryMgr() : pd_(nullptr) {} + + static bool Comparator(const void* ptr, const MemoryRegionPtr& other) { + return ptr < reinterpret_cast(other->addr) + other->length; + } + + private: + mutex tensor_meta_data_mu_; + std::unordered_map tensors_meta_data_; + + // Managed memory regions + mutex mrs_mu_; + std::vector mrs_ GUARDED_BY(mrs_mu_); }; -class RdmaBuffer; + +// RdmaTensorRequest +// Represents a single tensor request. +class RdmaTensorRequest { + public: + typedef Rendezvous::DoneCallback RecvDoneCallback; + + // Creates a tensor request identified by index. + RdmaTensorRequest(uint32_t index, const string& key, int64 step_id, + RdmaChannel* channel, Device* dst_dev, + const Rendezvous::Args recv_args, + const RecvDoneCallback& done); + ~RdmaTensorRequest(); + + // Request unique index. + uint32_t index() { return index_; } + + // Start the tensor request sequence. + // + // 1. Allocate the result tensor (and proxy tensor if required). + // 2. Send RDMA_MESSAGE_TENSOR_REQUEST to the remote side. + void Start(); + + // Receive tensor meta-data. + // + // 1. Update the local meta-data cache. + // 2. Reallocate the result tensor (and proxy tensor if required). + // 3. Re-send the request to the remote side. + void RecvTensorMetaData(DataType dtype, TensorShape shape, bool is_dead, + size_t proto_size); + + // Receive tensor content (RDMA write was completed). + // + // Decode proto if required and/or move to GPU if the content was not + // written to it directly (GPU direct is not avaliable). Afterwards, + // invoke Done(). + void RecvTensorContent(); + + // Receive error status (in case of a remote error). + // Invoke Done() with the status code. + void RecvErrorStatus(const Status& status); + +#ifdef RDMA_DATA_VALIDATION + // Receive tensor checksum + // + // For validation: Get and store the Tensor's expected checksum for the + // current request. Compare the result Tensor's checksum with the stored + // checksum right before invoking Done(). + void RecvTensorChecksum(uint64_t checksum) { checksum_ = checksum; } +#endif + + private: + void Done(const Status& s); + void Send(RdmaMessageType message_type); + bool AllocateTensors(); + void AllocateTensorsAsync(StatusCallback done); + void DeallocateTensors(); + + uint32_t index_; + string key_; + int64 step_id_; + RdmaChannel* channel_; + Device* dst_dev_; + Rendezvous::Args recv_args_; + const TensorMetaData* meta_data_; + Tensor* result_tensor_; + Tensor* proxy_tensor_; + void* rdma_addr_; + ibv_mr* mr_; + RecvDoneCallback done_; +#ifdef RDMA_DATA_VALIDATION + uint64_t checksum_; +#endif +}; + +// RdmaTensorResponse +// Represents a single tensor response. +class RdmaTensorResponse { + public: + // Creates a response for request message. + RdmaTensorResponse(RdmaChannel* channel, const RdmaMessage& rm) + : channel_(channel), rm_(rm) {} + + void Update(const RdmaMessage& rm) { rm_ = rm; } + + // Start the tensor response sequence. + // + // 1. Find the tensor in the local tag-match table and invoke RecvHandler. + // (Using RecvLocalAsync()). + // 2. Compare the tensor's meta-data to the meta-data in the message (taken + // from the requester's local cache). + // If meta-data changed: + // a. Clone the tensor to be sent later. + // b. Send a meta-data update message and wait for re-request. + // Else: + // a. Send the tensor's content (using direct RDMA write). + void Start(); + + // Resume the response sequence, after a re-request. + // + // 1. Send the tensor's content that was cloned earlier. + void Resume(); + + // Destroy the response's resources and remove it from the pending list. + void Destroy(); + + private: + void RecvHandler(Rendezvous::ParsedKey parsed, + const Rendezvous::Args& send_args, + const Rendezvous::Args& recv_args, const Tensor& in, + bool is_dead); + void Clone(const Tensor& in, const TensorProto& proto, bool is_dead); + void Send(const Tensor& in, const TensorProto& proto, bool is_dead, + const Status& status); + bool TensorMetaDataChanged(const Tensor& in, bool is_dead); + Status PrepareRecvTensor(const Rendezvous::ParsedKey& parsed, + Device** src_dev); + void SendMetaData(const Tensor& in, const TensorProto& proto, bool is_dead); + void SendContent(const Tensor& in, const TensorProto& proto, bool is_dead); + void SendErrorStatus(const Status& status); + + RdmaChannel* channel_; + RdmaMessage rm_; // The request message + Device* src_dev_ = nullptr; + TensorBuffer* src_buffer_ = nullptr; + void* src_addr_ = nullptr; + ibv_mr* mr_ = nullptr; + uint64_t checksum_ = 0; + bool meta_data_changed_ = false; + + // Re-item: + TensorProto* proto_ = nullptr; + Tensor* tensor_ = nullptr; + bool is_dead_ = false; +}; + +class RdmaMessageBuffer; // Class that represents the Rdma Adapter. // Responsible for creation of the completion queue, and handling // of work completions. class RdmaAdapter { friend class RdmaChannel; - friend class RdmaBuffer; - friend class RdmaAckBuffer; friend class RdmaMessageBuffer; - friend class RdmaTensorBuffer; + friend class RdmaTensorResponse; friend class RdmaMgr; friend class RdmaRemoteRendezvous; @@ -133,10 +406,10 @@ class RdmaAdapter { // Responsible for connecting queue pairs. class RdmaChannel { friend class RdmaAdapter; - friend class RdmaBuffer; - friend class RdmaAckBuffer; friend class RdmaMessageBuffer; friend class RdmaTensorBuffer; + friend class RdmaTensorRequest; + friend class RdmaTensorResponse; friend class RdmaMgr; friend class RdmaRemoteRendezvous; @@ -146,22 +419,28 @@ class RdmaChannel { ~RdmaChannel(); inline const RdmaAddress& self() { return self_; } RdmaAddress address() const; - inline const std::vector& message_buffers() const { + inline const std::vector& message_buffers() const { return message_buffers_; } void Connect(const RdmaAddress& remoteAddr); void Connect(); void Recv(); - RdmaBuffer* FindBuffer(const uint32_t index); - RdmaBuffer* FindBuffer(const string& name); - RdmaBuffer* FindOrCreateBuffer(const string& name, - BufferType buffer_type = TENSOR); - uint32_t LookupBufferIndex(const string& buffer_name); void SetRemoteAddress(const RdmaAddress& ra, bool override); - void InsertRecvCallback(const string& key, std::function recv_done); - void RemoveRecvCallback(const string& key); - void RunRecvCallback(const string& key); - static const int kNumMessageBuffers = 4; + + // Requests: + RdmaTensorRequest* InsertTensorRequest( + const string& key, int64 step_id, Device* dst_dev, + const Rendezvous::Args recv_args, + const RdmaTensorRequest::RecvDoneCallback& done); + void RemoveTensorRequest(uint32_t request_index); + RdmaTensorRequest* GetTensorRequest(uint32_t request_index); + + // Responses: + RdmaTensorResponse* AddTensorResponse(const RdmaMessage& rm); + RdmaTensorResponse* UpdateTensorResponse(const RdmaMessage& rm); + void RemoveTensorResponse(uint32_t request_index); + + static const int kNumMessageBuffers = 2; static const int kPingRecvWrid = 0; private: @@ -179,36 +458,31 @@ class RdmaChannel { string remote_name_; ibv_qp* qp_; mutex mu_; - bool connected_ GUARDED_BY(bt_mu_) = false; - RdmaAddress remote_ GUARDED_BY(bt_mu_); - bool remote_set_ GUARDED_BY(bt_mu_) = false; + bool connected_ GUARDED_BY(mu_) = false; + RdmaAddress remote_ GUARDED_BY(mu_); + bool remote_set_ GUARDED_BY(mu_) = false; mutex ct_mu_; - typedef std::unordered_map > CallbackTable; - CallbackTable callback_table_ GUARDED_BY(ct_mu_); - mutex bt_mu_; - typedef std::unordered_map BufferTable; - BufferTable buffer_table_ GUARDED_BY(bt_mu_); - typedef std::unordered_map BufferIndexNameTable; - BufferIndexNameTable buffer_index_name_table_ GUARDED_BY(bt_mu_); - typedef std::unordered_map BufferNameIndexTable; - BufferNameIndexTable buffer_name_index_table_ GUARDED_BY(bt_mu_); - RdmaBuffer* tx_message_buffer_; - RdmaBuffer* rx_message_buffer_; - RdmaBuffer* tx_ack_buffer_; - RdmaBuffer* rx_ack_buffer_; - std::vector message_buffers_; + typedef std::unordered_map RequestTable; + RequestTable request_table_ GUARDED_BY(ct_mu_); + uint32_t request_serial_ GUARDED_BY(ct_mu_); + mutex responses_mu_; + typedef std::unordered_map ResponsesTable; + ResponsesTable responses_table_ GUARDED_BY(responses_mu_); + RdmaMessageBuffer* tx_message_buffer_; + RdmaMessageBuffer* rx_message_buffer_; + std::vector message_buffers_; }; -// Class that represents a buffer for Rdma writes and reads. -class RdmaBuffer { +// Class that represents a buffer for Rdma message sending. +class RdmaMessageBuffer { friend class RdmaChannel; friend class RdmaAdapter; friend class RdmaMgr; friend class RdmaRemoteRendezvous; public: - explicit RdmaBuffer(RdmaChannel* channel, string name); - virtual ~RdmaBuffer(); + explicit RdmaMessageBuffer(RdmaChannel* channel, string name); + ~RdmaMessageBuffer(); inline void* buffer() const { return buffer_; } inline ibv_mr* self() const { return self_; } @@ -223,13 +497,15 @@ class RdmaBuffer { } void FreeBuffer(); void EnqueueItem(string Item); - virtual void SendNextItem() {}; + void SendNextItem(); void CreateCPUBuffer(size_t size, bool lock = true); void SetRemoteMR(RemoteMR rmi, bool override); - uint32_t LookupBufferIndex(const string& buffer_name) { - return const_cast(channel_)->LookupBufferIndex(buffer_name); - } void Write(uint32_t imm_data, size_t buffer_size); + static void Write(const RdmaChannel* channel, uint32_t imm_data, + size_t buffer_size, uint64_t src_addr, uint32_t lkey, + uint64_t remote_addr, uint32_t rkey, + RdmaWriteIDType write_type, void* write_context); + static void SendAck(const RdmaChannel* channel); protected: const RdmaChannel* channel_; @@ -245,126 +521,7 @@ class RdmaBuffer { BufferStatus remote_status_ GUARDED_BY(mu_) = none; }; -class RdmaAckBuffer : public RdmaBuffer { - public: - explicit RdmaAckBuffer(RdmaChannel* channel, string name); - virtual ~RdmaAckBuffer() override {} - void SendNextItem() override; -}; - -class RdmaMessageBuffer : public RdmaBuffer { - friend class RdmaChannel; - friend class RdmaAapater; - - public: - explicit RdmaMessageBuffer(RdmaChannel* channel, string name); - virtual ~RdmaMessageBuffer() override {} - void SendNextItem() override; -}; - -class RdmaTensorBuffer : public RdmaBuffer { - public: - explicit RdmaTensorBuffer(RdmaChannel* channel, string name); - virtual ~RdmaTensorBuffer() override; - void SendNextItem() override; - void PostCopyOperations(bool can_memcpy, size_t buffer_size, - size_t tensor_bytes, const string& key, - const Tensor& in, int64 step_id, bool is_dead, - const string& key_with_step_id, const Tensor* copy, - const TensorProto* proto, const StringPiece* copy_buf, - const Rendezvous::Args& send_args, - const Rendezvous::Args& recv_args); - - void ReSendNextItem(); - - private: - Rendezvous::DoneCallback getRecvTensorCallback( - const string& key_with_step_id, const string& key, int64 step_id, - const Rendezvous::ParsedKey& parsed); - - struct ReItem { - Rendezvous::Args send_args; - Rendezvous::Args recv_args; - Tensor in; - bool is_dead; - - ReItem(const Rendezvous::Args& send_args_, - const Rendezvous::Args& recv_args_, const Tensor& in_, bool is_dead_) - : send_args(send_args_), - recv_args(recv_args_), - in(in_), - is_dead(is_dead_) { - if (send_args.device_context) { - send_args.device_context->Ref(); - } - if (recv_args.device_context) { - recv_args.device_context->Ref(); - } - } - - ~ReItem() { - if (send_args.device_context) { - send_args.device_context->Unref(); - } - if (recv_args.device_context) { - recv_args.device_context->Unref(); - } - } - }; - typedef std::map Table; - typedef Table::iterator Itable; - - std::queue requeue GUARDED_BY(mu_); - Table retable GUARDED_BY(mu_); -}; - -struct RdmaMessage { - RdmaMessageType type_; - uint16_t name_size_; - string name_; - int64 step_id_; - uint64_t buffer_size_; - uint64_t remote_addr_; - uint32_t rkey_; - bool is_dead_; - DataType data_type_; - TensorShape tensor_shape_; - size_t tensor_bytes_; - - // type|name_size|name|step_id|buffer_size|remote_addr|rkey|is_dead|... - // 1B| 2B | 512| 8B | 8B | 8B | 4B | 1B |... - // ...|data_type|tensor_shape|tensor_bytes|tensor_buffer - // ...| XB | XB | 8B |... - // - static const size_t kNameCapacity = 512; - static const size_t kTypeStartIndex = 0; - static const size_t kNameSizeStartIndex = kTypeStartIndex + sizeof(type_); - static const size_t kNameStartIndex = - kNameSizeStartIndex + sizeof(name_size_); - static const size_t kStepIdStartIndex = kNameStartIndex + kNameCapacity; - static const size_t kBufferSizeStartIndex = - kStepIdStartIndex + sizeof(step_id_); - static const size_t kRemoteAddrStartIndex = - kBufferSizeStartIndex + sizeof(buffer_size_); - static const size_t kRkeyStartIndex = - kRemoteAddrStartIndex + sizeof(remote_addr_); - static const size_t kIsDeadStartIndex = kRkeyStartIndex + sizeof(rkey_); - static const size_t kDataTypeStartIndex = - kIsDeadStartIndex + sizeof(is_dead_); - static const size_t kTensorShapeStartIndex = - kDataTypeStartIndex + sizeof(data_type_); - static const size_t kTensorBytesStartIndex = - kTensorShapeStartIndex + sizeof(TensorShape); - static const size_t kTensorBufferStartIndex = - kTensorBytesStartIndex + sizeof(tensor_bytes_); - static const size_t kMessageTotalBytes = kTensorBufferStartIndex; - static const size_t kRdmaMessageBufferSize = kMessageTotalBytes; - static const size_t kRdmaAckBufferSize = kMessageTotalBytes; - static string CreateMessage(const RdmaMessage& rm); - static void ParseMessage(RdmaMessage& rm, void* buffer); -}; - } // namespace tensorflow #endif // TENSORFLOW_USE_VERBS -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_VERBS_RDMA_H_ +#endif // TENSORFLOW_CONTRIB_VERBS_RDMA_H_ diff --git a/tensorflow/contrib/verbs/rdma_mgr.cc b/tensorflow/contrib/verbs/rdma_mgr.cc index 9cb307bcfa06cfdf5ecb9b4faa1d3710e5701080..369bd986df5313955bc22d6e5c6d38815908ada3 100644 --- a/tensorflow/contrib/verbs/rdma_mgr.cc +++ b/tensorflow/contrib/verbs/rdma_mgr.cc @@ -16,11 +16,16 @@ limitations under the License. #ifdef TENSORFLOW_USE_VERBS #include "tensorflow/contrib/verbs/rdma_mgr.h" +#include #include #include "tensorflow/contrib/verbs/grpc_verbs_client.h" #include "tensorflow/contrib/verbs/verbs_service.pb.h" +#include "tensorflow/core/common_runtime/bfc_allocator.h" +#include "tensorflow/core/common_runtime/gpu/gpu_util.h" +#include "tensorflow/core/common_runtime/gpu/process_state.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h" #include "tensorflow/core/distributed_runtime/session_mgr.h" +#include "tensorflow/core/framework/allocator_registry.h" #include "tensorflow/core/lib/core/status.h" namespace tensorflow { @@ -53,7 +58,7 @@ RdmaMgr::RdmaMgr(const WorkerEnv* const worker_env, void RdmaMgr::SetupChannels() { for (const auto& p : channel_table_) { string worker_name = p.first; - LOG(INFO) << "connecting to remote node " << worker_name; + RDMA_LOG(2) << "Connecting to remote node " << worker_name; RdmaChannel* rc = p.second; GetRemoteAddressRequest req; GetRemoteAddressResponse resp; @@ -78,39 +83,49 @@ void RdmaMgr::SetupChannels() { mr->set_rkey(rc->message_buffers_[i]->self_->rkey); } // synchronous call - Status s = client->GetRemoteAddress(&req, &resp); - // save obtained remote addresses - // connect to the remote channel - if (s.ok()) { - CHECK(worker_name.compare(resp.host_name()) == 0); - RdmaAddress ra; - ra.lid = resp.channel().lid(); - ra.qpn = resp.channel().qpn(); - ra.psn = resp.channel().psn(); - ra.snp = resp.channel().snp(); - ra.iid = resp.channel().iid(); - rc->SetRemoteAddress(ra, false); - rc->Connect(); - int i = 0; - int idx[] = {1, 0, 3, 2}; - for (const auto& mr : resp.mr()) { - // the connections are crossed, i.e. - // local tx_message_buffer <---> remote rx_message_buffer_ - // local rx_message_buffer <---> remote tx_message_buffer_ - // local tx_ack_buffer <---> remote rx_ack_buffer_ - // local rx_ack_buffer <---> remote tx_ack_buffer_ - // hence idx[] = {1, 0, 3, 2}. - RdmaBuffer* rb = rc->message_buffers_[idx[i]]; - RemoteMR rmr; - rmr.remote_addr = mr.remote_addr(); - rmr.rkey = mr.rkey(); - rb->SetRemoteMR(rmr, false); - i++; + Status s; + int attempts = 0; + static const int max_num_attempts = 5; + do { + s = client->GetRemoteAddress(&req, &resp); + // save obtained remote addresses + // connect to the remote channel + if (s.ok()) { + CHECK(worker_name.compare(resp.host_name()) == 0); + RdmaAddress ra; + ra.lid = resp.channel().lid(); + ra.qpn = resp.channel().qpn(); + ra.psn = resp.channel().psn(); + ra.snp = resp.channel().snp(); + ra.iid = resp.channel().iid(); + rc->SetRemoteAddress(ra, false); + rc->Connect(); + int i = 0; + int idx[] = {1, 0}; + for (const auto& mr : resp.mr()) { + // the connections are crossed, i.e. + // local tx_message_buffer <---> remote rx_message_buffer_ + // local rx_message_buffer <---> remote tx_message_buffer_ + // hence idx[] = {1, 0}. + RdmaMessageBuffer* rb = rc->message_buffers_[idx[i]]; + RemoteMR rmr; + rmr.remote_addr = mr.remote_addr(); + rmr.rkey = mr.rkey(); + rb->SetRemoteMR(rmr, false); + i++; + } + CHECK(i == RdmaChannel::kNumMessageBuffers); + } else { + LOG(ERROR) << "Connecting to " << worker_name << ": Got " + << s.error_message() << ". Retrying (" << (attempts + 1) + << "/" << max_num_attempts << ")..."; + if (++attempts == max_num_attempts) { + break; + } + worker_env_->env->SleepForMicroseconds(2000000); } - CHECK(i == RdmaChannel::kNumMessageBuffers); - } else { - LOG(ERROR) << s.error_message(); - } + } while (!s.ok()); + RDMA_LOG(0) << "Connected to remote node " << worker_name; delete client; } } @@ -144,19 +159,17 @@ bool RdmaMgr::ConnectivityCheck() { ibv_wc_status s = rdma_adapter_->wc_[i].status; // recv complete if ((int)rdma_adapter_->wc_[i].wr_id == RdmaChannel::kPingRecvWrid) { - CHECK(s == IBV_WC_SUCCESS) << ": " << ibv_wc_status_str( - rdma_adapter_->wc_[i].status) - << "(" << rdma_adapter_->wc_[i].status - << ") for PING_RECV_WRID"; + CHECK(s == IBV_WC_SUCCESS) + << ": " << ibv_wc_status_str(rdma_adapter_->wc_[i].status) << "(" + << rdma_adapter_->wc_[i].status << ") for PING_RECV_WRID"; ++rcnt; // send complete } else { RdmaChannel* rc = reinterpret_cast(rdma_adapter_->wc_[i].wr_id); - CHECK(s == IBV_WC_SUCCESS) << ": " << ibv_wc_status_str( - rdma_adapter_->wc_[i].status) - << "(" << rdma_adapter_->wc_[i].status - << ") to " << rc->remote_name_; + CHECK(s == IBV_WC_SUCCESS) + << ": " << ibv_wc_status_str(rdma_adapter_->wc_[i].status) << "(" + << rdma_adapter_->wc_[i].status << ") to " << rc->remote_name_; ++scnt; } } // for @@ -183,6 +196,139 @@ RdmaChannel* RdmaMgr::FindChannel(const string& name) { return iter->second; } +bool IsGDRAvailable() { +#if defined(__APPLE__) + return false; +#elif defined(PLATFORM_WINDOWS) + return false; +#else + std::ifstream ifs("/proc/modules"); + string line; + while (std::getline(ifs, line)) { + auto sep = line.find(' '); + CHECK_NE(sep, std::string::npos); + if (line.substr(0, sep) == "nv_peer_mem") { + return true; + } + } + return false; +#endif +} + +int TryToReadNumaNode(ibv_device* device) { +#if defined(__APPLE__) + LOG(INFO) << "OS X does not support NUMA - returning NUMA node 0"; + return 0; +#elif defined(PLATFORM_WINDOWS) + // Windows support for NUMA is not currently implemented. Return node 0. + return 0; +#else + VLOG(2) << "Trying to read NUMA node for device: " << device->name; + static const int kUnknownNumaNode = -1; + + auto filename = string(device->ibdev_path) + "/device/numa_node"; + + std::ifstream ifs(filename.c_str()); + string content; + CHECK(std::getline(ifs, content)); + + int32 value; + if (strings::safe_strto32(content, &value)) { + if (value < 0) { + LOG(INFO) << "Successful NUMA node read from SysFS had negative value (" + << value + << "), but there must be at least one NUMA node" + ", so returning NUMA node zero"; + return 0; + } + LOG(INFO) << "NUMA node for device: " << device->name << " is " << value; + return value; + } + return kUnknownNumaNode; +#endif +} + +void MRDeleter(ibv_mr* mr) { + if (mr) { + ibv_dereg_mr(mr); + } +} + +// TODO(byronyi): remove this class duplicated from the one in +// common/runtime/gpu/pool_allocator.h when it is available in common_runtime +class BasicCPUAllocator : public SubAllocator { + public: + ~BasicCPUAllocator() override {} + + void* Alloc(size_t alignment, size_t num_bytes) override { + return port::AlignedMalloc(num_bytes, alignment); + } + void Free(void* ptr, size_t) override { port::AlignedFree(ptr); } +}; + +// TODO(byronyi): remove this class and its registration when the default +// cpu_allocator() returns visitable allocator +class BFCRdmaAllocator : public BFCAllocator { + public: + BFCRdmaAllocator() + : BFCAllocator(new BasicCPUAllocator(), 1LL << 36, true, "cpu_rdma_bfc") { + } +}; + +REGISTER_MEM_ALLOCATOR("BFCRdmaAllocator", 101, BFCRdmaAllocator); + +void RdmaMgr::InitAllocators() { + RdmaMemoryMgr::Singleton().pd_ = rdma_adapter_->pd_; + + Allocator* allocators[] = { +#if GOOGLE_CUDA + ProcessState::singleton()->GetCUDAHostAllocator(0), + ProcessState::singleton()->GetCPUAllocator(0), +#endif // GOOGLE_CUDA + cpu_allocator(), + }; + + using namespace std::placeholders; + + std::set instrumented_; + + // Host memory allocators + for (Allocator* allocator : allocators) { + VisitableAllocator::Visitor alloc_visitor = + std::bind(&RdmaMemoryMgr::InsertMemoryRegion, + &RdmaMemoryMgr::Singleton(), _1, _2, allocator->Name()); + VisitableAllocator::Visitor free_visitor = std::bind( + &RdmaMemoryMgr::EvictMemoryRegion, &RdmaMemoryMgr::Singleton(), _1, _2); + + auto* visitable_allocator = dynamic_cast(allocator); + CHECK(visitable_allocator) + << "is not visitable for instrumentation" << allocator->Name(); + // Make sure we don't instrument the same allocator twice + if (instrumented_.find(allocator) == std::end(instrumented_)) { + visitable_allocator->AddAllocVisitor(alloc_visitor); + visitable_allocator->AddFreeVisitor(free_visitor); + instrumented_.insert(allocator); + LOG(INFO) << "Instrumenting CPU allocator " << allocator->Name(); + } + } + +#if GOOGLE_CUDA + if (IsGDRAvailable()) { + // Note we don't free allocated GPU memory so there is no free visitor + int32_t bus_id = TryToReadNumaNode(rdma_adapter_->context_->device) + 1; + + char buf[8]; + sprintf(buf, "gpu"); + VisitableAllocator::Visitor cuda_alloc_visitor = + std::bind(&RdmaMemoryMgr::InsertMemoryRegion, + &RdmaMemoryMgr::Singleton(), _1, _2, std::string(buf)); + + ProcessState::singleton()->AddGPUAllocVisitor(bus_id, cuda_alloc_visitor); + LOG(INFO) << "Instrumenting GPU allocator with bus_id " << bus_id; + } +#endif // GOOGLE_CUDA +} + } // end namespace tensorflow #endif diff --git a/tensorflow/contrib/verbs/rdma_mgr.h b/tensorflow/contrib/verbs/rdma_mgr.h index e711e604788b12ff0c1a0977a90db21f9f8fa50e..9fffc335bbe2bf47a626736f6d3073f52b32a9c2 100644 --- a/tensorflow/contrib/verbs/rdma_mgr.h +++ b/tensorflow/contrib/verbs/rdma_mgr.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_VERBS_RDMA_MGR_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_VERBS_RDMA_MGR_H_ +#ifndef TENSORFLOW_CONTRIB_VERBS_RDMA_MGR_H_ +#define TENSORFLOW_CONTRIB_VERBS_RDMA_MGR_H_ #ifdef TENSORFLOW_USE_VERBS @@ -38,6 +38,7 @@ class RdmaMgr { RdmaChannel* FindChannel(const string& key); void SetupChannels(); bool ConnectivityCheck(); + void InitAllocators(); const string& local_worker() { return local_worker_; } private: @@ -54,4 +55,4 @@ class RdmaMgr { } // namespace tensorflow #endif // TENSORFLOW_USE_VERBS -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_VERBS_RDMA_MGR_H_ +#endif // TENSORFLOW_CONTRIB_VERBS_RDMA_MGR_H_ diff --git a/tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc b/tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc index 74f6681af3c29f370d6cdb37d64e10a30cbb7b84..ad3dce17844c5a43237372fb7fe074416e8b7117 100644 --- a/tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc +++ b/tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc @@ -21,10 +21,6 @@ limitations under the License. #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/dma_helper.h" -#if GOOGLE_CUDA -#include "tensorflow/core/common_runtime/gpu/gpu_util.h" -#include "tensorflow/core/common_runtime/gpu/process_state.h" -#endif // GOOGLE_CUDA #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/lib/strings/str_util.h" @@ -36,11 +32,6 @@ class RdmaRemoteRendezvous : public BaseRemoteRendezvous { RdmaRemoteRendezvous(const WorkerEnv* env, int64 step_id, RdmaMgr* rdma_mgr) : BaseRemoteRendezvous(env, step_id), rdma_mgr_(rdma_mgr) {} - void RecvPostCopyOps(const string& key, const string& key_with_step_id, - const Rendezvous::Args& recv_args, - const DoneCallback& done, const RdmaMessage& rm, - RdmaChannel* rc, Tensor& val, const Status& s); - protected: void RecvFromRemoteAsync(const Rendezvous::ParsedKey& parsed, const Rendezvous::Args& args, @@ -74,101 +65,18 @@ void RdmaRemoteRendezvous::RecvFromRemoteAsync( RdmaChannel* rc = rdma_mgr_->FindChannel(src_name); string key(std::move(parsed.FullKey().ToString())); string key_with_step_id = VerbsUtil::AppendStepidToKey(key, step_id_); - // insert callback - rc->InsertRecvCallback(key_with_step_id, [this, key, key_with_step_id, rc, - recv_args, parsed, done]() { - Status src_s, dst_s, s; - Device* src_dev, *dst_dev; - src_s = env_->device_mgr->LookupDevice("CPU:0", &src_dev); - dst_s = env_->device_mgr->LookupDevice(parsed.dst_device, &dst_dev); - if (!src_s.ok() || !dst_s.ok()) { - s = src_s.ok() ? dst_s : src_s; - LOG(ERROR) << "s is not ok, error code " << s.error_message(); - done(s, Args(), recv_args, Tensor(), true); - return; - } - RdmaBuffer* rb = rc->FindBuffer(key); - RdmaMessage rm; - CHECK(rb->size_ >= RdmaMessage::kMessageTotalBytes); - RdmaMessage::ParseMessage(rm, rb->buffer_); - CHECK(rm.type_ == RDMA_MESSAGE_TENSOR_WRITE); - Tensor val; - if (!rm.is_dead_) { - void* input = static_cast(rb->buffer_) + - RdmaMessage::kTensorBufferStartIndex; - bool can_memcpy = DataTypeCanUseMemcpy(rm.data_type_); - if (can_memcpy) { - if (dst_dev->tensorflow_gpu_device_info() && - (!recv_args.alloc_attrs.on_host())) { -#if GOOGLE_CUDA - CHECK(recv_args.device_context) - << "send dev name: " << src_dev->name() - << " gpu_info: " << src_dev->tensorflow_gpu_device_info(); - Allocator* alloc = ProcessState::singleton()->GetCUDAHostAllocator(0); - Tensor copy(alloc, rm.data_type_, rm.tensor_shape_); - memcpy(DMAHelper::base(©), input, rm.tensor_bytes_); - - Allocator* dst_alloc = dst_dev->GetAllocator(recv_args.alloc_attrs); - Tensor gpu_copy(dst_alloc, rm.data_type_, rm.tensor_shape_); - - GPUUtil::CopyCPUTensorToGPU( - ©, recv_args.device_context, dst_dev, &gpu_copy, - [this, gpu_copy, key, key_with_step_id, recv_args, done, rm, rc]( - const Status& s) { - CHECK(s.ok()) << "copy tensor to gpu sync"; - Tensor val; - val = std::move(gpu_copy); - RecvPostCopyOps(key, key_with_step_id, recv_args, done, rm, rc, - val, s); - }); -#endif // GOOGLE_CUDA - return; - } else { - AllocatorAttributes host_alloc_attrs; - host_alloc_attrs.set_gpu_compatible(true); - host_alloc_attrs.set_on_host(true); - Allocator* alloc = dst_dev->GetAllocator(host_alloc_attrs); - Tensor copy(alloc, rm.data_type_, rm.tensor_shape_); - memcpy(DMAHelper::base(©), input, rm.tensor_bytes_); - val = std::move(copy); - } - } else { - TensorProto proto; - CHECK(rm.tensor_bytes_ + RdmaMessage::kTensorBufferStartIndex <= - rb->size_); - CHECK(ParseProtoUnlimited(&proto, input, rm.tensor_bytes_)) - << "fail to parse proto from array"; - s = dst_dev->MakeTensorFromProto(proto, recv_args.alloc_attrs, &val); - } - } - RecvPostCopyOps(key, key_with_step_id, recv_args, done, rm, rc, val, s); - }); - // append key to message queue - RdmaBuffer* rb = rc->tx_message_buffer_; - RdmaMessage rm; - rm.type_ = RDMA_MESSAGE_TENSOR_REQUEST; - rm.name_size_ = key.size(); - rm.name_ = key; - rm.step_id_ = step_id_; - string message = RdmaMessage::CreateMessage(rm); - rb->EnqueueItem(message); - rb->SendNextItem(); -} -void RdmaRemoteRendezvous::RecvPostCopyOps( - const string& key, const string& key_with_step_id, - const Rendezvous::Args& recv_args, const DoneCallback& done, - const RdmaMessage& rm, RdmaChannel* rc, Tensor& val, const Status& s) { - rc->RemoveRecvCallback(key_with_step_id); - RdmaMessage br; - br.type_ = RDMA_MESSAGE_BUFFER_IDLE; - br.name_size_ = key.size(); - br.name_ = key; - string message = RdmaMessage::CreateMessage(br); - RdmaBuffer* tb = rc->tx_message_buffer_; - tb->EnqueueItem(message); - tb->SendNextItem(); - done(s, Args(), recv_args, val, rm.is_dead_); + Device* dst_dev; + s = env_->device_mgr->LookupDevice(parsed.dst_device, &dst_dev); + CHECK(s.ok()) << "s is not ok, error code " << s.error_message(); + if (!s.ok()) { + done(s, Args(), recv_args, Tensor(), true); + return; + } + + RdmaTensorRequest* request = + rc->InsertTensorRequest(key, step_id_, dst_dev, recv_args, done); + request->Start(); } RdmaRendezvousMgr::RdmaRendezvousMgr(const WorkerEnv* env) diff --git a/tensorflow/contrib/verbs/rdma_rendezvous_mgr.h b/tensorflow/contrib/verbs/rdma_rendezvous_mgr.h index 2dedd6c48f96a6ecf2b69c757f525ac1bfd6f2d0..c0d6f59c4842e28e37b2a3b45e955f8d92712dd7 100644 --- a/tensorflow/contrib/verbs/rdma_rendezvous_mgr.h +++ b/tensorflow/contrib/verbs/rdma_rendezvous_mgr.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_VERBS_RDMA_RENDEZVOUS_MGR_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_VERBS_RDMA_RENDEZVOUS_MGR_H_ +#ifndef TENSORFLOW_CONTRIB_VERBS_RDMA_RENDEZVOUS_MGR_H_ +#define TENSORFLOW_CONTRIB_VERBS_RDMA_RENDEZVOUS_MGR_H_ #ifdef TENSORFLOW_USE_VERBS @@ -60,4 +60,4 @@ class RdmaRendezvousMgr : public BaseRendezvousMgr { } // end namespace tensorflow #endif // TENSORFLOW_USE_VERBS -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_VERBS_RDMA_RENDEZVOUS_MGR_H_ +#endif // TENSORFLOW_CONTRIB_VERBS_RDMA_RENDEZVOUS_MGR_H_ diff --git a/tensorflow/contrib/verbs/verbs_server_lib.cc b/tensorflow/contrib/verbs/verbs_server_lib.cc index a606ef75a42069b3c32eb13a69e981a5c4c8f83c..47ed83f521c5e6165c906ea557e74faf27df2112 100644 --- a/tensorflow/contrib/verbs/verbs_server_lib.cc +++ b/tensorflow/contrib/verbs/verbs_server_lib.cc @@ -104,6 +104,7 @@ Status VerbsServer::Start() { [this] { verbs_service_->HandleRPCsLoop(); })); rdma_mgr_->SetupChannels(); CHECK(rdma_mgr_->ConnectivityCheck()) << "Connectivity check failed!"; + rdma_mgr_->InitAllocators(); verbs_state_ = CONNECTED; } } diff --git a/tensorflow/contrib/verbs/verbs_server_lib.h b/tensorflow/contrib/verbs/verbs_server_lib.h index 855380129f21bd8162cdf28a4d88c098db7ddc55..54ce8c1d47737f4da742925f99e3d1cd73160ffb 100644 --- a/tensorflow/contrib/verbs/verbs_server_lib.h +++ b/tensorflow/contrib/verbs/verbs_server_lib.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_VERBS_VERBS_SERVER_LIB_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_VERBS_VERBS_SERVER_LIB_H_ +#ifndef TENSORFLOW_CONTRIB_VERBS_VERBS_SERVER_LIB_H_ +#define TENSORFLOW_CONTRIB_VERBS_VERBS_SERVER_LIB_H_ #ifdef TENSORFLOW_USE_VERBS @@ -63,4 +63,4 @@ class VerbsServer : public GrpcServer { } // namespace tensorflow #endif // TENSORFLOW_USE_VERBS -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_VERBS_VERBS_SERVER_LIB_H_ +#endif // TENSORFLOW_CONTRIB_VERBS_VERBS_SERVER_LIB_H_ diff --git a/tensorflow/contrib/verbs/verbs_service.proto b/tensorflow/contrib/verbs/verbs_service.proto index 0df1fed4b9de81d7d99be3de9fba4be8b88ad404..abdae1d84f74b076bb5f457d0cf6f74bf07d75b4 100644 --- a/tensorflow/contrib/verbs/verbs_service.proto +++ b/tensorflow/contrib/verbs/verbs_service.proto @@ -50,6 +50,12 @@ message GetRemoteAddressResponse { repeated MemoryRegion mr = 3; } +message ErrorStatusProto { + int32 error_code = 1; + string error_message = 2; + string error_details = 3; +} + //////////////////////////////////////////////////////////////////////////////// // // VerbsService diff --git a/tensorflow/contrib/verbs/verbs_with_0_copies.png b/tensorflow/contrib/verbs/verbs_with_0_copies.png new file mode 100644 index 0000000000000000000000000000000000000000..0641e2fd50da3738e3b8113f4324156063bf52a2 Binary files /dev/null and b/tensorflow/contrib/verbs/verbs_with_0_copies.png differ diff --git a/tensorflow/contrib/verbs/verbs_with_0_copies.xml b/tensorflow/contrib/verbs/verbs_with_0_copies.xml new file mode 100644 index 0000000000000000000000000000000000000000..16130a961ba5185c415463ccb9636d010fe05e68 --- /dev/null +++ b/tensorflow/contrib/verbs/verbs_with_0_copies.xml @@ -0,0 +1 @@ +7Vxtc9o4EP41zKQfmsGW3/hIgPQ60/RyIZ1rPzHClsFXY1FZEOivP8mW8ZsAB2yXtHQ6jb2SJXl3n0e7K6cdMFhsPhC4nD9gB/kdtetsOmDYUVVFUw32g0u2scRUu7FgRjxHdEoFY+8nEsKk28pzUJjrSDH2qbfMC20cBMimORkkBL/ku7nYz8+6hDNUEoxt6Jel/3oOnQup0u2mDX8hbzYXU1u6aJhC+/uM4FUg5uuowI3+xM0LmIwl+odz6OCXjAiMOmBAMKbx1WIzQD7XbaK2+Ln7Pa27dRMU0CoP6CB+Yg39FUqWHC2MbhNlRK+D+APdDrh7mXsUjZfQ5q0vzPxMNqcLn90p7NL1fH+AfUzYfYAD1ulOzIAIRZu9y1R2L8+cCuEFomTLumx2mo8fEf5kiduX1DhWIptn7GIkQigcYrYbOlUKuxB6kevIkqjI8NkMd463zqnK+LHihrtjL0rfQ9+bBR3QZz185NK0lV3NxM9olHAJg0Q2ppDQm3dJE1tatjUjjqbOS+tfTSKbEskK2lqYcsua+r6PbUgRJwIUhJiEt03OqfI5xyhwbp5Hn8d/P02eRv98GY2f37VhAam2c7PVBVDGTg5Elmtzu1OCv6NMi2FbaOru5iuBVQLp/fgFefwqRhnAalcC4B3jngPghGxrR/ATstfPkTs+IAqHkMKbC/GQ2oD3ZenEsGNah+/ZNeTbLrTnqHkAPiHYLuxlHAjKVCBjg8q0+XsBGahtAlkxjkcryGGRnLjFhM7xDAfQH6XSu7yWMxr9D1G6FcEoXFHMROkInzBein579RjiFbGTEFIsjW3oM5R002IZX+NBbRPkQ+qt89HoWZpTG6LAiCQGBMUgJShc4iBshRvs9SfGDX4/3AZ2s7QbUcAAL7dRGsKvH79wyCPisWd/8hf/6EZv/2PlEeTsffs3B3dT+aX7tv4r0M20RbZfszff+GC3Or/dePRr7u6bmKgaJ2hlTkhS5fo4QTz6iD22lJ0pe728KU2jYKF4UeKp1Eh9QuA2023JO4QH5jELO4RZyECP9E9SttRH4hWkHrPTSTXm00rMd++RkD57Cw7cjjngf1UDLjjggmm4LPJGjN4kwhvMYTBDDmMccF8V53O8mK7C4xh3GHvY1MOMjYbMbzjCLgP3LW/zvePbfDiHS37p+mjT5xWfCKyOuBzaPgxDzz5Ioa5lI9vOIr5bRnxJzVNL1/TKiLckQYBaEfAZXesSVSeyM3kBFCI6tcgL8euUeKE0kNbt3jK/MUxLUSxD10wjP65SjW9OgHjiHm35y5k+Id0FuhflFIbSu1WtHlAVy1KApqs5U2pFU1Z1kcPDgoo70ikeUi5zfkNhyUl4OJh3gbypRUFTUuMUMeTQZoZHTH7Hudbj4aloWHiOE8Unsi0gH7M0wd+gzN+axH3UOqK28ob7Gf++qraK8U6bqpblw00VQqJM75GFyfI0r61yMM/+5MFadgVPwwc+xAvxorz0Jq7SdaITI8qM/e61S3/zqZsh8cvmQjigH9+S28zlRMK2i+2q7tWqKdmrW8rYrKIFtWr74/6M7Zwd1CwZdFcaztDBm4eJ1mqFA1Q4fn0jmU6yn+WQYlZESjtRrVpIdTTzxDi2OJBe9IWaaimleZR6ayOgQj29EfeTlNbOdT+j7H7gstzvcPZjgkaSqtIXEPUlVaC8JdR9rDqIo7Xf6VRVUlqMI2uCN9soQOXnDPcOyh4veJWOF0oDyyLj6PTkY7BmYGMXDs+p+IGu7/NPl45Fxa9S0ZEz1aHkdJf7aeBE77rAayReGoX0tEzjzQfxxePWdoP4JN6UAHyuVIKAyNFlCEeMSEnGUHzEPXY6vVbAXMX2ghkT6Ondc5QevFf3GZ05HnH9aHebe46DgibKBoqp5yzbKxt299Fb1rDFHOAku6pN2ZV/J/EnW9U0jlq115RRZamEYO0iL7o4CiDsHWmldgSu2+1y8ihBdvjQnzyMxuP+h9Ek/1Vcxt7xyCUWnjbgBRdWBwRWnqoVyTeq0uiyjkKgVq65Nmf8h9FzfzLss3+eRuPHvz+PR1cHkDgA0Np0AFm9rXH0XwnggP21Vglg/0lALfYv1s+vFpdY3NDbtHhT2XfNSXUi+LhYxP2ruCF3Qv47M8VRBQO9yvsOJYiyVLLm9773kO+E+VfPLpBulyzPHaSp7sRjXrqJRQFciMaQouWE2V70XGCKJtBxiBB8R9v4in+mPYk/0z6SGZ9w6nUMuTwvNqaGbnTKNUDXVaMa4KVh2MhjWJV86QRkGbZeB4ab/tWihjAsg1m31PvPBbpUPweBfoXtebAFkmCrOdjKvk+8wvYPhO3r9ucrsk9Att7mhpyMcUV2ceqC818+viueVF1xtwd3hqR2XRfu2G36XxzEJ8/p/yMBRv8D \ No newline at end of file diff --git a/tensorflow/contrib/verbs/verbs_with_0_copies_phase1_protocol.jpg b/tensorflow/contrib/verbs/verbs_with_0_copies_phase1_protocol.jpg new file mode 100644 index 0000000000000000000000000000000000000000..8bc69b889d68f0a6a8faa64f08eab22637646842 Binary files /dev/null and b/tensorflow/contrib/verbs/verbs_with_0_copies_phase1_protocol.jpg differ diff --git a/tensorflow/contrib/verbs/verbs_with_0_copies_phase1_protocol.xml b/tensorflow/contrib/verbs/verbs_with_0_copies_phase1_protocol.xml new file mode 100644 index 0000000000000000000000000000000000000000..484e7c78ae863b543dee360969e2ef1c8ae6114d --- /dev/null +++ b/tensorflow/contrib/verbs/verbs_with_0_copies_phase1_protocol.xml @@ -0,0 +1 @@ +7Vxbc5s4FP41nuk+pMMd8ujYTrYzTZqNk9nmyaOAbNhgRIWc2P31K4G4yzZ1AHtaZzJjOBLS8TnnOzeRDNTRcn2DQejeIgf6A0Vy1gN1PFAUWVMM+sEom4RiGpcJYYE9h0/KCVPvJ+REiVNXngOj0kSCkE+8sEy0URBAm5RoAGP0Xp42R3551xAsYI0wtYFfp/7rOcTlVFmS8oG/obdw+daWzgdegP26wGgV8P0GijqPf5LhJUjX4vMjFzjovUBSJwN1hBEiydVyPYI+k20qtuS56y2jGd8YBqTJA1bywBvwVzDl2PDpo1eO98b4IxsuE+PHijF1ReCaXADfWwQDdUhn+HBO8lF6teCf8SpRCIKUNiUAk09/pUOUqeJogRxvXaa2z01Ke8ECDnpkDCxDehG8ROzjgs4c+j6yAYGPMIgQjkoC64WBKQycT4+Tu+m3h9nD5J+nyfSxax62a6K0m1LaRYlxBpklS3T43fUInIbAZqPv1C9RmkuWPr2Ts6ffIKZMbQWLnEGQujaIlpDgDZ3CH7A4aLlTkw1+/567CCX1EG7BO2RuA3C3tMiWzqFJLzg6xUhNPUbrUH2A9ltia7eQgDEgoHOTa6juNnZjBj2GoE9M1UHc/X4xZq+erq8nDLPT+29308nWTU8LRqrSJ4xkQwCjikCgQ5MBfoswcdECBcCf5NSrssgK4vkPErLh+QxYEURJ+QpfEQpLYmQb7RYi5QutsJ3O4rzSQLqA6TRNLGwMfUC8t/L6H5Kc0pEDivHiON/wU+hQyDzAKERBBLsHDfN8XylM/WG0Cezu9/syj/XyY+VhZjvDPgP7gKGsJRL7LiMUbm7unx7R6F6UXT2dUp7XqSCmEHuUsZ/wqN/45HMnQztq8qQfw8lT0eDN9+LNM1vss85u1x75Xrp75hsdGBq0emhIqvAPhAb+6D3y6M6ZKi+VsipNo6KhhAf+VK6kIcZgU5gWsglR831Us1LL7plvWLvnW9rO+fQi4Ti3sEyGzQKmVguY1x6OyKO3hMyZmCP2W/MyBbeQYDdNy0cuCBbQoX5GvW6KchctX1bRfoQ7NCTZxEPU2YypWTFEdoH6nnO9y/25XuSCkF3Ofbgess5RDFWHX45tH0SRZ5eFNfd8f4R8hOMl0gZPAe9SHe+HodoS5HuKWOIFieoCgaa0D2K/mrwrVewn3NewX1tI1aXPpiwZpiXLlqFrplFeV27mUw6AZWoEfVlFi/5cGhxR9bpx+VmxLlVFtixZ1XSlpDCtqrCmhrB7WbVhbDnEDtSaHTzDqGYKLA8rKzoiGL3CVNUBCmBF+5zEk7exTfUMKf2KuVKPlRt8YOk5TpxpiJxzOfvowherdV+sCcxHaSP/qofCO/T7itqSjihqUYOjq+KKFUD3KBSW7H2VQBfbUqgiAw/jW7bCO6bap5+fkr7cID5BIlSrv8r4iVVThsDAusurVH1/BO2zvOI1VJZwHRx0FVMQdLsposyqBrVmgW57EfWRUGjWFLqlIXdidq/12kVQ6xlDTSKnXU+kAaZk4KZY5P1klXKloNDMA/PI6kJ6VeMtdSVq+8jtdg3UBgcUnRiZoEl1oJEZdSNTj2pku2sMU+2kdEnbSR2ULmrdX7d9FjxK8qLf6ShY0Fo78FSmvyOWETtiubl/aqSHftgaw0h45LGbq/hJWqzte+TEEm3rmHl2muwIYO7KjYDA62ERziF1p7ggdbbiFqEfXpfTUsr2ggUl6PndY5zBXyjbNIgoY3M/jmQuLdth0E2JXlLsZUO9VrP0g9QqakC2olb2FsifrNRqedCrVkXFAQ9vVSc3R3EaYWfpWK5ImphJEuOxBtnx7XB2O5lOhzeTWfntvILCk5VrLvWlAzM4sZ5b1mNLD5gtgfJFOWYbTTet3t/sTvnZa15n5W9Tvqr1qXxRO6xz5Sfv+J21L9C+1iv0t/fbW9F+loLHK1T71tloVe1na8hydr1Pa+iqMm+54E4JX5bLZH4TE2UGyv6Spboq902/ZH27akDRAUzL3/vag74Tlb96kUGyCeFAGfHOAIzIzKNWuk5IAVjywYjAcEZ1z2cuEYEz4DiYE17hJrmidgRmDiBgb/F7wOHTPuRSzTnGi6EbA1EXULHtE8SwXMawInhvSBVl8nobGO76j6I6wrAIZlJt9p8LdKF8dgL9DNuPwVYVJGLdwVb0tt8Ztn8gbD8Un7e+kXvG/i9hX+8zZKdrnLFf3boCj9P3AA2PAs+424I7Q9D0bgt39Db/1wTJuXX+/x/Uyf8= \ No newline at end of file diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 579174efa31bf62580feff3e46ae2290826b2add..94973a0e520e494ce2ccc947a803e10681ff5e21 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -136,6 +136,8 @@ load( "tf_nano_proto_library", "tf_protos_all", "tf_protos_all_impl", + "tf_protos_grappler", + "tf_protos_grappler_impl", ) load( "//tensorflow/core:platform/default/build_config_root.bzl", @@ -431,6 +433,7 @@ tf_cuda_library( "framework/cancellation.h", "framework/common_shape_fns.h", "framework/control_flow.h", # TODO(josh11b): Make internal? + "framework/dataset.h", "framework/device_base.h", "framework/function.h", "framework/graph_def_util.h", @@ -592,6 +595,7 @@ cc_library( tf_gen_op_libs( is_external = False, op_lib_names = [ + "batch_ops", "bitwise_ops", "candidate_sampling_ops", "checkpoint_ops", @@ -673,6 +677,7 @@ cc_library( deps = [ ":array_ops_op_lib", ":audio_ops_op_lib", + ":batch_ops_op_lib", ":bitwise_ops_op_lib", ":candidate_sampling_ops_op_lib", ":checkpoint_ops_op_lib", @@ -808,6 +813,7 @@ cc_library( deps = [ "//tensorflow/core/kernels:array", "//tensorflow/core/kernels:audio", + "//tensorflow/core/kernels:batch_kernels", "//tensorflow/core/kernels:bincount_op", "//tensorflow/core/kernels:candidate_sampler_ops", "//tensorflow/core/kernels:checkpoint_ops", @@ -1068,8 +1074,8 @@ cc_library( ":protos_all_cc_impl", "//third_party/eigen3", "//third_party/fft2d:fft2d_headers", - "@fft2d//:fft2d", - "@gemmlowp//:gemmlowp", + "@fft2d", + "@gemmlowp", "@protobuf_archive//:protobuf", ], alwayslink = 1, @@ -1318,6 +1324,13 @@ tf_pyclif_proto_library( visibility = ["//visibility:public"], ) +tf_pyclif_proto_library( + name = "framework/function_pyclif", + proto_lib = ":protos_all_cc", + proto_srcfile = "framework/function.proto", + visibility = ["//visibility:public"], +) + tf_pyclif_proto_library( name = "framework/graph_pyclif", proto_lib = ":protos_all_cc", @@ -1529,7 +1542,7 @@ cc_library( "@snappy", "@zlib_archive//:zlib", "@protobuf_archive//:protobuf", - ] + tf_protos_all_impl(), + ] + tf_protos_all_impl() + tf_protos_grappler_impl(), ) # File compiled with extra flags to get cpu-specific acceleration. @@ -2094,7 +2107,7 @@ tf_cuda_library( ":core_cpu_base", ":proto_text", "//tensorflow/core/grappler:grappler_item", - ] + if_static([":core_cpu_impl"]) + tf_protos_all(), + ] + if_static([":core_cpu_impl"]) + tf_protos_all() + tf_protos_grappler(), ) tf_cuda_library( @@ -2331,7 +2344,7 @@ cc_library( ":lib_internal", ":proto_text", "//third_party/eigen3", - "@local_config_sycl//sycl:sycl", + "@local_config_sycl//sycl", ], alwayslink = 0, ) diff --git a/tensorflow/core/api_def/base_api/api_def_Batch.pbtxt b/tensorflow/core/api_def/base_api/api_def_Batch.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..aea11b64fdc08576e619616856d9f7cf12392eab --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_Batch.pbtxt @@ -0,0 +1,42 @@ +op { + graph_op_name: "Batch" + summary: "Batches all input tensors nondeterministically." + description: <& ops, } } // namespace -// Returns ApiDef text representation in multi-line format +// Returns ApiDefs text representation in multi-line format // constructed based on the given op. string CreateApiDef(const OpDef& op) { - ApiDef api_def; - FillBaseApiDef(&api_def, op); + ApiDefs api_defs; + FillBaseApiDef(api_defs.add_op(), op); const std::vector multi_line_fields = {"description"}; - string new_api_defs_str = api_def.DebugString(); + string new_api_defs_str = api_defs.DebugString(); return PBTxtToMultiline(new_api_defs_str, multi_line_fields); } diff --git a/tensorflow/core/api_def/update_api_def.h b/tensorflow/core/api_def/update_api_def.h index 5eae7e528efae43d533d76f2ca96d6a016a63961..1e285c06883efa9e8952339f952e341a5bee7406 100644 --- a/tensorflow/core/api_def/update_api_def.h +++ b/tensorflow/core/api_def/update_api_def.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_API_DEF_UPDATE_API_DEF_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_API_DEF_UPDATE_API_DEF_H_ +#ifndef TENSORFLOW_CORE_API_DEF_UPDATE_API_DEF_H_ +#define TENSORFLOW_CORE_API_DEF_UPDATE_API_DEF_H_ // Functions for updating ApiDef when new ops are added. #include "tensorflow/core/framework/op_def.pb.h" @@ -21,7 +21,7 @@ limitations under the License. namespace tensorflow { -// Returns ApiDef text representation in multi-line format +// Returns ApiDefs text representation in multi-line format // constructed based on the given op. string CreateApiDef(const OpDef& op); @@ -42,4 +42,4 @@ void CreateApiDefs(const OpList& ops, const string& api_def_dir, const string& op_file_pattern); } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_API_DEF_UPDATE_API_DEF_H_ +#endif // TENSORFLOW_CORE_API_DEF_UPDATE_API_DEF_H_ diff --git a/tensorflow/core/api_def/update_api_def_test.cc b/tensorflow/core/api_def/update_api_def_test.cc index 8948f2c1d5b9f03d418bc11d6481b2b98cb37693..4200c9da23c09335d8edca217f68b2ae5d8c2bdf 100644 --- a/tensorflow/core/api_def/update_api_def_test.cc +++ b/tensorflow/core/api_def/update_api_def_test.cc @@ -173,30 +173,32 @@ description: "Description\nfor Op1." OpDef op; protobuf::TextFormat::ParseFromString(op_text, &op); // NOLINT - const string expected_api_def = R"(graph_op_name: "Op1" -in_arg { - name: "a" - description: <debug_options; } - std::shared_ptr ek(new ExecutorsAndKeys); std::unique_ptr func_info(new FunctionInfo); + std::shared_ptr ek(new ExecutorsAndKeys); // The executor_lock_ is intentionally released while executor is // being created. diff --git a/tensorflow/core/common_runtime/function_testlib.h b/tensorflow/core/common_runtime/function_testlib.h index 0bf6699f5aa13b7f125f7f3bb2c1781c90ee9ed9..3ddb26de929dc19792142dffde345672aafaadce 100644 --- a/tensorflow/core/common_runtime/function_testlib.h +++ b/tensorflow/core/common_runtime/function_testlib.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_TESTLIB_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_TESTLIB_H_ +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_TESTLIB_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_TESTLIB_H_ #include "tensorflow/cc/framework/scope.h" #include "tensorflow/core/framework/function.h" @@ -34,4 +34,4 @@ Output Call(Scope* scope, const string& op_name, const string& fn_name, } // namespace test } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_TESTLIB_H_ +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_TESTLIB_H_ diff --git a/tensorflow/core/common_runtime/gpu/gpu_device.cc b/tensorflow/core/common_runtime/gpu/gpu_device.cc index 0e5b6b7ef87f67bcb0b46d6e0acec82f8612b80f..933d700f6042bf51f11f773d731cece6ef5af436 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_device.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_device.cc @@ -762,7 +762,7 @@ int64 MinSystemMemory(int64 available_memory) { // is necessary. min_system_memory *= 2; #endif -#if defined(NVIDIA_TEGRA) +#if defined(ANDROID_TEGRA) // 1GB system mem for NVIDIA Tegra devices since they use the same mem for RAM and Video RAM min_system_memory = 1<<30; #endif diff --git a/tensorflow/core/common_runtime/gpu/gpu_id.h b/tensorflow/core/common_runtime/gpu/gpu_id.h index ff81ccd4325e0ad22636cd78ba99e0bff6a03347..4e9c4abce1264d0533c10c1d4dcfcc3f1455e727 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_id.h +++ b/tensorflow/core/common_runtime/gpu/gpu_id.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_ID_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_ID_H_ +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_ID_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_ID_H_ #include "tensorflow/core/lib/gtl/int_type.h" @@ -85,4 +85,4 @@ TF_LIB_GTL_DEFINE_INT_TYPE(CudaGpuId, int32); } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_ID_H_ +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_ID_H_ diff --git a/tensorflow/core/common_runtime/gpu/gpu_id_utils.h b/tensorflow/core/common_runtime/gpu/gpu_id_utils.h index 78e51c84c146693dfc02ce445bda030797de6c07..6d196b16eddfb4b77db97cd098538a7e1f7cae5b 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_id_utils.h +++ b/tensorflow/core/common_runtime/gpu/gpu_id_utils.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_ID_UTILS_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_ID_UTILS_H_ +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_ID_UTILS_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_ID_UTILS_H_ #include "tensorflow/core/common_runtime/gpu/gpu_id.h" #include "tensorflow/core/common_runtime/gpu/gpu_init.h" @@ -58,4 +58,4 @@ class GpuIdUtil { } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_ID_UTILS_H_ +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_ID_UTILS_H_ diff --git a/tensorflow/core/common_runtime/gpu/gpu_managed_allocator.h b/tensorflow/core/common_runtime/gpu/gpu_managed_allocator.h index 006b2ca44817a37dd7d88018d6f1edef18f07787..2d49a64c0fd93bde2f9ddf4503ea9adc97571b5d 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_managed_allocator.h +++ b/tensorflow/core/common_runtime/gpu/gpu_managed_allocator.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_MANAGED_ALLOCATOR_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_MANAGED_ALLOCATOR_H_ +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_MANAGED_ALLOCATOR_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_MANAGED_ALLOCATOR_H_ #include "tensorflow/core/framework/allocator.h" @@ -33,4 +33,4 @@ class GpuManagedAllocator : public Allocator { } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_MANAGED_ALLOCATOR_H_ +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_MANAGED_ALLOCATOR_H_ diff --git a/tensorflow/core/common_runtime/graph_optimizer.h b/tensorflow/core/common_runtime/graph_optimizer.h index 8f3a0821346f7485bc82e0f7a29076abdce7d4e9..8477cea126f1808d9472bd4f4127fd43e172848e 100644 --- a/tensorflow/core/common_runtime/graph_optimizer.h +++ b/tensorflow/core/common_runtime/graph_optimizer.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_OPTIMIZER_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_OPTIMIZER_H_ +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_OPTIMIZER_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_OPTIMIZER_H_ #include "tensorflow/core/framework/function.h" #include "tensorflow/core/graph/graph.h" @@ -60,4 +60,4 @@ class GraphOptimizer { } // end namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_OPTIMIZER_H_ +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_OPTIMIZER_H_ diff --git a/tensorflow/core/common_runtime/memory_types.h b/tensorflow/core/common_runtime/memory_types.h index fa0a7595f32ac8bb43010dcd3a407825ef79f618..f854acfdc55d66c1ffa93acc0954edae393b2359 100644 --- a/tensorflow/core/common_runtime/memory_types.h +++ b/tensorflow/core/common_runtime/memory_types.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_MEMORY_TYPES_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_MEMORY_TYPES_H_ +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_MEMORY_TYPES_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_MEMORY_TYPES_H_ #include "tensorflow/core/framework/memory_types.h" #include "tensorflow/core/graph/graph.h" @@ -45,4 +45,4 @@ Status MemoryTypeForOutput(const DeviceType& device_type, const Graph* g, } // end namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_MEMORY_TYPES_H_ +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_MEMORY_TYPES_H_ diff --git a/tensorflow/core/common_runtime/pending_counts.h b/tensorflow/core/common_runtime/pending_counts.h index 5707f5259228c0e54d6d858652a8c50986c0c49b..5e1925c40167fca0abe534e95bed487c77cd2215 100644 --- a/tensorflow/core/common_runtime/pending_counts.h +++ b/tensorflow/core/common_runtime/pending_counts.h @@ -1,5 +1,5 @@ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_PENDING_COUNTS_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_PENDING_COUNTS_H_ +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_PENDING_COUNTS_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_PENDING_COUNTS_H_ /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. @@ -328,4 +328,4 @@ inline PendingCounts::Handle PendingCounts::Layout::CreateHandle( } // end namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_PENDING_COUNTS_H_ +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_PENDING_COUNTS_H_ diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.h b/tensorflow/core/common_runtime/process_function_library_runtime.h index 38003b772630221f3681866309a1a83a526eb95c..a1adc4b6b35950339b727774c45014ef71839554 100644 --- a/tensorflow/core/common_runtime/process_function_library_runtime.h +++ b/tensorflow/core/common_runtime/process_function_library_runtime.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_PROCESS_FUNCTION_LIBRARY_RUNTIME_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_PROCESS_FUNCTION_LIBRARY_RUNTIME_H_ +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_PROCESS_FUNCTION_LIBRARY_RUNTIME_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_PROCESS_FUNCTION_LIBRARY_RUNTIME_H_ #include @@ -173,4 +173,4 @@ class ProcessFunctionLibraryRuntime { } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_PROCESS_FUNCTION_LIBRARY_RUNTIME_H_ +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_PROCESS_FUNCTION_LIBRARY_RUNTIME_H_ diff --git a/tensorflow/core/common_runtime/profile_handler.h b/tensorflow/core/common_runtime/profile_handler.h index 57c83c2e6f3c281c83c2596d3ca83dca221d5965..9d31b1aecbce210e8409db66aeb20a8e9245d9bc 100644 --- a/tensorflow/core/common_runtime/profile_handler.h +++ b/tensorflow/core/common_runtime/profile_handler.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_PROFILE_HANDLER_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_PROFILE_HANDLER_H_ +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_PROFILE_HANDLER_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_PROFILE_HANDLER_H_ #include "tensorflow/core/framework/step_stats.pb.h" #include "tensorflow/core/graph/types.h" @@ -80,4 +80,4 @@ class ProfileHandler { } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_PROFILE_HANDLER_H_ +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_PROFILE_HANDLER_H_ diff --git a/tensorflow/core/common_runtime/renamed_device.h b/tensorflow/core/common_runtime/renamed_device.h index c5c204d4faff8c5016cc0a48fec266b06409b668..fe4df1c106c5a86d4a9cdb73bafed7f4431e76a0 100644 --- a/tensorflow/core/common_runtime/renamed_device.h +++ b/tensorflow/core/common_runtime/renamed_device.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_RENAMED_DEVICE_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_RENAMED_DEVICE_H_ +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_RENAMED_DEVICE_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_RENAMED_DEVICE_H_ #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/util/device_name_utils.h" @@ -134,4 +134,4 @@ class RenamedDevice : public Device { } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_RENAMED_DEVICE_H_ +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_RENAMED_DEVICE_H_ diff --git a/tensorflow/core/common_runtime/rendezvous_util.h b/tensorflow/core/common_runtime/rendezvous_util.h index 3b6354603b2925dd7a1d2abe34308e9c8865f6bb..aad910f6d800f0043fba0fbad43801fd3b0ba914 100644 --- a/tensorflow/core/common_runtime/rendezvous_util.h +++ b/tensorflow/core/common_runtime/rendezvous_util.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_RENDEZVOUS_UTIL_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_RENDEZVOUS_UTIL_H_ +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_RENDEZVOUS_UTIL_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_RENDEZVOUS_UTIL_H_ #include @@ -49,4 +49,4 @@ Status RecvOutputsFromRendezvous(Rendezvous* rendezvous, NamedTensors* out, } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_RENDEZVOUS_UTIL_H_ +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_RENDEZVOUS_UTIL_H_ diff --git a/tensorflow/core/common_runtime/shape_refiner.h b/tensorflow/core/common_runtime/shape_refiner.h index da42c30ce949dbc3a953d20d0ff3333b6ba1b1d5..75eb5bf0d2972e6bccdd9c2c265f3494821210cc 100644 --- a/tensorflow/core/common_runtime/shape_refiner.h +++ b/tensorflow/core/common_runtime/shape_refiner.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_SHAPE_REFINER_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_SHAPE_REFINER_H_ +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_SHAPE_REFINER_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_SHAPE_REFINER_H_ #include @@ -303,4 +303,4 @@ class ShapeRefiner { } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_SHAPE_REFINER_H_ +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_SHAPE_REFINER_H_ diff --git a/tensorflow/core/common_runtime/stats_publisher_interface.h b/tensorflow/core/common_runtime/stats_publisher_interface.h index b285420798761d70822f94afd622afbd1c2b5e0e..f063ee5297deed168abf9807792a0342dcf5f963 100644 --- a/tensorflow/core/common_runtime/stats_publisher_interface.h +++ b/tensorflow/core/common_runtime/stats_publisher_interface.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_STATS_PUBLISHER_INTERFACE_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_STATS_PUBLISHER_INTERFACE_H_ +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_STATS_PUBLISHER_INTERFACE_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_STATS_PUBLISHER_INTERFACE_H_ #include "tensorflow/core/common_runtime/build_graph_options.h" #include "tensorflow/core/common_runtime/profile_handler.h" @@ -61,4 +61,4 @@ std::unique_ptr CreateNoOpStatsPublisher( } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_STATS_PUBLISHER_INTERFACE_H_ +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_STATS_PUBLISHER_INTERFACE_H_ diff --git a/tensorflow/core/distributed_runtime/BUILD b/tensorflow/core/distributed_runtime/BUILD index 2db7ebd7952c9e1edf374267ee33f697eb846885..f4ee841032bf2b78b70fd446a6e4679bd9c943f1 100644 --- a/tensorflow/core/distributed_runtime/BUILD +++ b/tensorflow/core/distributed_runtime/BUILD @@ -556,3 +556,47 @@ tf_cuda_cc_test( "//tensorflow/core/kernels:array", ], ) + +cc_library( + name = "request_id", + srcs = ["request_id.cc"], + hdrs = ["request_id.h"], + deps = [ + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + ], +) + +tf_cc_test( + name = "request_id_test", + size = "small", + srcs = ["request_id_test.cc"], + deps = [ + ":request_id", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +cc_library( + name = "recent_request_ids", + srcs = ["recent_request_ids.cc"], + hdrs = ["recent_request_ids.h"], + deps = [ + "//tensorflow/core:lib", + "//tensorflow/core:worker_proto_cc", + ], +) + +tf_cc_test( + name = "recent_request_ids_test", + size = "small", + srcs = ["recent_request_ids_test.cc"], + deps = [ + ":recent_request_ids", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:worker_proto_cc", + ], +) diff --git a/tensorflow/core/distributed_runtime/cluster_function_library_runtime.h b/tensorflow/core/distributed_runtime/cluster_function_library_runtime.h index 3deb80dff79e7f54684b39d4bd17a63b99836eab..d3ca350e3659ffa9f8248d2be80a1b1f0303addc 100644 --- a/tensorflow/core/distributed_runtime/cluster_function_library_runtime.h +++ b/tensorflow/core/distributed_runtime/cluster_function_library_runtime.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_CLUSTER_FUNCTION_LIBRARY_RUNTIME_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_CLUSTER_FUNCTION_LIBRARY_RUNTIME_H_ +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_CLUSTER_FUNCTION_LIBRARY_RUNTIME_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_CLUSTER_FUNCTION_LIBRARY_RUNTIME_H_ #include "tensorflow/core/distributed_runtime/worker_interface.h" #include "tensorflow/core/distributed_runtime/worker_session.h" @@ -74,4 +74,4 @@ class ClusterFunctionLibraryRuntime : public DistributedFunctionLibraryRuntime { } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_CLUSTER_FUNCTION_LIBRARY_RUNTIME_H_ +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_CLUSTER_FUNCTION_LIBRARY_RUNTIME_H_ diff --git a/tensorflow/core/distributed_runtime/local_master.h b/tensorflow/core/distributed_runtime/local_master.h index 5fc21d3a1e25faa5f6478914c69a3d513b50530c..c20b40329ab1712b3dd0cae673d337481ee40196 100644 --- a/tensorflow/core/distributed_runtime/local_master.h +++ b/tensorflow/core/distributed_runtime/local_master.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_LOCAL_MASTER_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_LOCAL_MASTER_H_ +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_LOCAL_MASTER_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_LOCAL_MASTER_H_ #include @@ -98,4 +98,4 @@ class LocalMaster : public MasterInterface { } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_LOCAL_MASTER_H_ +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_LOCAL_MASTER_H_ diff --git a/tensorflow/core/distributed_runtime/message_wrappers.h b/tensorflow/core/distributed_runtime/message_wrappers.h index 7113d73dd77c6141c904388b3fb9a28c7561daf2..79fa6f926ea6afb351eacf279d3cf493b6d4713f 100644 --- a/tensorflow/core/distributed_runtime/message_wrappers.h +++ b/tensorflow/core/distributed_runtime/message_wrappers.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MESSAGE_WRAPPERS_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MESSAGE_WRAPPERS_H_ +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MESSAGE_WRAPPERS_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MESSAGE_WRAPPERS_H_ #include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/framework/cost_graph.pb.h" @@ -702,4 +702,4 @@ class NonOwnedProtoRunStepResponse : public MutableRunStepResponseWrapper { } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW +#endif // TENSORFLOW diff --git a/tensorflow/core/distributed_runtime/partial_run_mgr.h b/tensorflow/core/distributed_runtime/partial_run_mgr.h index af56e723a9a7e6710b06943c3806ca3690667810..e95f4da6c30b14b9766ef43bf8ef231a1db91ca8 100644 --- a/tensorflow/core/distributed_runtime/partial_run_mgr.h +++ b/tensorflow/core/distributed_runtime/partial_run_mgr.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_PARTIAL_RUN_MGR_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_PARTIAL_RUN_MGR_H_ +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_PARTIAL_RUN_MGR_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_PARTIAL_RUN_MGR_H_ #include @@ -84,4 +84,4 @@ class PartialRunMgr { } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_PARTIAL_RUN_MGR_H_ +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_PARTIAL_RUN_MGR_H_ diff --git a/tensorflow/core/distributed_runtime/recent_request_ids.cc b/tensorflow/core/distributed_runtime/recent_request_ids.cc new file mode 100644 index 0000000000000000000000000000000000000000..c30879406c6924aa85ad4bf8279b278eaf5d29fd --- /dev/null +++ b/tensorflow/core/distributed_runtime/recent_request_ids.cc @@ -0,0 +1,57 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/distributed_runtime/recent_request_ids.h" + +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { + +RecentRequestIds::RecentRequestIds(int num_tracked_request_ids) + : circular_buffer_(num_tracked_request_ids) { + set_.reserve(num_tracked_request_ids); +} + +Status RecentRequestIds::TrackUnique(int64 request_id, + const string& method_name, + const protobuf::Message& request) { + mutex_lock l(mu_); + if (request_id == 0) { + // For backwards compatibility, allow all requests with request_id 0. + return Status::OK(); + } + if (set_.count(request_id) > 0) { + // Note: RecentRequestIds is not strict LRU because we don't update + // request_id's age in the circular_buffer_ if it's tracked again. Strict + // LRU is not useful here because returning this error will close the + // current Session. + return errors::Aborted("The same ", method_name, + " request was received twice. ", + request.ShortDebugString()); + } + + // Remove the oldest request_id from the set_. circular_buffer_ is + // zero-initialized, and zero is never tracked, so it's safe to do this even + // when the buffer is not yet full. + set_.erase(circular_buffer_[next_index_]); + circular_buffer_[next_index_] = request_id; + set_.insert(request_id); + next_index_ = (next_index_ + 1) % circular_buffer_.size(); + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/recent_request_ids.h b/tensorflow/core/distributed_runtime/recent_request_ids.h new file mode 100644 index 0000000000000000000000000000000000000000..e8e45331dd5a26e2230bb92e8ce73888d3f28505 --- /dev/null +++ b/tensorflow/core/distributed_runtime/recent_request_ids.h @@ -0,0 +1,72 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RECENT_REQUEST_IDS_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RECENT_REQUEST_IDS_H_ + +#include + +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/flatset.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/protobuf/worker.pb.h" + +namespace tensorflow { + +// RecentRequestIds tracks recent 64-bit request_ids. When maximum capacity is +// reached, the oldest request_id is evicted. Thread safe. +// +// Some RPCs like RecvTensor are unsafe to retry. For example, RecvTensor pairs +// one sender and one receiver, and the receiver waits for the sender's tensor. +// Retried RecvTensor requests are problematic, because the original RecvTensor +// request may have consumed the sender's tensor, so a retried request might +// block forever. RecentRequestIds identifies retried requests, so we can fail +// them instead of blocking forever. +// +// Internally, recent request_ids are stored in two data structures: a set and a +// circular buffer. The set is used for efficient lookups, and the circular +// buffer tracks the oldest request_id. When the buffer is full, the new +// request_id replaces the oldest request_id in the circular buffer, and the +// oldest request_id is removed from the set. +class RecentRequestIds { + public: + // num_tracked_request_ids should be much larger than the number of RPCs that + // can be received in a small time window. For example, we observed a peak RPC + // rate of ~700 RecvTensor RPC/s when training inception v3 on TPUs, so we + // currently set num_tracked_request_ids to 100,000 for RecvTensor. + RecentRequestIds(int num_tracked_request_ids); + + // Returns OK iff request_id has not been seen in the last + // num_tracked_request_ids insertions. For backwards compatibility, this + // always returns OK for request_id 0. The method_name and the request's + // ShortDebugString are added to returned errors. + Status TrackUnique(int64 request_id, const string& method_name, + const protobuf::Message& request); + + private: + mutex mu_; + // next_index_ indexes into circular_buffer_, and points to the next storage + // space to use. When the buffer is full, next_index_ points at the oldest + // request_id. + int next_index_ GUARDED_BY(mu_) = 0; + std::vector circular_buffer_ GUARDED_BY(mu_); + gtl::FlatSet set_ GUARDED_BY(mu_); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RECENT_REQUEST_IDS_H_ diff --git a/tensorflow/core/distributed_runtime/recent_request_ids_test.cc b/tensorflow/core/distributed_runtime/recent_request_ids_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..9a0facf5404bb4e6d0d57f55bcd1f2a4f4f99dba --- /dev/null +++ b/tensorflow/core/distributed_runtime/recent_request_ids_test.cc @@ -0,0 +1,96 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/distributed_runtime/recent_request_ids.h" + +#include + +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/protobuf/worker.pb.h" + +namespace tensorflow { + +Status TrackUnique(int64 request_id, RecentRequestIds* recent_request_ids) { + RecvTensorRequest request; + request.set_request_id(request_id); + return recent_request_ids->TrackUnique(request_id, "recent_request_ids_test", + request); +} + +// request_id 0 is always valid. +TEST(RecentRequestIds, Zero) { + RecentRequestIds recent_request_ids(1); + EXPECT_TRUE(TrackUnique(0, &recent_request_ids).ok()); + EXPECT_TRUE(TrackUnique(0, &recent_request_ids).ok()); + EXPECT_TRUE(TrackUnique(0, &recent_request_ids).ok()); +} + +TEST(RecentRequestIds, Unordered) { + // Capacity for 6 numbers. + RecentRequestIds recent_request_ids(6); + + // Some unordered numbers to insert into request_id_set. + std::vector numbers = {53754, 23351, 164101, 7476, + 162432, 130761, 164102}; + + // Insert numbers[0..6) and check that all previously inserted numbers remain + // in the set. + for (int i = 0; i < 6; ++i) { + TF_EXPECT_OK(TrackUnique(numbers[i], &recent_request_ids)); + + for (int j = 0; j <= i; ++j) { + EXPECT_FALSE(TrackUnique(numbers[j], &recent_request_ids).ok()) + << "i=" << i << " j=" << j; + } + } + + // Insert numbers[6]. Inserting this 7th number should evict the first number + // from the set. The set should only contain numbers[1..7). + TF_EXPECT_OK(TrackUnique(numbers[6], &recent_request_ids)); + for (int i = 1; i < 7; ++i) { + EXPECT_FALSE(TrackUnique(numbers[i], &recent_request_ids).ok()) + << "i=" << i; + } + + // Insert numbers[0] again. This should succeed because we just evicted it + // from the set. + TF_EXPECT_OK(TrackUnique(numbers[0], &recent_request_ids)); +} + +// Check that the oldest request_id is evicted. +void TestOrdered(int num_request_ids) { + RecentRequestIds recent_request_ids(num_request_ids); + + // Insert [1..101). The current number and the (num_request_ids - 1) preceding + // numbers should still be in the set. + for (int i = 1; i < 101; ++i) { + TF_EXPECT_OK(TrackUnique(i, &recent_request_ids)); + + for (int j = std::max(1, i - num_request_ids + 1); j <= i; ++j) { + EXPECT_FALSE(TrackUnique(j, &recent_request_ids).ok()) + << "i=" << i << " j=" << j; + } + } +} + +// Test eviction with various numbers of buckets. +TEST(RecentRequestIds, Ordered2) { TestOrdered(2); } +TEST(RecentRequestIds, Ordered3) { TestOrdered(3); } +TEST(RecentRequestIds, Ordered4) { TestOrdered(4); } +TEST(RecentRequestIds, Ordered5) { TestOrdered(5); } + +} // namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/request_id.cc b/tensorflow/core/distributed_runtime/request_id.cc new file mode 100644 index 0000000000000000000000000000000000000000..230c6f9601355d4f6e904f4c3a762cd9d44f72c9 --- /dev/null +++ b/tensorflow/core/distributed_runtime/request_id.cc @@ -0,0 +1,30 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/distributed_runtime/request_id.h" + +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +int64 GetUniqueRequestId() { + int64 request_id = 0; + while (request_id == 0) { + request_id = random::New64(); + } + return request_id; +} + +} // namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/request_id.h b/tensorflow/core/distributed_runtime/request_id.h new file mode 100644 index 0000000000000000000000000000000000000000..a882b69ab16bea32c0f0fae394a8cce5dc469d27 --- /dev/null +++ b/tensorflow/core/distributed_runtime/request_id.h @@ -0,0 +1,31 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_REQUEST_ID_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_REQUEST_ID_H_ + +#include "tensorflow/core/lib/random/random.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +// Returns a request_id for use with RecentRequestIds. This number will not be +// zero, and must be unique over RecentRequestIds' window of +// num_tracked_request_ids. See recent_request_ids.h for more details. +int64 GetUniqueRequestId(); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_REQUEST_ID_H_ diff --git a/tensorflow/core/distributed_runtime/request_id_test.cc b/tensorflow/core/distributed_runtime/request_id_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..e0dc9d934723cfa5bea8ad3bf6377ab47bbe40a0 --- /dev/null +++ b/tensorflow/core/distributed_runtime/request_id_test.cc @@ -0,0 +1,29 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/distributed_runtime/request_id.h" + +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { + +// Try requesting some request_ids and verify that none are zero. +TEST(GetUniqueRequestId, Basic) { + for (int i = 0; i < 1000000; ++i) { + EXPECT_NE(GetUniqueRequestId(), 0); + } +} + +} // namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/rpc/BUILD b/tensorflow/core/distributed_runtime/rpc/BUILD index 80640c806deedccbe15bdca3216e0c0d195045e1..dade26abc6a3c58f24c759ad863600a156985708 100644 --- a/tensorflow/core/distributed_runtime/rpc/BUILD +++ b/tensorflow/core/distributed_runtime/rpc/BUILD @@ -186,6 +186,7 @@ tf_cuda_library( "//tensorflow/core:lib_internal", "//tensorflow/core:worker_proto_cc", "//tensorflow/core/distributed_runtime:graph_mgr", + "//tensorflow/core/distributed_runtime:recent_request_ids", "//tensorflow/core/distributed_runtime:rendezvous_mgr_interface", "//tensorflow/core/distributed_runtime:worker", "//tensorflow/core/distributed_runtime:worker_cache", @@ -270,6 +271,7 @@ cc_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core/distributed_runtime:base_rendezvous_mgr", + "//tensorflow/core/distributed_runtime:request_id", "//tensorflow/core/distributed_runtime:tensor_coding", "//tensorflow/core/distributed_runtime:worker_cache", "//tensorflow/core/distributed_runtime:worker_env", diff --git a/tensorflow/core/distributed_runtime/rpc/async_service_interface.h b/tensorflow/core/distributed_runtime/rpc/async_service_interface.h index 63b0f2272d6aa711c8ce77f00b1f2619efafccc9..b2730a583b1252d8703495782e30caf8f5fa3a46 100644 --- a/tensorflow/core/distributed_runtime/rpc/async_service_interface.h +++ b/tensorflow/core/distributed_runtime/rpc/async_service_interface.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_ASYNC_SERVICE_INTERFACE_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_ASYNC_SERVICE_INTERFACE_H_ +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_ASYNC_SERVICE_INTERFACE_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_ASYNC_SERVICE_INTERFACE_H_ namespace tensorflow { @@ -38,4 +38,4 @@ class AsyncServiceInterface { } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_ASYNC_SERVICE_INTERFACE_H_ +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_ASYNC_SERVICE_INTERFACE_H_ diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_call.h b/tensorflow/core/distributed_runtime/rpc/grpc_call.h index 2ab0a40f333bf995a3847ef9bf35d1381512c16c..ecad1274cc14c7f03eddf6fbb806e886b0c7d0b2 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_call.h +++ b/tensorflow/core/distributed_runtime/rpc/grpc_call.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_CALL_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_CALL_H_ +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_CALL_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_CALL_H_ #include "tensorflow/core/lib/core/refcount.h" #include "tensorflow/core/platform/macros.h" @@ -265,4 +265,4 @@ class Call : public UntypedCall { } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_CALL_H_ +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_CALL_H_ diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_channel.h b/tensorflow/core/distributed_runtime/rpc/grpc_channel.h index c662cde9be8998b8303b345403620ca920f3ca92..de9840fca8c312bffefea501522210cafc2af82e 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_channel.h +++ b/tensorflow/core/distributed_runtime/rpc/grpc_channel.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_CHANNEL_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_CHANNEL_H_ +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_CHANNEL_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_CHANNEL_H_ #include #include @@ -93,4 +93,4 @@ Status NewHostPortGrpcChannel(const string& target, } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_CHANNEL_H_ +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_CHANNEL_H_ diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_client_cq_tag.h b/tensorflow/core/distributed_runtime/rpc/grpc_client_cq_tag.h index 95c2c935f091abc808a7fb0ee8446ced5e1d184b..d367b83ee7fac5001bd83737531689b64a7e3774 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_client_cq_tag.h +++ b/tensorflow/core/distributed_runtime/rpc/grpc_client_cq_tag.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_CLIENT_CQ_TAG_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_CLIENT_CQ_TAG_H_ +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_CLIENT_CQ_TAG_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_CLIENT_CQ_TAG_H_ #include "grpc++/grpc++.h" @@ -41,4 +41,4 @@ class GrpcClientCQTag { } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_CLIENT_CQ_TAG_H_ +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_CLIENT_CQ_TAG_H_ diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_master_service.h b/tensorflow/core/distributed_runtime/rpc/grpc_master_service.h index 8770dcc3ac9bf7f0b6c7544a34ccb6d6fa5966b5..473604f257607456d0fb4dcb6d9189f2f6dba135 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_master_service.h +++ b/tensorflow/core/distributed_runtime/rpc/grpc_master_service.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_MASTER_SERVICE_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_MASTER_SERVICE_H_ +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_MASTER_SERVICE_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_MASTER_SERVICE_H_ #include #include "tensorflow/core/platform/types.h" @@ -34,4 +34,4 @@ AsyncServiceInterface* NewGrpcMasterService(Master* master, } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_MASTER_SERVICE_H_ +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_MASTER_SERVICE_H_ diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.h b/tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.h index 412395c52635d5c3cda95dddea50f7cd2d8c8e4f..4e203e260a1a370cc2bc7e40c3ce9e84da4d3ad4 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.h +++ b/tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_MASTER_SERVICE_IMPL_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_MASTER_SERVICE_IMPL_H_ +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_MASTER_SERVICE_IMPL_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_MASTER_SERVICE_IMPL_H_ #include "grpc++/impl/codegen/async_stream.h" #include "grpc++/impl/codegen/async_unary_call.h" @@ -186,4 +186,4 @@ class MasterService final { } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_MASTER_SERVICE_IMPL_H_ +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_MASTER_SERVICE_IMPL_H_ diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_remote_master.h b/tensorflow/core/distributed_runtime/rpc/grpc_remote_master.h index d661caaa6029dc29c9eb8983c009f232fb2b3cbf..c80668e899d100edd65649c5588177655d1d0b7e 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_remote_master.h +++ b/tensorflow/core/distributed_runtime/rpc/grpc_remote_master.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_REMOTE_MASTER_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_REMOTE_MASTER_H_ +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_REMOTE_MASTER_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_REMOTE_MASTER_H_ #include "tensorflow/core/distributed_runtime/master_interface.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h" @@ -24,4 +24,4 @@ namespace tensorflow { MasterInterface* NewGrpcMaster(const SharedGrpcChannelPtr& channel); } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_REMOTE_MASTER_H_ +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_REMOTE_MASTER_H_ diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.h b/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.h index 8ad41335409e0a7f7576134ed12b1a233aa341e0..709c3833e7aaa8b61656693e376c1d3060e0bb35 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.h +++ b/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_DISTRIBUTED_RUNTIME_RPC_GRPC_REMOTE_WORKER_H_ -#define THIRD_PARTY_TENSORFLOW_DISTRIBUTED_RUNTIME_RPC_GRPC_REMOTE_WORKER_H_ +#ifndef TENSORFLOW_DISTRIBUTED_RUNTIME_RPC_GRPC_REMOTE_WORKER_H_ +#define TENSORFLOW_DISTRIBUTED_RUNTIME_RPC_GRPC_REMOTE_WORKER_H_ #include @@ -35,4 +35,4 @@ WorkerInterface* NewGrpcRemoteWorker(SharedGrpcChannelPtr channel, } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_DISTRIBUTED_RUNTIME_RPC_GRPC_REMOTE_WORKER_H_ +#endif // TENSORFLOW_DISTRIBUTED_RUNTIME_RPC_GRPC_REMOTE_WORKER_H_ diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_serialization_traits.h b/tensorflow/core/distributed_runtime/rpc/grpc_serialization_traits.h index b35d4843e8482dc15c6013f9cd0486f8feea754a..dd114d39c62f6b69a3fb9ea4401459f963137a1f 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_serialization_traits.h +++ b/tensorflow/core/distributed_runtime/rpc/grpc_serialization_traits.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SERIALIZATION_TRAITS_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SERIALIZATION_TRAITS_H_ +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SERIALIZATION_TRAITS_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SERIALIZATION_TRAITS_H_ #include "grpc++/impl/codegen/proto_utils.h" #include "grpc++/support/slice.h" @@ -231,4 +231,4 @@ class UnlimitedSizeProtoSerializationTraits { : public UnlimitedSizeProtoSerializationTraits {}; \ } // namespace grpc -#endif // THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SERIALIZATION_TRAITS_H_ +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SERIALIZATION_TRAITS_H_ diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h index c3f513d4926e9abe59561e4146237d3ced244ea7..8b12ac1461d6b1fa3098197aa7697031a5d3075b 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h +++ b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SERVER_LIB_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SERVER_LIB_H_ +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SERVER_LIB_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SERVER_LIB_H_ #include @@ -141,4 +141,4 @@ class GrpcServer : public ServerInterface { } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SERVER_LIB_H_ +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SERVER_LIB_H_ diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_session.h b/tensorflow/core/distributed_runtime/rpc/grpc_session.h index 300f7271249d88e4aa2153e64d2b2671a6168b65..d87956a13515fde533e746d2abd04e4a2f4959ae 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_session.h +++ b/tensorflow/core/distributed_runtime/rpc/grpc_session.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SESSION_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SESSION_H_ +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SESSION_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SESSION_H_ #include #include @@ -130,4 +130,4 @@ class GrpcSession : public Session { } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SESSION_H_ +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SESSION_H_ diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_state.h b/tensorflow/core/distributed_runtime/rpc/grpc_state.h index 3f80bdfb70d0f3054b35a17ee34ec53655ccccc1..0b6f9474dd9e520b21c1915578cd8071a28ac7fd 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_state.h +++ b/tensorflow/core/distributed_runtime/rpc/grpc_state.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_STATE_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_STATE_H_ +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_STATE_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_STATE_H_ #include @@ -96,4 +96,4 @@ class RPCState : public GrpcClientCQTag { } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_STATE_H_ +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_STATE_H_ diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_testlib.h b/tensorflow/core/distributed_runtime/rpc/grpc_testlib.h index 5e81b90189484053907f9b3f70154d1f2ce25775..4b3a03b1d708744bded25ff4d320979bb7eb38b2 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_testlib.h +++ b/tensorflow/core/distributed_runtime/rpc/grpc_testlib.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_TESTLIB_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_TESTLIB_H_ +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_TESTLIB_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_TESTLIB_H_ #include #include @@ -70,4 +70,4 @@ class TestCluster { } // end namespace test } // end namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_TESTLIB_H_ +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_TESTLIB_H_ diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_util.h b/tensorflow/core/distributed_runtime/rpc/grpc_util.h index bb854783472c4a5e1261e9e737f4b830e5cbf3e2..d5e7e9f5b39e9f1ab9704de3f8ec7964096ae569 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_util.h +++ b/tensorflow/core/distributed_runtime/rpc/grpc_util.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_UTIL_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_UTIL_H_ +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_UTIL_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_UTIL_H_ #include @@ -114,4 +114,4 @@ class GrpcByteBufferSource : public ::grpc::protobuf::io::ZeroCopyInputStream { } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_UTIL_H_ +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_UTIL_H_ diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h b/tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h index 17a307a6d99748c4d5daa689c9967622ed933d87..7a35fdbca08e1f7a79e77418f69efb3e4fa80e0a 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h +++ b/tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_WORKER_CACHE_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_WORKER_CACHE_H_ +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_WORKER_CACHE_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_WORKER_CACHE_H_ #include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h" #include "tensorflow/core/distributed_runtime/worker_cache.h" @@ -29,4 +29,4 @@ WorkerCacheInterface* NewGrpcWorkerCacheWithLocalWorker( const string& local_target); } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_WORKER_CACHE_H_ +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_WORKER_CACHE_H_ diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc index 15faf21dafc2ee1a2a6d6ad6463b87aa9a62d88d..95811476f789be0225231f86aa0242db71b81199 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc @@ -354,7 +354,8 @@ class GrpcWorkerService : public AsyncServiceInterface { } // namespace -GrpcWorker::GrpcWorker(WorkerEnv* worker_env) : Worker(worker_env) {} +GrpcWorker::GrpcWorker(WorkerEnv* worker_env) + : Worker(worker_env), recv_tensor_recent_request_ids_(100000) {} // GrpcRecvTensorAsync: unlike the other Worker methods, which use protocol // buffers for a response object, to avoid extra protocol buffer serialization @@ -363,11 +364,18 @@ void GrpcWorker::GrpcRecvTensorAsync(CallOptions* opts, const RecvTensorRequest* request, ::grpc::ByteBuffer* response, StatusCallback done) { + Status s = recv_tensor_recent_request_ids_.TrackUnique( + request->request_id(), "RecvTensor (GrpcWorker)", *request); + if (!s.ok()) { + done(s); + return; + } + const int64 step_id = request->step_id(); const string& key = request->rendezvous_key(); TRACEPRINTF("RecvTensor: %lld %s", step_id, key.c_str()); Rendezvous::ParsedKey parsed; - Status s = Rendezvous::ParseKey(key, &parsed); + s = Rendezvous::ParseKey(key, &parsed); Device* src_dev = nullptr; if (s.ok()) { s = PrepareRecvTensor(parsed, &src_dev); diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h index 64d7c986daf1f78dafdbdf459034fd51db4d699d..78a21fd9f6ecb6deac171bb5c4a16fa074988fa2 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h +++ b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h @@ -13,9 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_WORKER_SERVICE_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_WORKER_SERVICE_H_ +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_WORKER_SERVICE_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_WORKER_SERVICE_H_ +#include "tensorflow/core/distributed_runtime/recent_request_ids.h" #include "tensorflow/core/distributed_runtime/worker.h" namespace grpc { @@ -40,6 +41,9 @@ class GrpcWorker : public Worker { StatusCallback done); WorkerEnv* env(); + + private: + RecentRequestIds recv_tensor_recent_request_ids_; }; std::unique_ptr NewGrpcWorker(WorkerEnv* worker_env); @@ -50,4 +54,4 @@ std::unique_ptr NewGrpcWorkerService( } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_WORKER_SERVICE_H_ +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_WORKER_SERVICE_H_ diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.h b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.h index fb23f8631fd17a7533fde01cde9453dc8ea8505a..1a5e2edfb240198c50d3b5d00bec1127fceff725 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.h +++ b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_WORKER_SERVICE_IMPL_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_WORKER_SERVICE_IMPL_H_ +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_WORKER_SERVICE_IMPL_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_WORKER_SERVICE_IMPL_H_ #include "grpc++/impl/codegen/async_stream.h" #include "grpc++/impl/codegen/async_unary_call.h" @@ -147,4 +147,4 @@ class WorkerService final { } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_WORKER_SERVICE_IMPL_H_ +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_WORKER_SERVICE_IMPL_H_ diff --git a/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.cc b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.cc index 72dfe5c062177de7039980ece31778e7cac06592..067dc5dff5bb81f8cc1da883d226ee3cfa5638f2 100644 --- a/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.cc +++ b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/dma_helper.h" #include "tensorflow/core/common_runtime/process_util.h" +#include "tensorflow/core/distributed_runtime/request_id.h" #include "tensorflow/core/distributed_runtime/tensor_coding.h" #include "tensorflow/core/distributed_runtime/worker_cache.h" #include "tensorflow/core/distributed_runtime/worker_interface.h" @@ -67,6 +68,7 @@ class RpcRecvTensorCall : public BaseRecvTensorCall { done_ = std::move(done); req_.set_step_id(step_id); req_.set_rendezvous_key(key.data(), key.size()); + req_.set_request_id(GetUniqueRequestId()); } void Reset(WorkerCacheInterface* wc) { diff --git a/tensorflow/core/distributed_runtime/server_lib.h b/tensorflow/core/distributed_runtime/server_lib.h index a064d20cdb84fe82a53e85a95944301e9761bb03..275f526d311aec571a2f2ffb8a377d952b6ae8dc 100644 --- a/tensorflow/core/distributed_runtime/server_lib.h +++ b/tensorflow/core/distributed_runtime/server_lib.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_SERVER_LIB_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_SERVER_LIB_H_ +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_SERVER_LIB_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_SERVER_LIB_H_ #include @@ -95,4 +95,4 @@ Status NewServer(const ServerDef& server_def, } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_SERVER_LIB_H_ +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_SERVER_LIB_H_ diff --git a/tensorflow/core/distributed_runtime/session_mgr.h b/tensorflow/core/distributed_runtime/session_mgr.h index ba077c3accff672f088bb7222858197b43ea4676..3ce260d12e92e3458fe12f3f5b5723f9c39b5f4b 100644 --- a/tensorflow/core/distributed_runtime/session_mgr.h +++ b/tensorflow/core/distributed_runtime/session_mgr.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_SESSION_MGR_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_SESSION_MGR_H_ +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_SESSION_MGR_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_SESSION_MGR_H_ #include @@ -87,4 +87,4 @@ class SessionMgr { } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_SESSION_MGR_H_ +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_SESSION_MGR_H_ diff --git a/tensorflow/core/distributed_runtime/worker.h b/tensorflow/core/distributed_runtime/worker.h index c62347926fa11c135b6116d17f6545007e9f6115..62fa5f3cf54202c91b27ae03d9d34fc09b8392ec 100644 --- a/tensorflow/core/distributed_runtime/worker.h +++ b/tensorflow/core/distributed_runtime/worker.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_H_ +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_H_ #include @@ -120,4 +120,4 @@ class Worker : public WorkerInterface { } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_H_ +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_H_ diff --git a/tensorflow/core/distributed_runtime/worker_session.h b/tensorflow/core/distributed_runtime/worker_session.h index 9da3bb253f838efdf6d4dd97575f7ae48ba95ab1..0fd19ac27f20edbf8a2ed85d1c3d97eaabab3347 100644 --- a/tensorflow/core/distributed_runtime/worker_session.h +++ b/tensorflow/core/distributed_runtime/worker_session.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_SESSION_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_SESSION_H_ +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_SESSION_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_SESSION_H_ #include @@ -61,4 +61,4 @@ struct WorkerSession { } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_SESSION_H_ +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_SESSION_H_ diff --git a/tensorflow/core/example/example_parser_configuration.h b/tensorflow/core/example/example_parser_configuration.h index 69955ec4cb3deb92587e4ed95382e5eaf9f74eab..3d06bd55e2bdd845c598078438dac79edf7e475e 100644 --- a/tensorflow/core/example/example_parser_configuration.h +++ b/tensorflow/core/example/example_parser_configuration.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_EXAMPLE_EXAMPLE_PARSER_CONFIGURATION_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_EXAMPLE_EXAMPLE_PARSER_CONFIGURATION_H_ +#ifndef TENSORFLOW_CORE_EXAMPLE_EXAMPLE_PARSER_CONFIGURATION_H_ +#define TENSORFLOW_CORE_EXAMPLE_EXAMPLE_PARSER_CONFIGURATION_H_ #include #include @@ -53,4 +53,4 @@ Status ExampleParserConfigurationProtoToFeatureVectors( } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_EXAMPLE_EXAMPLE_PARSE_CONFIGURATION_H_ +#endif // TENSORFLOW_CORE_EXAMPLE_EXAMPLE_PARSE_CONFIGURATION_H_ diff --git a/tensorflow/core/framework/common_shape_fns.h b/tensorflow/core/framework/common_shape_fns.h index c0deb473a25cf19b99ae79903c1a2014b6e378f7..293c40e04d6ad9b57aabfda678216b1805a006f4 100644 --- a/tensorflow/core/framework/common_shape_fns.h +++ b/tensorflow/core/framework/common_shape_fns.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_OPS_COMMON_SHAPE_FNS_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_OPS_COMMON_SHAPE_FNS_H_ +#ifndef TENSORFLOW_CORE_OPS_COMMON_SHAPE_FNS_H_ +#define TENSORFLOW_CORE_OPS_COMMON_SHAPE_FNS_H_ #include @@ -287,4 +287,4 @@ Status ExplicitShape(InferenceContext* c); } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_OPS_COMMON_SHAPE_FNS_H_ +#endif // TENSORFLOW_CORE_OPS_COMMON_SHAPE_FNS_H_ diff --git a/tensorflow/core/framework/dataset.h b/tensorflow/core/framework/dataset.h new file mode 100644 index 0000000000000000000000000000000000000000..2c2c7e7c585c9364e1d08280d5fe76f1bf1eff23 --- /dev/null +++ b/tensorflow/core/framework/dataset.h @@ -0,0 +1,75 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_FRAMEWORK_DATASET_H_ +#define TENSORFLOW_FRAMEWORK_DATASET_H_ + +namespace tensorflow { +namespace dataset { +// Registry for stateful ops that need to be used in dataset functions. +// See below macro for usage details. +class WhitelistedStatefulOpRegistry { + public: + Status Add(StringPiece op_name) { + op_names_.insert(op_name); + return Status::OK(); + } + + bool Contains(StringPiece op_name) { + return op_names_.find(op_name) != op_names_.end(); + } + + static WhitelistedStatefulOpRegistry* Global() { + static WhitelistedStatefulOpRegistry* reg = + new WhitelistedStatefulOpRegistry; + return reg; + } + + private: + WhitelistedStatefulOpRegistry() {} + WhitelistedStatefulOpRegistry(WhitelistedStatefulOpRegistry const& copy); + WhitelistedStatefulOpRegistry operator=( + WhitelistedStatefulOpRegistry const& copy); + std::set op_names_; +}; + +} // namespace dataset + +// Use this macro to whitelist an op that is marked stateful but needs to be +// used inside a map_fn in an input pipeline. This is only needed if you wish +// to be able to checkpoint the state of the input pipeline. We currently +// do not allow stateful ops to be defined inside of map_fns since it is not +// possible to save their state. +// Note that the state of the whitelisted ops inside functions will not be +// saved during checkpointing, hence this should only be used if the op is +// marked stateful for reasons like to avoid constant folding during graph +// optimiztion but is not stateful. +// If possible, try to remove the stateful flag on the op first. +// Example usage: +// +// WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS("LegacyStatefulReader"); +// +#define WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS(name) \ + WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS_UNIQ_HELPER(__COUNTER__, name) +#define WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS_UNIQ_HELPER(ctr, name) \ + WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS_UNIQ(ctr, name) +#define WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS_UNIQ(ctr, name) \ + static ::tensorflow::Status whitelist_op##ctr TF_ATTRIBUTE_UNUSED = \ + ::tensorflow::dataset::WhitelistedStatefulOpRegistry::Global()->Add( \ + name) + +} // namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_DATASET_H_ diff --git a/tensorflow/core/framework/op_kernel.cc b/tensorflow/core/framework/op_kernel.cc index c879dc6f3f6039fad268680f52de128a4ae8a8f6..aee3a0afbca23a180d5415fef2b1b405f23b3f53 100644 --- a/tensorflow/core/framework/op_kernel.cc +++ b/tensorflow/core/framework/op_kernel.cc @@ -34,6 +34,7 @@ limitations under the License. #include "tensorflow/core/lib/core/notification.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/map_util.h" +#include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" @@ -1162,24 +1163,51 @@ const Eigen::SyclDevice& OpKernelContext::eigen_device() const { } #endif -void OpKernelConstruction::CtxFailure(Status s) { +void OpKernelConstruction::CtxFailure(const Status& s) { VLOG(1) << s; SetStatus(s); } -void OpKernelConstruction::CtxFailureWithWarning(Status s) { +void OpKernelConstruction::CtxFailureWithWarning(const Status& s) { LOG(WARNING) << s; SetStatus(s); } -void OpKernelContext::CtxFailure(Status s) { +void OpKernelConstruction::CtxFailure(const char* file, int line, + const Status& s) { + VLOG(1) << "OP_REQUIRES failed at " << io::Basename(file) << ":" << line + << " : " << s; + SetStatus(s); +} + +void OpKernelConstruction::CtxFailureWithWarning(const char* file, int line, + const Status& s) { + LOG(WARNING) << "OP_REQUIRES failed at " << io::Basename(file) << ":" << line + << " : " << s; + SetStatus(s); +} + +void OpKernelContext::CtxFailure(const Status& s) { VLOG(1) << s; SetStatus(s); } -void OpKernelContext::CtxFailureWithWarning(Status s) { +void OpKernelContext::CtxFailureWithWarning(const Status& s) { LOG(WARNING) << s; SetStatus(s); } +void OpKernelContext::CtxFailure(const char* file, int line, const Status& s) { + VLOG(1) << "OP_REQUIRES failed at " << io::Basename(file) << ":" << line + << " : " << s; + SetStatus(s); +} + +void OpKernelContext::CtxFailureWithWarning(const char* file, int line, + const Status& s) { + LOG(WARNING) << "OP_REQUIRES failed at " << io::Basename(file) << ":" << line + << " : " << s; + SetStatus(s); +} + } // namespace tensorflow diff --git a/tensorflow/core/framework/op_kernel.h b/tensorflow/core/framework/op_kernel.h index 25150499ad76c45493645a9ee4a83fd55e69eb13..b72f1405cffd83439dd837fa7f8e641ecf44e2ae 100644 --- a/tensorflow/core/framework/op_kernel.h +++ b/tensorflow/core/framework/op_kernel.h @@ -316,8 +316,10 @@ class OpKernelConstruction { int graph_def_version() const { return graph_def_version_; } // Helper routines for the OP_REQUIRES macros - void CtxFailure(Status s); - void CtxFailureWithWarning(Status s); + void CtxFailure(const Status& s); + void CtxFailureWithWarning(const Status& s); + void CtxFailure(const char* file, int line, const Status& s); + void CtxFailureWithWarning(const char* file, int line, const Status& s); // Unrecommended functions: these are functions that have some // current uses but are not recommended for use, and may go away at @@ -1014,8 +1016,10 @@ class OpKernelContext { } // Helper routines for the OP_REQUIRES macros - void CtxFailure(Status s); - void CtxFailureWithWarning(Status s); + void CtxFailure(const Status& s); + void CtxFailureWithWarning(const Status& s); + void CtxFailure(const char* file, int line, const Status& s); + void CtxFailureWithWarning(const char* file, int line, const Status& s); // Unrecommended functions: these are functions that have some // current uses but are not recommended for use, and may go away at @@ -1476,40 +1480,40 @@ inline void OpOutputList::set_ref(int i, mutex* mu, Tensor* tensor_for_ref) { // ... // } -#define OP_REQUIRES(CTX, EXP, STATUS) \ - do { \ - if (!TF_PREDICT_TRUE(EXP)) { \ - (CTX)->CtxFailure((STATUS)); \ - return; \ - } \ +#define OP_REQUIRES(CTX, EXP, STATUS) \ + do { \ + if (!TF_PREDICT_TRUE(EXP)) { \ + (CTX)->CtxFailure(__FILE__, __LINE__, (STATUS)); \ + return; \ + } \ } while (0) -#define OP_REQUIRES_OK(CTX, ...) \ - do { \ - ::tensorflow::Status _s(__VA_ARGS__); \ - if (!TF_PREDICT_TRUE(_s.ok())) { \ - (CTX)->CtxFailureWithWarning(_s); \ - return; \ - } \ +#define OP_REQUIRES_OK(CTX, ...) \ + do { \ + ::tensorflow::Status _s(__VA_ARGS__); \ + if (!TF_PREDICT_TRUE(_s.ok())) { \ + (CTX)->CtxFailureWithWarning(__FILE__, __LINE__, _s); \ + return; \ + } \ } while (0) -#define OP_REQUIRES_ASYNC(CTX, EXP, STATUS, CALLBACK) \ - do { \ - if (!TF_PREDICT_TRUE(EXP)) { \ - (CTX)->CtxFailure((STATUS)); \ - (CALLBACK)(); \ - return; \ - } \ +#define OP_REQUIRES_ASYNC(CTX, EXP, STATUS, CALLBACK) \ + do { \ + if (!TF_PREDICT_TRUE(EXP)) { \ + (CTX)->CtxFailure(__FILE__, __LINE__, (STATUS)); \ + (CALLBACK)(); \ + return; \ + } \ } while (0) -#define OP_REQUIRES_OK_ASYNC(CTX, STATUS, CALLBACK) \ - do { \ - ::tensorflow::Status _s(STATUS); \ - if (!TF_PREDICT_TRUE(_s.ok())) { \ - (CTX)->CtxFailureWithWarning(_s); \ - (CALLBACK)(); \ - return; \ - } \ +#define OP_REQUIRES_OK_ASYNC(CTX, STATUS, CALLBACK) \ + do { \ + ::tensorflow::Status _s(STATUS); \ + if (!TF_PREDICT_TRUE(_s.ok())) { \ + (CTX)->CtxFailureWithWarning(__FILE__, __LINE__, _s); \ + (CALLBACK)(); \ + return; \ + } \ } while (0) } // namespace tensorflow diff --git a/tensorflow/core/framework/register_types.h b/tensorflow/core/framework/register_types.h index edc93aec7f801b77a5c7867589f9d89ff7b6ea8f..e062adffe821464cd349227cde17b9d4db54c44e 100644 --- a/tensorflow/core/framework/register_types.h +++ b/tensorflow/core/framework/register_types.h @@ -53,7 +53,7 @@ limitations under the License. */ #if !defined(IS_MOBILE_PLATFORM) || defined(SUPPORT_SELECTIVE_REGISTRATION) || \ - defined(NVIDIA_TEGRA) + defined(ANDROID_TEGRA) // All types are supported, so all macros are invoked. // diff --git a/tensorflow/core/framework/shape_inference.h b/tensorflow/core/framework/shape_inference.h index 4a4ef12635f867fccb594d50a2c9e8f3059ce337..d552ec1693f89a6695609681f2e8bffa9d78f93c 100644 --- a/tensorflow/core/framework/shape_inference.h +++ b/tensorflow/core/framework/shape_inference.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_FRAMEWORK_SHAPE_INFERENCE_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_FRAMEWORK_SHAPE_INFERENCE_H_ +#ifndef TENSORFLOW_CORE_FRAMEWORK_SHAPE_INFERENCE_H_ +#define TENSORFLOW_CORE_FRAMEWORK_SHAPE_INFERENCE_H_ #include @@ -787,4 +787,4 @@ Status InferenceContext::GetAttr(StringPiece attr_name, T* value) const { } // namespace shape_inference } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_FRAMEWORK_SHAPE_INFERENCE_H_ +#endif // TENSORFLOW_CORE_FRAMEWORK_SHAPE_INFERENCE_H_ diff --git a/tensorflow/core/framework/shape_inference_testutil.h b/tensorflow/core/framework/shape_inference_testutil.h index fbfd24538bc7a5b1f3ee3805d4a803a0e7239fca..7977841482efa396c8e0797d8c80a40c11b4df56 100644 --- a/tensorflow/core/framework/shape_inference_testutil.h +++ b/tensorflow/core/framework/shape_inference_testutil.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_FRAMEWORK_SHAPE_INFERENCE_TESTUTIL_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_FRAMEWORK_SHAPE_INFERENCE_TESTUTIL_H_ +#ifndef TENSORFLOW_CORE_FRAMEWORK_SHAPE_INFERENCE_TESTUTIL_H_ +#define TENSORFLOW_CORE_FRAMEWORK_SHAPE_INFERENCE_TESTUTIL_H_ #include #include "tensorflow/core/framework/node_def.pb.h" @@ -98,4 +98,4 @@ class ShapeInferenceTestutil { } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_FRAMEWORK_SHAPE_INFERENCE_TESTUTIL_H_ +#endif // TENSORFLOW_CORE_FRAMEWORK_SHAPE_INFERENCE_TESTUTIL_H_ diff --git a/tensorflow/core/framework/tensor.cc b/tensorflow/core/framework/tensor.cc index 4f08cdc1d7c130bd351de7b5f7574ea199977804..77a3edcc10e9c5ceb8bf26570c3e271f9e853444 100644 --- a/tensorflow/core/framework/tensor.cc +++ b/tensorflow/core/framework/tensor.cc @@ -615,11 +615,11 @@ void Tensor::CheckType(DataType expected_dtype) const { void Tensor::CheckTypeAndIsAligned(DataType expected_dtype) const { CHECK_EQ(dtype(), expected_dtype); - CHECK(IsAligned()); + CHECK(IsAligned()) << "CheckTypeAndIsAligned"; } void Tensor::CheckIsAlignedAndSingleElement() const { - CHECK(IsAligned()); + CHECK(IsAligned()) << "Aligned and single element"; CHECK_EQ(1, NumElements()) << "Must have a one element tensor"; } diff --git a/tensorflow/core/framework/tensor.h b/tensorflow/core/framework/tensor.h index 92d10f0d8cf452264885917bc0c897e03527a782..94c39c53a6fbb6a30e054346a2ec608a6970c373 100644 --- a/tensorflow/core/framework/tensor.h +++ b/tensorflow/core/framework/tensor.h @@ -660,7 +660,8 @@ void Tensor::FillDimsAndValidateCompatibleShape( template typename TTypes::Tensor Tensor::shaped( gtl::ArraySlice new_sizes) { - CheckTypeAndIsAligned(DataTypeToEnum::v()); + CheckType(DataTypeToEnum::v()); + CHECK(IsAligned()); Eigen::array dims; FillDimsAndValidateCompatibleShape(new_sizes, &dims); return typename TTypes::Tensor(base(), dims); @@ -687,7 +688,8 @@ typename TTypes::UnalignedTensor Tensor::unaligned_shaped( template typename TTypes::ConstTensor Tensor::shaped( gtl::ArraySlice new_sizes) const { - CheckTypeAndIsAligned(DataTypeToEnum::v()); + CheckType(DataTypeToEnum::v()); + CHECK(IsAligned()); Eigen::array dims; FillDimsAndValidateCompatibleShape(new_sizes, &dims); return typename TTypes::ConstTensor(base(), dims); diff --git a/tensorflow/core/framework/types.h b/tensorflow/core/framework/types.h index cb8e77f1df962eb36277ac7c01e8b580d5926452..ded6aa09918f873b975f537fa33dcd55902090fe 100644 --- a/tensorflow/core/framework/types.h +++ b/tensorflow/core/framework/types.h @@ -453,6 +453,13 @@ inline bool DataTypeIsInteger(DataType dt) { return kDataTypeIsInteger.Contains(dt); } +// Is the dtype a signed integral type? +constexpr DataTypeSet kDataTypeIsSigned = + ToSet(DT_INT8) | ToSet(DT_INT16) | ToSet(DT_INT32) | ToSet(DT_INT64); +inline bool DataTypeIsSigned(DataType dt) { + return kDataTypeIsSigned.Contains(dt); +} + // Is the dtype an unsigned integral type? constexpr DataTypeSet kDataTypeIsUnsigned = ToSet(DT_UINT8) | ToSet(DT_UINT16) | ToSet(DT_UINT32) | ToSet(DT_UINT64); diff --git a/tensorflow/core/framework/variant_op_registry.cc b/tensorflow/core/framework/variant_op_registry.cc index 395329da3bee01cf73c69d52b150b88f34d1b1ff..ee07db1aee15e578c4bcbac22cecf6e75e95b6e2 100644 --- a/tensorflow/core/framework/variant_op_registry.cc +++ b/tensorflow/core/framework/variant_op_registry.cc @@ -182,7 +182,7 @@ Status VariantDeviceCopy( // Special casing UnaryOpFn per op and per device. UnaryVariantOpRegistry::VariantUnaryOpFn* UnaryVariantOpRegistry::GetUnaryOpFn( VariantUnaryOp op, StringPiece device, StringPiece type_name) { - auto found = unary_op_fns.find(std::make_tuple(op, device, type_name)); + auto found = unary_op_fns.find({op, device, type_name}); if (found == unary_op_fns.end()) return nullptr; return &found->second; } @@ -195,12 +195,10 @@ void UnaryVariantOpRegistry::RegisterUnaryOpFn( CHECK_EQ(existing, nullptr) << "Unary VariantUnaryOpFn for type_name: " << type_name << " already registered for device type: " << device; - unary_op_fns.insert( - std::pair, - VariantUnaryOpFn>( - std::make_tuple(op, GetPersistentStringPiece(device), - GetPersistentStringPiece(type_name)), - unary_op_fn)); + unary_op_fns.insert(std::pair, VariantUnaryOpFn>( + {op, GetPersistentStringPiece(device), + GetPersistentStringPiece(type_name)}, + unary_op_fn)); } namespace { @@ -229,7 +227,7 @@ REGISTER_VARIANT_ZEROS_LIKE_TYPE(bool); UnaryVariantOpRegistry::VariantBinaryOpFn* UnaryVariantOpRegistry::GetBinaryOpFn(VariantBinaryOp op, StringPiece device, StringPiece type_name) { - auto found = binary_op_fns.find(std::make_tuple(op, device, type_name)); + auto found = binary_op_fns.find({op, device, type_name}); if (found == binary_op_fns.end()) return nullptr; return &found->second; } @@ -242,12 +240,10 @@ void UnaryVariantOpRegistry::RegisterBinaryOpFn( CHECK_EQ(existing, nullptr) << "Unary VariantBinaryOpFn for type_name: " << type_name << " already registered for device type: " << device; - binary_op_fns.insert( - std::pair, - VariantBinaryOpFn>( - std::make_tuple(op, GetPersistentStringPiece(device), - GetPersistentStringPiece(type_name)), - add_fn)); + binary_op_fns.insert(std::pair, VariantBinaryOpFn>( + {op, GetPersistentStringPiece(device), + GetPersistentStringPiece(type_name)}, + add_fn)); } namespace { diff --git a/tensorflow/core/framework/variant_op_registry.h b/tensorflow/core/framework/variant_op_registry.h index 13f6908cae1ed1b1964bf827dce0fcb2bee4e6d1..0e2a410429d199998722e68280a8438465988ddd 100644 --- a/tensorflow/core/framework/variant_op_registry.h +++ b/tensorflow/core/framework/variant_op_registry.h @@ -166,6 +166,21 @@ class UnaryVariantOpRegistry { device_copy_fns; // Map std::tuple to function. + + // this breaks by falling victim to "too perfect forwarding" + // see https://stackoverflow.com/questions/44475317/variadic-template-issue + // and references therein + template + struct FuncTuple { + FuncTuple(const Op& op, const StringPiece& dev, const StringPiece& tname) + : op_type_(op), device_(dev), typename_(tname){}; + Op op_type_; + StringPiece device_, typename_; + }; + //friend declaration for operator== + // needed for clang + template + friend bool operator==(const FuncTuple &l, const FuncTuple &r); struct TupleHash { template std::size_t operator()( @@ -176,18 +191,24 @@ class UnaryVariantOpRegistry { ret = Hash64Combine(ret, sp_hasher_(std::get<2>(x))); return ret; } + + template + std::size_t operator()(const FuncTuple& x) const { + // The hash of an enum is just its value as a std::size_t. + std::size_t ret = static_cast(x.op_type_); + ret = Hash64Combine(ret, sp_hasher_(x.device_)); + ret = Hash64Combine(ret, sp_hasher_(x.typename_)); + return ret; + } StringPieceHasher sp_hasher_; }; - std::unordered_map, - VariantUnaryOpFn, TupleHash> + std::unordered_map, VariantUnaryOpFn, TupleHash> unary_op_fns; - std::unordered_map, - VariantBinaryOpFn, TupleHash> + std::unordered_map, VariantBinaryOpFn, TupleHash> binary_op_fns; // Find or insert a string into a persistent string storage - // container; return the StringPiece pointing to the permanent - // string location. + // container; return the StringPiece pointing to the permanent string location. static StringPiece GetPersistentStringPiece(const string& str) { const auto string_storage = PersistentStringStorage(); auto found = string_storage->find(str); @@ -199,7 +220,12 @@ class UnaryVariantOpRegistry { } } }; - +template +inline bool operator==(const UnaryVariantOpRegistry::FuncTuple& lhs, + const UnaryVariantOpRegistry::FuncTuple& rhs) { + return (lhs.op_type_ == rhs.op_type_) && (lhs.device_ == rhs.device_) && + (lhs.typename_ == rhs.typename_); +} // Gets a TensorShape from a Tensor containing a scalar Variant. // Returns an Internal error if the Variant does not have a registered shape // function, or if it's a serialized Variant that cannot be decoded. @@ -283,8 +309,8 @@ Status BinaryOpVariants(OpKernelContext* ctx, VariantBinaryOp op, return errors::Internal( "No unary variant binary_op function found for binary variant op " "enum: ", - op, " Variant type_name: '", a.TypeName(), - "' for device type: ", device); + op, " Variant type_name: '", a.TypeName(), "' for device type: ", + device); } return (*binary_op_fn)(ctx, a, b, out); } diff --git a/tensorflow/core/graph/gradients.h b/tensorflow/core/graph/gradients.h index 75906e6ce96de3deb5bb603fb4ca06763496bb6d..ddfed084b09c1072aae7ae7838d84c4659188bf4 100644 --- a/tensorflow/core/graph/gradients.h +++ b/tensorflow/core/graph/gradients.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_GRAPH_GRADIENTS_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_GRAPH_GRADIENTS_H_ +#ifndef TENSORFLOW_CORE_GRAPH_GRADIENTS_H_ +#define TENSORFLOW_CORE_GRAPH_GRADIENTS_H_ #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/lib/core/status.h" @@ -55,4 +55,4 @@ Status AddSymbolicGradients(gtl::ArraySlice y_node_outputs, } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_GRAPH_GRADIENTS_H_ +#endif // TENSORFLOW_CORE_GRAPH_GRADIENTS_H_ diff --git a/tensorflow/core/grappler/costs/BUILD b/tensorflow/core/grappler/costs/BUILD index 7abc155c19db06db81a62672f3f9f333272d5a3f..0fe01e9c9e094ebfa7fd1e6200d775ef61775184 100644 --- a/tensorflow/core/grappler/costs/BUILD +++ b/tensorflow/core/grappler/costs/BUILD @@ -1,6 +1,10 @@ licenses(["notice"]) # Apache 2.0 load("//tensorflow:tensorflow.bzl", "tf_cuda_library", "tf_cc_test") +load( + "//tensorflow/core:platform/default/build_config.bzl", + "tf_protos_grappler", +) filegroup( name = "all_files", @@ -37,6 +41,7 @@ tf_proto_library( name = "op_performance_data", srcs = ["op_performance_data.proto"], cc_api_version = 2, + default_header = True, protodeps = tf_additional_all_protos(), visibility = ["//visibility:public"], ) @@ -47,7 +52,6 @@ cc_library( hdrs = ["graph_properties.h"], visibility = ["//visibility:public"], deps = [ - ":op_performance_data_cc", ":utils", "//tensorflow/core:core_cpu_base", "//tensorflow/core:framework", @@ -55,7 +59,7 @@ cc_library( "//tensorflow/core/grappler:grappler_item", "//tensorflow/core/grappler:utils", "//tensorflow/core/grappler/clusters:cluster", - ], + ] + tf_protos_grappler(), ) tf_cc_test( @@ -135,7 +139,7 @@ tf_cuda_library( hdrs = ["utils.h"], visibility = ["//visibility:public"], deps = [ - ":op_performance_data_cc", + "//third_party/eigen3", "//tensorflow/core:framework", "//tensorflow/core:graph", "//tensorflow/core:lib", @@ -143,8 +147,7 @@ tf_cuda_library( "//tensorflow/core:protos_all_cc", "//tensorflow/core/grappler:utils", "//tensorflow/core/grappler/clusters:utils", - "//third_party/eigen3", - ], + ] + tf_protos_grappler(), ) tf_cc_test( @@ -207,9 +210,8 @@ cc_library( hdrs = ["op_context.h"], visibility = ["//visibility:public"], deps = [ - ":op_performance_data_cc", "//tensorflow/core:protos_all_cc", - ], + ] + tf_protos_grappler(), ) cc_library( @@ -276,12 +278,11 @@ cc_library( deps = [ ":cost_estimator", ":op_context", - ":op_performance_data_cc", + "//third_party/eigen3", "//tensorflow/core:framework", "//tensorflow/core:protos_all_cc", "//tensorflow/core/grappler/clusters:utils", - "//third_party/eigen3", - ], + ] + tf_protos_grappler(), ) tf_cc_test( @@ -305,7 +306,6 @@ cc_library( ":cost_estimator", ":graph_properties", ":op_level_cost_estimator", - ":op_performance_data_cc", ":utils", ":virtual_placer", ":virtual_scheduler", @@ -314,7 +314,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core/grappler:grappler_item", - ], + ] + tf_protos_grappler(), ) tf_cc_test( diff --git a/tensorflow/core/grappler/costs/op_context.h b/tensorflow/core/grappler/costs/op_context.h index 735a1e68ea6e30adff297d29f6a9c86111ef7507..6391de4a91ead5032013b3c9143ebcfc9f929901 100644 --- a/tensorflow/core/grappler/costs/op_context.h +++ b/tensorflow/core/grappler/costs/op_context.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_GRAPPLER_COSTS_OP_CONTEXT_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_GRAPPLER_COSTS_OP_CONTEXT_H_ +#ifndef TENSORFLOW_CORE_GRAPPLER_COSTS_OP_CONTEXT_H_ +#define TENSORFLOW_CORE_GRAPPLER_COSTS_OP_CONTEXT_H_ #include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/grappler/costs/op_performance_data.pb.h" @@ -36,4 +36,4 @@ struct OpContext { } // end namespace grappler } // end namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_GRAPPLER_COSTS_OP_CONTEXT_H_ +#endif // TENSORFLOW_CORE_GRAPPLER_COSTS_OP_CONTEXT_H_ diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.h b/tensorflow/core/grappler/costs/virtual_scheduler.h index c1802509089645a72c5cf06d9b5375553d053841..8ccc51f5451bb2b5052fd04100ba7684b0956cea 100644 --- a/tensorflow/core/grappler/costs/virtual_scheduler.h +++ b/tensorflow/core/grappler/costs/virtual_scheduler.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_GRAPPLER_COSTS_VIRTUAL_SCHEDULER_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_GRAPPLER_COSTS_VIRTUAL_SCHEDULER_H_ +#ifndef TENSORFLOW_CORE_GRAPPLER_COSTS_VIRTUAL_SCHEDULER_H_ +#define TENSORFLOW_CORE_GRAPPLER_COSTS_VIRTUAL_SCHEDULER_H_ #include #include @@ -342,4 +342,4 @@ class VirtualScheduler { } // namespace grappler } // end namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_GRAPPLER_COSTS_VIRTUAL_SCHEDULER_H_ +#endif // TENSORFLOW_CORE_GRAPPLER_COSTS_VIRTUAL_SCHEDULER_H_ diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc index 990a07c86c2e144d24505dd45a092884f4ef77bc..9c544c82bf7f77760e5a2090ca947fd7185e27b7 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc @@ -436,23 +436,31 @@ bool UniqueNodes::SameNode(const NodeDef& node1, const NodeDef& node2) const { return true; } +NodeDef* ArithmeticOptimizer::AddNode(const NodeDef& node, StringPiece suffix, + bool copy_node) { + return AddNode(OptimizedNodeName(node, suffix), copy_node ? &node : nullptr); +} + NodeDef* ArithmeticOptimizer::AddNode(const string& name, const NodeDef* node_to_copy) { NodeDef* new_node = optimized_graph_->add_node(); - const string name_with_prefix = - AddPrefixToNodeName(name, kArithmeticOptimizer); - node_map_->AddNode(NodeName(name_with_prefix), new_node); + node_map_->AddNode(NodeName(name), new_node); if (node_to_copy != nullptr) { *new_node = *node_to_copy; } - new_node->set_name(name_with_prefix); + new_node->set_name(name); return new_node; } -bool ArithmeticOptimizer::OptimizedNodeExists(const string& name) { - const string name_with_prefix = - AddPrefixToNodeName(name, kArithmeticOptimizer); - return node_map_->NodeExists(name_with_prefix); +string ArithmeticOptimizer::OptimizedNodeName(const NodeDef& node, + StringPiece suffix) const { + return AddPrefixToNodeName(strings::StrCat(node.name(), "_", suffix), + kArithmeticOptimizer); +} + +bool ArithmeticOptimizer::OptimizedNodeExists(const NodeDef& node, + StringPiece suffix) const { + return node_map_->NodeExists(OptimizedNodeName(node, suffix)); } bool ArithmeticOptimizer::CanDedup(const NodeDef& node) const { @@ -668,17 +676,19 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses( const DataType src_type = GetSourceDataType(*cast); const DataType dst_type = GetDestinationDataType(*cast); if (IsNumberType(src_type) && IsNumberType(dst_type) && - DataTypeSize(src_type) < DataTypeSize(dst_type)) { - NodeDef* new_transpose = - AddNode(StrCat(transpose->name(), "_", DataTypeString(src_type)), - transpose); + DataTypeSize(src_type) < DataTypeSize(dst_type) && + !OptimizedNodeExists(*cast, DataTypeString(dst_type)) && + !OptimizedNodeExists(*transpose, DataTypeString(src_type))) { + NodeDef* new_transpose = AddNode(*transpose, DataTypeString(src_type), + /*copy_node=*/true); (*new_transpose->mutable_attr())["T"].set_type(src_type); new_transpose->set_input(0, cast->input(0)); node_map_->AddOutput(input->name(), new_transpose->name()); node_map_->AddOutput(NodeName(new_transpose->input(1)), new_transpose->name()); - NodeDef* new_cast = AddNode(StrCat(cast->name(), "_new"), cast); + NodeDef* new_cast = + AddNode(*cast, DataTypeString(dst_type), /*copy_node=*/true); new_cast->set_input(0, new_transpose->name()); node_map_->AddOutput(new_transpose->name(), new_cast->name()); @@ -754,7 +764,8 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses( // multiply can be constant-folded. TODO(jingyue): When the weights aren't // constant, this should also help performance a bit and memory usage a lot, // since the weights tend to be smaller than the activations. - if (weights->op() == "Const") { + if (weights->op() == "Const" && + !OptimizedNodeExists(*weights, StrCat("scaled_", conv->name()))) { const NodeDef* source = node_map_->GetNode( GetTailOfValuePreservingChain(*node, *node_map_, nodes_to_preserve_) ->input(0)); @@ -773,7 +784,7 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses( scale_tensor.tensor_shape().dim_size() == 0) { // Create new node `scaled_weights`. NodeDef* scaled_weights = AddNode( - StrCat(weights->name(), "_scaled_", conv->name()), nullptr); + *weights, StrCat("scaled_", conv->name()), /*copy_node=*/false); scaled_weights->set_op("Mul"); scaled_weights->set_device(weights->device()); (*scaled_weights->mutable_attr())["T"] = @@ -810,9 +821,8 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses( } if (node->op() == "Mul" && node->input(0) == node->input(1) && - !OptimizedNodeExists(StrCat(node->name(), "_square"))) { - NodeDef* new_square_node = - AddNode(strings::StrCat(node->name(), "_square"), node); + !OptimizedNodeExists(*node, "square")) { + NodeDef* new_square_node = AddNode(*node, "square", /*copy_node=*/true); new_square_node->set_op("Square"); for (int i = 1; i < new_square_node->input_size(); ++i) { new_square_node->set_input(i - 1, new_square_node->input(i)); @@ -847,8 +857,8 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses( break; } } - const string mul_node_name = StrCat(node->name(), "_mul"); - if (all_equal && !OptimizedNodeExists(mul_node_name)) { + if (all_equal && !OptimizedNodeExists(*node, "const") && + !OptimizedNodeExists(*node, "mul")) { // 1. Create constant node with value N. const auto type = GetDataTypeFromAttr(*node, "T"); Tensor t(type, TensorShape({})); @@ -859,15 +869,14 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses( return ""; } TensorValue value(&t); - NodeDef* new_const_node = - AddNode(StrCat(node->name(), "_const"), nullptr); + NodeDef* new_const_node = AddNode(*node, "const", /*copy_node=*/false); *new_const_node = ConstantFolding::CreateNodeDef(new_const_node->name(), value); new_const_node->set_device(node->device()); nodes_to_simplify->PushBack(new_const_node); // 2. Replace the aggregate node with Mul(Const(N), x). - NodeDef* new_mul_node = AddNode(mul_node_name, nullptr); + NodeDef* new_mul_node = AddNode(*node, "mul", /*copy_node=*/false); new_mul_node->set_op("Mul"); new_mul_node->set_device(node->device()); SetDataTypeToAttr(type, "T", new_mul_node); @@ -892,7 +901,8 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses( // to the following: // Mul(x, AddN(y1, y2, y3, ... yn)) if (IsAggregate(*node) && NumNonControlInputs(*node) > 1 && - !OptimizedNodeExists(StrCat(node->name(), "_hoist_add"))) { + !OptimizedNodeExists(*node, "hoist_add") && + !OptimizedNodeExists(*node, "hoist_mul")) { // Determine the set of common factors if the input nodes are all Mul nodes. std::set common_factors; for (int i = 0; i < node->input_size(); ++i) { @@ -946,10 +956,9 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses( if (shapes_match) { // 1. Use a copy of the first Mul node for the outer multiplication. - NodeDef* new_mul_node = AddNode(StrCat(node->name(), "_hoist_mul"), + NodeDef* new_mul_node = AddNode(OptimizedNodeName(*node, "hoist_mul"), node_map_->GetNode(node->input(0))); - NodeDef* new_add_node = - AddNode(StrCat(node->name(), "_hoist_add"), node); + NodeDef* new_add_node = AddNode(*node, "hoist_add", /*copy_node=*/true); new_mul_node->set_device(node->device()); new_mul_node->set_input(0, common_factor); node_map_->AddOutput(common_factor, new_mul_node->name()); @@ -978,7 +987,7 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses( // Fold Transpose into matrix multiplication. if ((node->op() == "MatMul" || node->op() == "SparseMatMul" || node->op() == "BatchMatMul") && - !OptimizedNodeExists(StrCat(node->name(), "_fused"))) { + !OptimizedNodeExists(*node, "fused")) { const NodeDef* a = node_map_->GetNode(node->input(0)); const NodeDef* b = node_map_->GetNode(node->input(1)); bool is_complex = false; @@ -996,7 +1005,7 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses( const bool b_is_foldable = foldable_transpose_ops.count(b->op()) > 0 && IsInnerMatrixTransposeNode(*b, node_map_.get()); if (a_is_foldable || b_is_foldable) { - NodeDef* new_op = AddNode(StrCat(node->name(), "_fused"), node); + NodeDef* new_op = AddNode(*node, "fused", /*copy_node=*/true); if (a_is_foldable) { const string attr_a = node->op() == "BatchMatMul" ? "adj_x" : "transpose_a"; @@ -1021,7 +1030,7 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses( // Fold Conj into Transpose or ConjugateTranspose. if ((node->op() == "Conj" || node->op() == "Transpose" || node->op() == "ConjugateTranspose") && - !OptimizedNodeExists(StrCat(node->name(), "_fused"))) { + !OptimizedNodeExists(*node, "fused")) { const NodeDef* input = node_map_->GetNode(node->input(0)); const NodeDef* transpose_op = node->op() == "Conj" ? input : node; const NodeDef* conj_op = node->op() == "Conj" ? node : input; @@ -1029,7 +1038,8 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses( if ((transpose_op->op() == "Transpose" || transpose_op->op() == "ConjugateTranspose") && conj_op->op() == "Conj") { - NodeDef* new_op = AddNode(StrCat(node->name(), "_fused"), transpose_op); + NodeDef* new_op = + AddNode(OptimizedNodeName(*node, "fused"), transpose_op); // Flip the type of transpose op to absorb the conjugation. new_op->set_op(transpose_op->op() == "Transpose" ? "ConjugateTranspose" : "Transpose"); diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h index ec269792386189e5a590a99af020803810f36b1a..afd538db408aa859a108e08b2de9efad635d515c 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h @@ -48,7 +48,13 @@ class ArithmeticOptimizer : public GraphOptimizer { private: // Returns true is a node with given name and the optimizer prefix already // exists. - bool OptimizedNodeExists(const string& name); + string OptimizedNodeName(const NodeDef& node, StringPiece suffix) const; + bool OptimizedNodeExists(const NodeDef& node, StringPiece suffix) const; + + // Creates a new node in the graph, with name equal to that of node, prefixed + // with "ArithmeticOptimizer/" and the given suffix. Also updates node_map_, + // and optionally copies node into the new node if copy_node is true. + NodeDef* AddNode(const NodeDef& node, StringPiece suffix, bool copy_node); // Creates a new node in the graph, prefixed with "ArithmeticOptimizer/", // updates node_map_, and optionally copies *node_to_copy into the new diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc index b5b1ec7021e5b901195bc1e6b6b2247410d5ff1b..2a82b250586783759608db75bf9e383f4b0322cb 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc @@ -627,7 +627,7 @@ TEST_F(ArithmeticOptimizerTest, IdentityReshape) { GraphDef output; TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output)); - item.graph = output; + item.graph.Swap(&output); TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output)); EXPECT_EQ(0, std::count_if( @@ -651,7 +651,7 @@ TEST_F(ArithmeticOptimizerTest, NotIdentityReshape) { GraphDef output; TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output)); - item.graph = output; + item.graph.Swap(&output); TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output)); EXPECT_EQ(1, std::count_if( @@ -673,7 +673,7 @@ TEST_F(ArithmeticOptimizerTest, NotIdentityReshapeTooManyUnknownDimSizes) { GraphDef output; TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output)); - item.graph = output; + item.graph.Swap(&output); TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output)); EXPECT_EQ(1, std::count_if( @@ -706,7 +706,7 @@ TEST_F(ArithmeticOptimizerTest, CombineReshapes) { GraphDef output; TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output)); - item.graph = output; + item.graph.Swap(&output); TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output)); EXPECT_EQ(1, std::count_if( @@ -730,7 +730,7 @@ TEST_F(ArithmeticOptimizerTest, ReorderTransposeCast) { GraphDef output; TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output)); - item.graph = output; + item.graph.Swap(&output); TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output)); const NodeDef* transpose_node = nullptr; @@ -766,7 +766,7 @@ TEST_F(ArithmeticOptimizerTest, NoReorderTransposeCast) { GraphDef output; TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output)); - item.graph = output; + item.graph.Swap(&output); TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output)); int num_transposes = 0; @@ -800,7 +800,7 @@ TEST_F(ArithmeticOptimizerTest, RemoveInverseTransposes) { GraphDef output; TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output)); - item.graph = output; + item.graph.Swap(&output); TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output)); std::set nodes_after_optimization; @@ -833,7 +833,7 @@ TEST_F(ArithmeticOptimizerTest, RemoveInverseTransposesMultipleOutputs) { GraphDef output; TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output)); - item.graph = output; + item.graph.Swap(&output); TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output)); for (const NodeDef& node : output.node()) { @@ -860,7 +860,7 @@ TEST_F(ArithmeticOptimizerTest, RemoveTransposesWithControlDependency) { TF_CHECK_OK(s.ToGraphDef(&item.graph)); GraphDef output; TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output)); - item.graph = output; + item.graph.Swap(&output); TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output)); NodeMap node_map(&output); @@ -889,7 +889,7 @@ TEST_F(ArithmeticOptimizerTest, NotRemoveTransposes) { GraphDef output; TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output)); - item.graph = output; + item.graph.Swap(&output); TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output)); EXPECT_EQ(6, output.node_size()); @@ -920,7 +920,7 @@ TEST_F(ArithmeticOptimizerTest, FoldMulToTransposeConv) { GraphDef output; TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output)); - item.graph = output; + item.graph.Swap(&output); TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output)); NodeMap node_map(&output); @@ -962,7 +962,7 @@ TEST_F(ArithmeticOptimizerTest, NotFoldMulAcrossPreservedTranspose) { GraphDef output; TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output)); - item.graph = output; + item.graph.Swap(&output); TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output)); NodeMap node_map(&output); @@ -992,7 +992,7 @@ TEST_F(ArithmeticOptimizerTest, FoldMulToConv) { GraphDef output; TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output)); - item.graph = output; + item.graph.Swap(&output); TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output)); NodeMap node_map(&output); @@ -1031,11 +1031,15 @@ TEST_F(ArithmeticOptimizerTest, OptimizeCastMulTransposeConv) { GraphDef output; TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output)); - item.graph = output; + // Run the optimizer twice to make sure the rewrite is idempotent. + item.graph.Swap(&output); + TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output)); + + item.graph.Swap(&output); TF_EXPECT_OK( ConstantFolding(/*cpu_device=*/nullptr).Optimize(nullptr, item, &output)); - item.graph = output; + item.graph.Swap(&output); TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output)); NodeMap node_map(&output); @@ -1043,7 +1047,7 @@ TEST_F(ArithmeticOptimizerTest, OptimizeCastMulTransposeConv) { const NodeDef* transpose_node = CHECK_NOTNULL(node_map.GetNode(OptimizedName("Transpose_uint8"))); const NodeDef* cast_node = - CHECK_NOTNULL(node_map.GetNode(OptimizedName("Cast_new"))); + CHECK_NOTNULL(node_map.GetNode(OptimizedName("Cast_float"))); const NodeDef* weights_node = CHECK_NOTNULL(node_map.GetNode(OptimizedName("weights_scaled_Conv2D"))); const NodeDef* conv_node = CHECK_NOTNULL(node_map.GetNode("Conv2D")); @@ -1080,11 +1084,11 @@ TEST_F(ArithmeticOptimizerTest, OptimizeMultipleMulTransposeConv) { GraphDef output; TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output)); - item.graph = output; + item.graph.Swap(&output); TF_EXPECT_OK( ConstantFolding(/*cpu_device=*/nullptr).Optimize(nullptr, item, &output)); - item.graph = output; + item.graph.Swap(&output); TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output)); NodeMap node_map(&output); @@ -1113,7 +1117,7 @@ TEST_F(ArithmeticOptimizerTest, CombineBitcasts) { GraphDef output; TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output)); - item.graph = output; + item.graph.Swap(&output); TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output)); EXPECT_EQ(1, std::count_if( @@ -1133,7 +1137,7 @@ TEST_F(ArithmeticOptimizerTest, CombineAndRemoveBitcasts) { TF_CHECK_OK(s.ToGraphDef(&item.graph)); GraphDef output; TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output)); - item.graph = output; + item.graph.Swap(&output); TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output)); EXPECT_EQ(0, std::count_if( @@ -1152,7 +1156,7 @@ TEST_F(ArithmeticOptimizerTest, RemoveRedundantCast) { TF_CHECK_OK(s.ToGraphDef(&item.graph)); GraphDef output; TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output)); - item.graph = output; + item.graph.Swap(&output); TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output)); EXPECT_EQ(0, std::count_if( diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc index 6860447fb895c3dc3e0c0a087b1ec0d36898ab28..0aeff6222c291455c04cf3fb68a90298724385dd 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding.cc @@ -128,6 +128,42 @@ bool AllValuesAre(const TensorProto& tensor, const T& value) { return false; } +// Add new_input as a control input to node if it does not already depend on it. +// TODO(rmlarsen): Move the following two utility functions to utils.{h,cc} and +// clean up code that should be using them. +bool MaybeAddControlInput(const string& new_input, NodeDef* node, + GraphDef* graph, NodeMap* node_map) { + bool already_exists = false; + for (const string& input : node->input()) { + if (input == new_input || AsControlDependency(input) == new_input) { + already_exists = true; + break; + } + } + if (!already_exists) { + const string ctrl_dep = + ConstantFolding::AddControlDependency(new_input, graph, node_map); + node->add_input(ctrl_dep); + node_map->AddOutput(NodeName(new_input), node->name()); + } + return !already_exists; +} + +// Remove old_input as a control input to node. +bool MaybeRemoveControlInput(const string& old_input, NodeDef* node, + GraphDef* graph, NodeMap* node_map) { + for (int i = 0; i < node->input_size(); ++i) { + const string& input = node->input(i); + if (IsControlInput(input) && AsControlDependency(old_input) == input) { + node->mutable_input()->SwapElements(i, node->input_size() - 1); + node->mutable_input()->RemoveLast(); + node_map->RemoveOutput(NodeName(old_input), node->name()); + return true; + } + } + return false; +} + } // namespace ConstantFolding::ConstantFolding(RewriterConfig::Toggle opt_level, @@ -1524,14 +1560,15 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output, // // + + = parent // / \ / \ - // Const + -- > X + = children + // C + -- > X + = children // / \ / \ - // X Y Const Y = leaves + // X Y C Y = leaves // - // where '+' denotes an associative and commutative operator like addition - // or multiplication. This optimization pushes constants down in the tree - // to canonicalize it. Moreoever, in cases where the child node has a - // constant input we will create a node that can be folded, e.g. + // where C is constant and X is non-constant, and '+' denotes an + // associative and commutative operator like addition or multiplication. + // This optimization pushes constants down in the tree to canonicalize it. + // Moreoever, in cases where the child node has a second constant input Y + // we will create a leaf node that can be folded, e.g. // // Add(C1, Add(C2, X)) -> Add(X, Add(C1, C2)) -> Add(X, C1 + C2) // @@ -1540,7 +1577,8 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output, // division/multiplication. // Don't touch BiasAdd since they can't handle vectors as their first // inputs. - if ((IsAdd(*node) || is_mul) && NumNonControlInputs(*node) == 2) { + if (has_fetch_ && (IsAdd(*node) || is_mul) && + NumNonControlInputs(*node) == 2) { NodeDef* left_child = node_map_->GetNode(node->input(0)); NodeDef* right_child = node_map_->GetNode(node->input(1)); // One child must be constant, and the other the same op as the parent. @@ -1556,18 +1594,21 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output, node->device() != right_child->device()) { continue; } - NodeDef* child_node = left_child_is_constant ? right_child : left_child; + NodeDef* op_child_node = + left_child_is_constant ? right_child : left_child; + NodeDef* const_child_node = + left_child_is_constant ? left_child : right_child; // Make sure that it is safe to change the value of the child node-> - if (child_node->input_size() < 2 || - NumNonControlOutputs(*child_node, *node_map_) > 1 || !has_fetch_ || - nodes_to_preserve_.find(child_node->name()) != + if (op_child_node->input_size() < 2 || + NumNonControlOutputs(*op_child_node, *node_map_) > 1 || + nodes_to_preserve_.find(op_child_node->name()) != nodes_to_preserve_.end()) { continue; } // Identify the nodes to swap. - const NodeDef* left_leaf = node_map_->GetNode(child_node->input(0)); - const NodeDef* right_leaf = node_map_->GetNode(child_node->input(1)); + NodeDef* left_leaf = node_map_->GetNode(op_child_node->input(0)); + NodeDef* right_leaf = node_map_->GetNode(op_child_node->input(1)); const bool left_leaf_is_constant = IsReallyConstant(*left_leaf); const bool right_leaf_is_constant = IsReallyConstant(*right_leaf); if (left_leaf_is_constant && right_leaf_is_constant) { @@ -1576,15 +1617,27 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output, } const int non_const_leaf_input = left_leaf_is_constant ? 1 : 0; const int parent_const_input = left_child_is_constant ? 0 : 1; + const auto& child_output = node_map_->GetOutputs(op_child_node->name()); + if (child_output.find(const_child_node) != child_output.end()) { + // If there is a control edge from the child op to C, the transformation + // would create a cycle in the graph. We know that it must be a control + // edge. We can replace such a control edge with a control edge from A + // to C. + CHECK(MaybeRemoveControlInput(op_child_node->name(), const_child_node, + graph_, node_map_.get())); + NodeDef* other_leaf = left_leaf_is_constant ? left_leaf : right_leaf; + MaybeAddControlInput(other_leaf->name(), const_child_node, graph_, + node_map_.get()); + } // Swap the constant child with a non-constant leaf node. node_map_->UpdateInput(node->name(), node->input(parent_const_input), - child_node->input(non_const_leaf_input)); - node_map_->UpdateInput(child_node->name(), - child_node->input(non_const_leaf_input), + op_child_node->input(non_const_leaf_input)); + node_map_->UpdateInput(op_child_node->name(), + op_child_node->input(non_const_leaf_input), node->input(parent_const_input)); std::swap(*node->mutable_input(parent_const_input), - *child_node->mutable_input(non_const_leaf_input)); + *op_child_node->mutable_input(non_const_leaf_input)); graph_modified_ = true; } } diff --git a/tensorflow/core/grappler/optimizers/constant_folding.h b/tensorflow/core/grappler/optimizers/constant_folding.h index 6aadd9750893bd008b353e6227d82723166edd6e..18acc91e8a18f4bf2eb77c7e5171eaca4ff5bec5 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.h +++ b/tensorflow/core/grappler/optimizers/constant_folding.h @@ -52,7 +52,6 @@ class ConstantFolding : public GraphOptimizer { private: string OptimizedNodeName(const NodeDef& node, StringPiece suffix) const; - string OptimizedNodeName(const NodeDef& node) const; bool OptimizedNodeExists(const NodeDef& node, StringPiece suffix) const; bool IsReallyConstant(const NodeDef& node) const; diff --git a/tensorflow/core/grappler/optimizers/constant_folding_test.cc b/tensorflow/core/grappler/optimizers/constant_folding_test.cc index 2db3dc699341ab3b582e6ee17b2611410cf27366..849a88770ae6127c6f2e3fac968a976c0a523a0b 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding_test.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding_test.cc @@ -80,18 +80,25 @@ TEST_F(ConstantFoldingTest, SimpleFolding) { TEST_F(ConstantFoldingTest, AddTree) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); - Output c1 = ops::Const(s.WithOpName("c1"), 2.0f, {1}); Output c2 = ops::Const(s.WithOpName("c2"), 2.0f, {2}); - Output c4 = ops::Const(s.WithOpName("c4"), 4.0f, {2}); + Output c3 = ops::Const(s.WithOpName("c3"), 3.0f, {2}); Output x = ops::Placeholder(s.WithOpName("x"), DT_FLOAT, ops::Placeholder::Shape(TensorShape({2, 2}))); Output add_child = ops::Add(s.WithOpName("add_child"), c2, x); + Output c1 = ops::Const(s.WithOpName("c1").WithControlDependencies(add_child), + 1.0f, {1}); Output add_parent = ops::Add(s.WithOpName("add_parent"), c1, add_child); - Output mul_child = ops::Mul(s.WithOpName("mul_child"), c2, x); - Output mul_parent = ops::Mul(s.WithOpName("mul_parent"), c1, mul_child); - Output addmul_child = ops::Add(s.WithOpName("addmul_child"), c2, x); + + Output y = ops::Placeholder(s.WithOpName("y"), DT_FLOAT, + ops::Placeholder::Shape(TensorShape({2, 2}))); + Output c4 = ops::Const(s.WithOpName("c4"), 4.0f, {2}); + Output c5 = ops::Const(s.WithOpName("c5"), 5.0f, {2}); + Output c20 = ops::Const(s.WithOpName("c20"), 20.0f, {2}); + Output mul_child = ops::Mul(s.WithOpName("mul_child"), c4, y); + Output mul_parent = ops::Mul(s.WithOpName("mul_parent"), c5, mul_child); + Output addmul_child = ops::Add(s.WithOpName("addmul_child"), c4, x); Output addmul_parent = - ops::Mul(s.WithOpName("addmul_parent"), c1, addmul_child); + ops::Mul(s.WithOpName("addmul_parent"), c5, addmul_child); GrapplerItem item; item.fetch = {"add_parent", "mul_parent", "addmul_parent"}; @@ -102,15 +109,21 @@ TEST_F(ConstantFoldingTest, AddTree) { Status status = fold.Optimize(nullptr, item, &output); TF_EXPECT_OK(status); - EXPECT_EQ(9, output.node_size()); - - // We expect the following rewrite(s) to occur (for both Add and Mul): + // We expect the following rewrite(s) to occur: + // // + + + // / \ / \ / \ - // 2.0 + --> x + --> x 4.0 - // / \ / \ - // 2.0 x 2.0 2.0 + // 1.0 + --> x + --> x 3.0 + // / \ / \ + // 2.0 x 1.0 2.0 + // + // * * * + // / \ / \ / \ + // 4.0 * --> y * --> y 20.0 + // / \ / \ + // 5.0 y 4.0 5.0 + EXPECT_EQ(11, output.node_size()); for (const auto& node : output.node()) { if (node.name() == "add_child") { EXPECT_EQ("Const", node.op()); @@ -130,26 +143,26 @@ TEST_F(ConstantFoldingTest, AddTree) { } else if (node.name() == "mul_parent") { EXPECT_EQ("Mul", node.op()); EXPECT_EQ(2, node.input_size()); - EXPECT_EQ("x", node.input(0)); + EXPECT_EQ("y", node.input(0)); EXPECT_EQ("mul_child", node.input(1)); } else if (node.name() == "addmul_child") { // Unchanged. EXPECT_EQ("Add", node.op()); EXPECT_EQ(2, node.input_size()); - EXPECT_EQ("c2", node.input(0)); + EXPECT_EQ("c4", node.input(0)); EXPECT_EQ("x", node.input(1)); } } - // Check that the reciprocals have the expected value. - std::vector fetch = {"c4"}; + // Check that the result nodes have the expected value. + std::vector fetch = {"c3", "c20"}; auto tensor_expected = EvaluateNodes(item.graph, fetch); EXPECT_EQ(fetch.size(), tensor_expected.size()); fetch = {"add_child", "mul_child"}; auto tensors = EvaluateNodes(output, fetch); EXPECT_EQ(fetch.size(), tensors.size()); for (int i = 0; i < fetch.size(); i++) { - test::ExpectTensorEqual(tensor_expected[0], tensors[i]); + test::ExpectTensorEqual(tensor_expected[i], tensors[i]); } } diff --git a/tensorflow/core/grappler/optimizers/dependency_optimizer.cc b/tensorflow/core/grappler/optimizers/dependency_optimizer.cc index 1f68ecbade9147b652ac970aa1c5ec4b056209c7..d2da125236ab4f9b386ba2c6dc808e2b030c819c 100644 --- a/tensorflow/core/grappler/optimizers/dependency_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/dependency_optimizer.cc @@ -58,11 +58,7 @@ void PruneControlInputs(NodeDef* node) { int pos = 0; while (pos < node->input_size()) { const string& input = node->input(pos); - // TODO(rmlarsen): Remove control inputs that also appears as a regular - // inputs. Currently, doing so breaks testControlFlowStrictness in - // python/framework/function_test. - // if (!inputs.insert(NodeName(input)).second && IsControlInput(input)) { - if (IsControlInput(input) && !inputs.insert(input).second) { + if (!inputs.insert(NodeName(input)).second && IsControlInput(input)) { VLOG(1) << "**** Removing duplicate control input: " << input << " from node " << node->DebugString(); node->mutable_input()->SwapElements(pos, node->input_size() - 1); diff --git a/tensorflow/core/grappler/optimizers/dependency_optimizer.h b/tensorflow/core/grappler/optimizers/dependency_optimizer.h index 3f6f418bee69cc86d8865bccd266803ade2ef2c1..02d8a0f32a9bbe4e49c484ece601e219257908c0 100644 --- a/tensorflow/core/grappler/optimizers/dependency_optimizer.h +++ b/tensorflow/core/grappler/optimizers/dependency_optimizer.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DEPENDENCY_OPTIMIZER_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DEPENDENCY_OPTIMIZER_H_ +#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DEPENDENCY_OPTIMIZER_H_ +#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DEPENDENCY_OPTIMIZER_H_ #include #include "tensorflow/core/grappler/optimizers/graph_optimizer.h" @@ -73,4 +73,4 @@ class DependencyOptimizer : public GraphOptimizer { } // end namespace grappler } // end namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DEPENDENCY_OPTIMIZER_H_ +#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DEPENDENCY_OPTIMIZER_H_ diff --git a/tensorflow/core/grappler/optimizers/layout_optimizer.cc b/tensorflow/core/grappler/optimizers/layout_optimizer.cc index ea7b05d3810f7a4b9f6388e040df930526f6e47e..50e6ba4a6483cf55e32e3d04f1b3af42c48d9f87 100644 --- a/tensorflow/core/grappler/optimizers/layout_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/layout_optimizer.cc @@ -590,7 +590,7 @@ class NodeProcessor : public GraphProcessor { // to ensure added_node is in the same frame with node_. NodeDef* added_node = graph_->add_node(); *added_node = *input_node; - string base_name = strings::StrCat(node_->name(), "-", input_node->name()); + string base_name = strings::StrCat(node_->name(), "-", input_index); string node_name = LayoutOptimizerNode(base_name); added_node->set_name(node_name); *node_->mutable_input(input_index) = node_name; @@ -1647,12 +1647,32 @@ class StridedSliceProcessor : public SliceProcessor { return errors::InvalidArgument("invalid mask value: ", i); } if (i == 0 || i == 1 || i == 14 || i == 15) return Status::OK(); - if (i == 2 || i == 3) i += 2; - if (i == 4 || i == 5) i += 4; - if (i == 6 || i == 7) i += 6; - if (i == 8 || i == 9) i -= 6; - if (i == 10 || i == 11) i -= 4; - if (i == 12 || i == 13) i -= 2; + switch (i) { + case 2: + case 3: + i += 2; + break; + case 4: + case 5: + i += 4; + break; + case 6: + case 7: + i += 6; + break; + case 8: + case 9: + i -= 6; + break; + case 10: + case 11: + i -= 4; + break; + case 12: + case 13: + i -= 2; + break; + } node_->mutable_attr()->at(mask).set_i(i); return Status::OK(); } diff --git a/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc b/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc index 587642c96e879f62f7ead809e7d01888ef320f93..5cb366df2dccee2260c6f407e992e73296712ccc 100644 --- a/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc @@ -172,8 +172,7 @@ TEST_F(LayoutOptimizerTest, Conv2DBackpropInput) { Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output); NodeMap node_map(&output); - string input_name = - strings::StrCat("Conv2DBackpropInput-InputSizes", "-", "LayoutOptimizer"); + string input_name = "Conv2DBackpropInput-0-LayoutOptimizer"; auto input_sizes_node = node_map.GetNode(input_name); CHECK(input_sizes_node); auto conv2d_backprop_node = node_map.GetNode("Conv2DBackpropInput"); @@ -288,7 +287,7 @@ TEST_F(LayoutOptimizerTest, Pad) { auto pad = node_map.GetNode("p"); EXPECT_EQ(pad->input(0), "Conv2D"); - auto pad_const = node_map.GetNode("p-c-LayoutOptimizer"); + auto pad_const = node_map.GetNode("p-1-LayoutOptimizer"); EXPECT_TRUE(pad_const); EXPECT_TRUE(pad_const->attr().find("value") != pad_const->attr().end()); Tensor tensor; @@ -476,9 +475,9 @@ TEST_F(LayoutOptimizerTest, SplitDimC) { Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output); NodeMap node_map(&output); auto split_node = node_map.GetNode("split"); - EXPECT_EQ(split_node->input(0), "split-c-LayoutOptimizer"); + EXPECT_EQ(split_node->input(0), "split-0-LayoutOptimizer"); EXPECT_EQ(split_node->input(1), "Conv2D"); - auto split_const = node_map.GetNode("split-c-LayoutOptimizer"); + auto split_const = node_map.GetNode("split-0-LayoutOptimizer"); EXPECT_EQ(split_const->op(), "Const"); EXPECT_EQ(split_const->attr().at({"value"}).tensor().int_val(0), 1); } @@ -496,9 +495,9 @@ TEST_F(LayoutOptimizerTest, SplitDimH) { Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output); NodeMap node_map(&output); auto split_node = node_map.GetNode("split"); - EXPECT_EQ(split_node->input(0), "split-c-LayoutOptimizer"); + EXPECT_EQ(split_node->input(0), "split-0-LayoutOptimizer"); EXPECT_EQ(split_node->input(1), "Conv2D"); - auto split_const = node_map.GetNode("split-c-LayoutOptimizer"); + auto split_const = node_map.GetNode("split-0-LayoutOptimizer"); EXPECT_EQ(split_const->op(), "Const"); EXPECT_EQ(split_const->attr().at({"value"}).tensor().int_val(0), 2); } @@ -516,9 +515,9 @@ TEST_F(LayoutOptimizerTest, SplitDimW) { Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output); NodeMap node_map(&output); auto split_node = node_map.GetNode("split"); - EXPECT_EQ(split_node->input(0), "split-c-LayoutOptimizer"); + EXPECT_EQ(split_node->input(0), "split-0-LayoutOptimizer"); EXPECT_EQ(split_node->input(1), "Conv2D"); - auto split_const = node_map.GetNode("split-c-LayoutOptimizer"); + auto split_const = node_map.GetNode("split-0-LayoutOptimizer"); EXPECT_EQ(split_const->op(), "Const"); EXPECT_EQ(split_const->attr().at({"value"}).tensor().int_val(0), 3); } @@ -536,9 +535,9 @@ TEST_F(LayoutOptimizerTest, SplitDimN) { Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output); NodeMap node_map(&output); auto split_node = node_map.GetNode("split"); - EXPECT_EQ(split_node->input(0), "split-c-LayoutOptimizer"); + EXPECT_EQ(split_node->input(0), "split-0-LayoutOptimizer"); EXPECT_EQ(split_node->input(1), "Conv2D"); - auto split_const = node_map.GetNode("split-c-LayoutOptimizer"); + auto split_const = node_map.GetNode("split-0-LayoutOptimizer"); EXPECT_EQ(split_const->op(), "Const"); EXPECT_EQ(split_const->attr().at({"value"}).tensor().int_val(0), 0); } @@ -582,8 +581,8 @@ TEST_F(LayoutOptimizerTest, SplitSamePortToMultipleInputsOfSameNode) { EXPECT_EQ(concat_node->input(0), "split:1"); EXPECT_EQ(concat_node->input(1), "split:1"); EXPECT_EQ(concat_node->input(2), "split:1"); - EXPECT_EQ(concat_node->input(3), "concat-axis-LayoutOptimizer"); - auto concat_dim = node_map.GetNode("concat-axis-LayoutOptimizer"); + EXPECT_EQ(concat_node->input(3), "concat-3-LayoutOptimizer"); + auto concat_dim = node_map.GetNode("concat-3-LayoutOptimizer"); EXPECT_EQ(concat_dim->attr().at({"value"}).tensor().int_val(0), 1); } @@ -603,8 +602,8 @@ TEST_F(LayoutOptimizerTest, ConcatDimH) { auto concat_node = node_map.GetNode("concat"); EXPECT_EQ(concat_node->input(0), "split"); EXPECT_EQ(concat_node->input(1), "split:1"); - EXPECT_EQ(concat_node->input(2), "concat-axis-LayoutOptimizer"); - auto concat_dim = node_map.GetNode("concat-axis-LayoutOptimizer"); + EXPECT_EQ(concat_node->input(2), "concat-2-LayoutOptimizer"); + auto concat_dim = node_map.GetNode("concat-2-LayoutOptimizer"); EXPECT_EQ(concat_dim->attr().at({"value"}).tensor().int_val(0), 2); } @@ -648,8 +647,8 @@ TEST_F(LayoutOptimizerTest, ConcatDimW) { auto concat_node = node_map.GetNode("concat"); EXPECT_EQ(concat_node->input(0), "split"); EXPECT_EQ(concat_node->input(1), "split:1"); - EXPECT_EQ(concat_node->input(2), "concat-axis-LayoutOptimizer"); - auto concat_dim = node_map.GetNode("concat-axis-LayoutOptimizer"); + EXPECT_EQ(concat_node->input(2), "concat-2-LayoutOptimizer"); + auto concat_dim = node_map.GetNode("concat-2-LayoutOptimizer"); EXPECT_EQ(concat_dim->attr().at({"value"}).tensor().int_val(0), 3); } @@ -669,8 +668,8 @@ TEST_F(LayoutOptimizerTest, ConcatDimN) { auto concat_node = node_map.GetNode("concat"); EXPECT_EQ(concat_node->input(0), "split"); EXPECT_EQ(concat_node->input(1), "split:1"); - EXPECT_EQ(concat_node->input(2), "concat-axis-LayoutOptimizer"); - auto concat_dim = node_map.GetNode("concat-axis-LayoutOptimizer"); + EXPECT_EQ(concat_node->input(2), "concat-2-LayoutOptimizer"); + auto concat_dim = node_map.GetNode("concat-2-LayoutOptimizer"); EXPECT_EQ(concat_dim->attr().at({"value"}).tensor().int_val(0), 0); } @@ -690,8 +689,8 @@ TEST_F(LayoutOptimizerTest, ConcatDimC) { auto concat_node = node_map.GetNode("concat"); EXPECT_EQ(concat_node->input(0), "split"); EXPECT_EQ(concat_node->input(1), "split:1"); - EXPECT_EQ(concat_node->input(2), "concat-axis-LayoutOptimizer"); - auto concat_dim = node_map.GetNode("concat-axis-LayoutOptimizer"); + EXPECT_EQ(concat_node->input(2), "concat-2-LayoutOptimizer"); + auto concat_dim = node_map.GetNode("concat-2-LayoutOptimizer"); EXPECT_EQ(concat_dim->attr().at({"value"}).tensor().int_val(0), 1); } @@ -861,10 +860,10 @@ TEST_F(LayoutOptimizerTest, SliceConst) { NodeMap node_map(&output); auto slice_node = node_map.GetNode("slice"); EXPECT_EQ(slice_node->input(0), "Conv2D"); - EXPECT_EQ(slice_node->input(1), "slice-begin-LayoutOptimizer"); - EXPECT_EQ(slice_node->input(2), "slice-size-LayoutOptimizer"); + EXPECT_EQ(slice_node->input(1), "slice-1-LayoutOptimizer"); + EXPECT_EQ(slice_node->input(2), "slice-2-LayoutOptimizer"); - auto begin_const = node_map.GetNode("slice-begin-LayoutOptimizer"); + auto begin_const = node_map.GetNode("slice-1-LayoutOptimizer"); Tensor begin_tensor; EXPECT_TRUE(begin_tensor.FromProto( begin_const->mutable_attr()->at({"value"}).tensor())); @@ -872,7 +871,7 @@ TEST_F(LayoutOptimizerTest, SliceConst) { test::FillValues(&begin_tensor_expected, {0, 1, 2, 3}); test::ExpectTensorEqual(begin_tensor_expected, begin_tensor); - auto size_const = node_map.GetNode("slice-size-LayoutOptimizer"); + auto size_const = node_map.GetNode("slice-2-LayoutOptimizer"); Tensor size_tensor; EXPECT_TRUE(size_tensor.FromProto( size_const->mutable_attr()->at({"value"}).tensor())); diff --git a/tensorflow/core/grappler/optimizers/memory_optimizer.cc b/tensorflow/core/grappler/optimizers/memory_optimizer.cc index 8418abd80f84a7675dd34414dc582fb31089672b..f537ecc41b964fb6c5f2e24891891c9407fcffef 100644 --- a/tensorflow/core/grappler/optimizers/memory_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/memory_optimizer.cc @@ -829,7 +829,7 @@ static NodeDef* FindSwapOutTrigger( view.GetFanout(generator); NodeDef* trigger = nullptr; Costs::NanoSeconds earliest_fanout( - static_cast(std::numeric_limits::max())); + static_cast(std::numeric_limits::max() >> 2)); for (const auto& port : fanout) { if (port.node == node) { @@ -861,8 +861,9 @@ static bool IsSwappable(GraphView::InputPort input) { return !IsRefType(dtype); } -static bool IdentifySwappingCandidates(Cluster* cluster, GrapplerItem* item, - std::unordered_set* skip_list) { +static bool IdentifySwappingCandidates( + Cluster* cluster, GrapplerItem* item, std::unordered_set* skip_list, + std::unordered_map* nodes_to_swap) { GraphMemory memory(*item); const std::unordered_map& devices = cluster->GetDevices(); @@ -960,9 +961,8 @@ static bool IdentifySwappingCandidates(Cluster* cluster, GrapplerItem* item, } } if (!found) { - AttrValue& val = - (*fanout_to_swap.node->mutable_attr())["_swap_to_host"]; - val.mutable_list()->add_i(fanout_to_swap.port_id); + (*nodes_to_swap)[fanout_to_swap.node].inputs_to_swap.push_back( + fanout_to_swap.port_id); required_savings -= live_tensor.memory_used; updated_graph = true; if (required_savings < 0) { @@ -978,14 +978,13 @@ static bool IdentifySwappingCandidates(Cluster* cluster, GrapplerItem* item, bool SwappingPass(RewriterConfig::MemOptType optimization_level, Cluster* cluster, GrapplerItem* item, std::unordered_set* skip_list) { - bool updated_graph = false; + std::unordered_map nodes_to_swap; if (optimization_level == RewriterConfig::SWAPPING_HEURISTICS || optimization_level == RewriterConfig::HEURISTICS) { // Use heuristics to figure out what needs to be swapped; - updated_graph = IdentifySwappingCandidates(cluster, item, skip_list); + IdentifySwappingCandidates(cluster, item, skip_list, &nodes_to_swap); } // Look for manual annotatations in the graph. - std::unordered_map nodes_to_swap; for (auto& node : *item->graph.mutable_node()) { if (node.attr().count("_swap_to_host") != 0) { SwapInfo& swap_info = nodes_to_swap[&node]; @@ -1035,10 +1034,11 @@ bool SwappingPass(RewriterConfig::MemOptType optimization_level, } GraphView view(&item->graph); + bool updated_graph = false; + for (auto& swap : nodes_to_swap) { NodeDef* node = swap.first; const SwapInfo& swap_info = swap.second; - if (skip_list->find(node->name()) != skip_list->end()) { continue; } @@ -1064,7 +1064,7 @@ bool SwappingPass(RewriterConfig::MemOptType optimization_level, skip_list->insert(input_name); } - // Make sure the tensor isn't swapped out quickly look for node that + // Make sure the tensor is swapped out quickly: look for node that // will execute just after the tensor is generated and add a control // dependency from the swap out node to that node. NodeDef* out_trigger = diff --git a/tensorflow/core/grappler/optimizers/memory_optimizer_test.cc b/tensorflow/core/grappler/optimizers/memory_optimizer_test.cc index 185ac6040c4ce85ca5e7f8eadbe41b05fbe339df..dd2d20d8d682856a8a94f99e4ca2aa706331d9d4 100644 --- a/tensorflow/core/grappler/optimizers/memory_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/memory_optimizer_test.cc @@ -280,10 +280,14 @@ TEST_F(MemoryOptimizerTest, SwappingHeuristics) { Output axis = ops::Const(s.WithOpName("axis"), 0); Output e = ops::Concat(s.WithOpName("e").WithDevice("/gpu:0"), {a, b, c, d}, axis); + Output f = ops::Square(s.WithOpName("f").WithDevice("/gpu:0"), a); + Output g = ops::Sqrt(s.WithOpName("g").WithDevice("/gpu:0"), b); + Output h = ops::Exp(s.WithOpName("h").WithDevice("/gpu:0"), c); + Output i = ops::Log(s.WithOpName("i").WithDevice("/gpu:0"), d); GrapplerItem item; TF_CHECK_OK(s.ToGraphDef(&item.graph)); - item.fetch = {"e"}; + item.fetch = {"e", "f", "g", "h", "i"}; std::unique_ptr cluster(CreateVirtualCluster()); @@ -294,14 +298,12 @@ TEST_F(MemoryOptimizerTest, SwappingHeuristics) { for (const auto& node : output.node()) { if (node.name() == "e") { - EXPECT_TRUE(node.attr().count("_swap_to_host") > 0); - const AttrValue& val = node.attr().at("_swap_to_host"); - EXPECT_TRUE(val.has_list()); - std::set inputs_to_swap; - for (int64 input_id : val.list().i()) { - inputs_to_swap.insert(input_id); - } - EXPECT_EQ(std::set({1, 2, 3}), inputs_to_swap); + EXPECT_EQ(5, node.input_size()); + EXPECT_EQ("a", node.input(0)); + EXPECT_EQ("swap_in_e_1", node.input(1)); + EXPECT_EQ("swap_in_e_2", node.input(2)); + EXPECT_EQ("swap_in_e_3", node.input(3)); + EXPECT_EQ("axis", node.input(4)); } } } @@ -333,9 +335,10 @@ TEST_F(MemoryOptimizerTest, UnswappableInputs) { TF_EXPECT_OK(status); for (const auto& node : output.node()) { - if (node.name() == "d") { - EXPECT_EQ(1, node.attr().count("_swap_to_host")); - EXPECT_EQ(2, node.attr().at("_swap_to_host").list().i(0)); + if (node.name() == "e") { + // The d node isn't swappable. + EXPECT_EQ(4, node.input_size()); + EXPECT_EQ("d", node.input(2)); } } } diff --git a/tensorflow/core/grappler/optimizers/static_schedule.h b/tensorflow/core/grappler/optimizers/static_schedule.h index aa2726a2bdf95fa6f73d131e36371b8c18de1aaf..678b4d193fb30610820769d5e899322f924da4ad 100644 --- a/tensorflow/core/grappler/optimizers/static_schedule.h +++ b/tensorflow/core/grappler/optimizers/static_schedule.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_STATIC_SCHEDULE_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_STATIC_SCHEDULE_H_ +#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_STATIC_SCHEDULE_H_ +#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_STATIC_SCHEDULE_H_ #include @@ -47,4 +47,4 @@ Status EstimateRequiredTimes( } // namespace grappler } // end namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_STATIC_SCHEDULE_H_ +#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_STATIC_SCHEDULE_H_ diff --git a/tensorflow/core/grappler/utils/frame.h b/tensorflow/core/grappler/utils/frame.h index be726ae795769609769709746ce7bb74f849e37a..95b72748f4e1f13f1c61d64c4a457287e9d7d46b 100644 --- a/tensorflow/core/grappler/utils/frame.h +++ b/tensorflow/core/grappler/utils/frame.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_GRAPPLER_UTILS_FRAME_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_GRAPPLER_UTILS_FRAME_H_ +#ifndef TENSORFLOW_CORE_GRAPPLER_UTILS_FRAME_H_ +#define TENSORFLOW_CORE_GRAPPLER_UTILS_FRAME_H_ #include #include "tensorflow/core/framework/graph.pb.h" @@ -40,4 +40,4 @@ Status IdentifyFramesWithNodeMap(const GraphDef& graph, const NodeMap& node_map, } // namespace grappler } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_GRAPPLER_UTILS_FRAME_H_ +#endif // TENSORFLOW_CORE_GRAPPLER_UTILS_FRAME_H_ diff --git a/tensorflow/core/grappler/utils/scc.h b/tensorflow/core/grappler/utils/scc.h index 4e46169971ac5a92b79370c01d4634cf9e6c1b96..4fb7aab6474c35eaa9d3ebbb93f0a70ab16c5fb4 100644 --- a/tensorflow/core/grappler/utils/scc.h +++ b/tensorflow/core/grappler/utils/scc.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_GRAPPLER_UTILS_SCC_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_GRAPPLER_UTILS_SCC_H_ +#ifndef TENSORFLOW_CORE_GRAPPLER_UTILS_SCC_H_ +#define TENSORFLOW_CORE_GRAPPLER_UTILS_SCC_H_ #include #include "tensorflow/core/framework/graph.pb.h" @@ -43,4 +43,4 @@ int IdentifyLoops(const GraphDef& graph, } // namespace grappler } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_GRAPPLER_UTILS_SCC_H_ +#endif // TENSORFLOW_CORE_GRAPPLER_UTILS_SCC_H_ diff --git a/tensorflow/core/grappler/utils/topological_sort.h b/tensorflow/core/grappler/utils/topological_sort.h index f2c9bbfa4ebce373a4fa80f399ce3d2b59a576f4..7700fe41e40e6d1111c9e84aabfd2a05968ef882 100644 --- a/tensorflow/core/grappler/utils/topological_sort.h +++ b/tensorflow/core/grappler/utils/topological_sort.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_GRAPPLER_UTILS_TOPOLOGICAL_SORT_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_GRAPPLER_UTILS_TOPOLOGICAL_SORT_H_ +#ifndef TENSORFLOW_CORE_GRAPPLER_UTILS_TOPOLOGICAL_SORT_H_ +#define TENSORFLOW_CORE_GRAPPLER_UTILS_TOPOLOGICAL_SORT_H_ #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/lib/core/status.h" @@ -28,4 +28,4 @@ Status TopologicalSort(GraphDef* graph); } // namespace grappler } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_GRAPPLER_UTILS_TOPOLOGICAL_SORT_H_ +#endif // TENSORFLOW_CORE_GRAPPLER_UTILS_TOPOLOGICAL_SORT_H_ diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index f40074f1afb4d854b7d88e4d91d97445f285895f..fd99409c9b35ae0ee2a3cbd9da9067fdc6434a8f 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -369,6 +369,22 @@ cc_library( ], ) +cc_library( + name = "batch_kernels", + srcs = ["batch_kernels.cc"], + deps = [ + "//tensorflow/core:batch_ops_op_lib", + "//tensorflow/core:framework_headers_lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/kernels:concat_lib_hdrs", + "//tensorflow/core/kernels:ops_util_hdrs", + "//tensorflow/core/kernels:split_lib_hdrs", + "//tensorflow/core/kernels/batching_util:periodic_function_dynamic", + "//tensorflow/core/kernels/batching_util:shared_batch_scheduler_hdrs", + ], + alwayslink = 1, +) + tf_kernel_library( name = "record_input_op", srcs = [ @@ -4268,7 +4284,7 @@ cc_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "//third_party/fft2d:fft2d_headers", - "@fft2d//:fft2d", + "@fft2d", ], ) @@ -4975,6 +4991,7 @@ filegroup( "debug_ops.*", "scatter_nd_op*", "critical_section.*", + "batch_kernels.*", ], ), visibility = ["//visibility:public"], @@ -5007,8 +5024,8 @@ cc_library( "//tensorflow/core:protos_all_cc_impl", "//third_party/eigen3", "//third_party/fft2d:fft2d_headers", - "@fft2d//:fft2d", - "@gemmlowp//:gemmlowp", + "@fft2d", + "@gemmlowp", "@protobuf_archive//:protobuf", ], alwayslink = 1, @@ -5079,7 +5096,7 @@ tf_kernel_library( "//tensorflow/core:math_ops_op_lib", "//tensorflow/core:nn_ops_op_lib", "//third_party/eigen3", - "@gemmlowp//:gemmlowp", + "@gemmlowp", ], ) @@ -5840,9 +5857,10 @@ tf_mkl_kernel_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:nn_ops_op_lib", + ] + if_mkl([ "//third_party/mkl:intel_binary_blob", "@mkl_dnn//:mkl_dnn", - ], + ]), ) tf_mkl_kernel_library( @@ -6008,6 +6026,6 @@ cc_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "//third_party/eigen3", - "@gemmlowp//:gemmlowp", + "@gemmlowp", ], ) diff --git a/tensorflow/core/kernels/adjust_hsv_gpu.cu.h b/tensorflow/core/kernels/adjust_hsv_gpu.cu.h index c160ce2c3349fbd08a1d512e35a424dc00919628..49df5ae296b3e2a213c436d0e4656757c49cb16e 100644 --- a/tensorflow/core/kernels/adjust_hsv_gpu.cu.h +++ b/tensorflow/core/kernels/adjust_hsv_gpu.cu.h @@ -11,8 +11,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_ADJUST_HSV_GPU_CU_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_ADJUST_HSV_GPU_CU_H_ +#ifndef TENSORFLOW_CORE_KERNELS_ADJUST_HSV_GPU_CU_H_ +#define TENSORFLOW_CORE_KERNELS_ADJUST_HSV_GPU_CU_H_ #if GOOGLE_CUDA @@ -143,4 +143,4 @@ __global__ void adjust_hsv_nhwc(const int64 number_elements, } // namespace tensorflow #endif // GOOGLE_CUDA -#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_ADJUST_HSV_GPU_CU_H_ +#endif // TENSORFLOW_CORE_KERNELS_ADJUST_HSV_GPU_CU_H_ diff --git a/tensorflow/contrib/batching/kernels/batch_kernels.cc b/tensorflow/core/kernels/batch_kernels.cc similarity index 99% rename from tensorflow/contrib/batching/kernels/batch_kernels.cc rename to tensorflow/core/kernels/batch_kernels.cc index 6041d8c9b2ca14bd325d1e7ea562bc4bc27d6a51..5b4e1a809fa4b9e3d5c5e1b877233b31826bd386 100644 --- a/tensorflow/contrib/batching/kernels/batch_kernels.cc +++ b/tensorflow/core/kernels/batch_kernels.cc @@ -13,20 +13,22 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/batching/shared_batch_scheduler.h" -#include "tensorflow/contrib/batching/util/periodic_function.h" + #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_util.h" #include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/batching_util/shared_batch_scheduler.h" +#include "tensorflow/core/kernels/batching_util/periodic_function.h" #include "tensorflow/core/kernels/concat_lib.h" #include "tensorflow/core/kernels/ops_util.h" #include "tensorflow/core/kernels/split_lib.h" #include "tensorflow/core/lib/random/random.h" #include "tensorflow/core/platform/macros.h" + namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; diff --git a/tensorflow/core/kernels/batch_util.h b/tensorflow/core/kernels/batch_util.h index b066e2a5748e6c2e0a63ef7e27a528be99067b83..0d634ae7b07ee641eb13167d6f9fcb9ed5f0d974 100644 --- a/tensorflow/core/kernels/batch_util.h +++ b/tensorflow/core/kernels/batch_util.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_BATCH_UTIL_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_BATCH_UTIL_H_ +#ifndef TENSORFLOW_CORE_KERNELS_BATCH_UTIL_H_ +#define TENSORFLOW_CORE_KERNELS_BATCH_UTIL_H_ #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/core/status.h" @@ -35,4 +35,4 @@ Status CopySliceToElement(const Tensor& parent, Tensor* element, int64 index); } // namespace batch_util } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_BATCH_UTIL_H_ +#endif // TENSORFLOW_CORE_KERNELS_BATCH_UTIL_H_ diff --git a/tensorflow/core/kernels/batching_util/adaptive_shared_batch_scheduler.h b/tensorflow/core/kernels/batching_util/adaptive_shared_batch_scheduler.h index ff8ebb349f66df63bb23f4985212240f69efc542..25c5f9cf424fdb286922548ea7ab0a35157a3502 100644 --- a/tensorflow/core/kernels/batching_util/adaptive_shared_batch_scheduler.h +++ b/tensorflow/core/kernels/batching_util/adaptive_shared_batch_scheduler.h @@ -13,9 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_ADAPTIVE_SHARED_BATCH_SCHEDULER_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_ADAPTIVE_SHARED_BATCH_SCHEDULER_H_ - +#ifndef TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_ADAPTIVE_SHARED_BATCH_SCHEDULER_H_ +#define TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_ADAPTIVE_SHARED_BATCH_SCHEDULER_H_ #include #include @@ -657,4 +656,4 @@ size_t ASBSQueue::SchedulingCapacity() const { } // namespace serving } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_ADAPTIVE_SHARED_BATCH_SCHEDULER_H_ +#endif // TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_ADAPTIVE_SHARED_BATCH_SCHEDULER_H_ diff --git a/tensorflow/core/kernels/batching_util/basic_batch_scheduler.h b/tensorflow/core/kernels/batching_util/basic_batch_scheduler.h index 920797210079bf7ba095c4652fe952510664c47d..2b5a991caf2fc3fdb1068070946f29d26c6a55ff 100644 --- a/tensorflow/core/kernels/batching_util/basic_batch_scheduler.h +++ b/tensorflow/core/kernels/batching_util/basic_batch_scheduler.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_BASIC_BATCH_SCHEDULER_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_BASIC_BATCH_SCHEDULER_H_ +#ifndef TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_BASIC_BATCH_SCHEDULER_H_ +#define TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_BASIC_BATCH_SCHEDULER_H_ #include #include @@ -265,4 +265,4 @@ BasicBatchScheduler::BasicBatchScheduler( } // namespace serving } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_BASIC_BATCH_SCHEDULER_H_ +#endif // TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_BASIC_BATCH_SCHEDULER_H_ diff --git a/tensorflow/core/kernels/batching_util/batch_scheduler.h b/tensorflow/core/kernels/batching_util/batch_scheduler.h index a5316f152b19db2de239ff54dbca0858314d2a25..f6d9a8f0c8824188d83124d857ca9def7224bc99 100644 --- a/tensorflow/core/kernels/batching_util/batch_scheduler.h +++ b/tensorflow/core/kernels/batching_util/batch_scheduler.h @@ -23,8 +23,8 @@ limitations under the License. // // This file defines an abstract BatchScheduler class. -#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_BATCH_SCHEDULER_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_BATCH_SCHEDULER_H_ +#ifndef TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_BATCH_SCHEDULER_H_ +#define TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_BATCH_SCHEDULER_H_ #include #include @@ -278,4 +278,4 @@ void Batch::Close() { } // namespace serving } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_BATCH_SCHEDULER_H_ +#endif // TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_BATCH_SCHEDULER_H_ diff --git a/tensorflow/core/kernels/batching_util/fake_clock_env.h b/tensorflow/core/kernels/batching_util/fake_clock_env.h index b2848afe0741fc0a7d0cacce8d20bbb7ce027295..60f1cbe7bd4d3bb73abfab413cdddaecf5de6c68 100644 --- a/tensorflow/core/kernels/batching_util/fake_clock_env.h +++ b/tensorflow/core/kernels/batching_util/fake_clock_env.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_FAKE_CLOCK_ENV_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_FAKE_CLOCK_ENV_H_ +#ifndef TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_FAKE_CLOCK_ENV_H_ +#define TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_FAKE_CLOCK_ENV_H_ #include #include @@ -73,4 +73,4 @@ class FakeClockEnv : public EnvWrapper { } // namespace serving } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_FAKE_CLOCK_ENV_H_ +#endif // TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_FAKE_CLOCK_ENV_H_ diff --git a/tensorflow/core/kernels/batching_util/periodic_function.h b/tensorflow/core/kernels/batching_util/periodic_function.h index 6811cd015edfc02da70e979bdc9902b8b310c791..dbf1733dcc399522a673e5724dfeb62446f72a0f 100644 --- a/tensorflow/core/kernels/batching_util/periodic_function.h +++ b/tensorflow/core/kernels/batching_util/periodic_function.h @@ -49,9 +49,8 @@ limitations under the License. // PeriodicFunction periodic_function_; // }; -#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_PERIODIC_FUNCTION_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_PERIODIC_FUNCTION_H_ - +#ifndef TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_PERIODIC_FUNCTION_H_ +#define TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_PERIODIC_FUNCTION_H_ #include "tensorflow/core/kernels/batching_util/periodic_function.h" @@ -132,4 +131,4 @@ class PeriodicFunction { } // namespace serving } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_PERIODIC_FUNCTION_H_ +#endif // TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_PERIODIC_FUNCTION_H_ diff --git a/tensorflow/core/kernels/batching_util/shared_batch_scheduler.h b/tensorflow/core/kernels/batching_util/shared_batch_scheduler.h index 3736d8ef64d84a37c49814823d0e04db3a21ccfb..b77289aded437b2e6955ced3f7eca2aa5bd182dd 100644 --- a/tensorflow/core/kernels/batching_util/shared_batch_scheduler.h +++ b/tensorflow/core/kernels/batching_util/shared_batch_scheduler.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_SHARED_BATCH_SCHEDULER_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_SHARED_BATCH_SCHEDULER_H_ +#ifndef TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_SHARED_BATCH_SCHEDULER_H_ +#define TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_SHARED_BATCH_SCHEDULER_H_ #include #include @@ -702,4 +702,4 @@ size_t QueueHandle::SchedulingCapacity() const { } // namespace serving } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_SHARED_BATCH_SCHEDULER_H_ +#endif // TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_SHARED_BATCH_SCHEDULER_H_ diff --git a/tensorflow/core/kernels/bitcast_op.h b/tensorflow/core/kernels/bitcast_op.h index 0413569e795bcc0911d95a7a946e172579b4ef3a..900ab6f35c15e908a415849784b612da2b6d7c22 100644 --- a/tensorflow/core/kernels/bitcast_op.h +++ b/tensorflow/core/kernels/bitcast_op.h @@ -15,8 +15,8 @@ limitations under the License. // See docs in ../ops/array_ops.cc. -#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_BITCAST_OP_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_BITCAST_OP_H_ +#ifndef TENSORFLOW_CORE_KERNELS_BITCAST_OP_H_ +#define TENSORFLOW_CORE_KERNELS_BITCAST_OP_H_ #include // for memcpy @@ -27,4 +27,4 @@ limitations under the License. #include "tensorflow/core/framework/types.h" #include "tensorflow/core/lib/core/casts.h" -#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_BITCAST_OP_H_ +#endif // TENSORFLOW_CORE_KERNELS_BITCAST_OP_H_ diff --git a/tensorflow/core/kernels/captured_function.h b/tensorflow/core/kernels/captured_function.h index cdf191f4c768c2ed3bd15b0ff45fdfa27800653c..2d2d87134e786139386509c6e5f353bb88882915 100644 --- a/tensorflow/core/kernels/captured_function.h +++ b/tensorflow/core/kernels/captured_function.h @@ -12,9 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_CAPTURED_FUNCTION_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_CAPTURED_FUNCTION_H_ +#ifndef TENSORFLOW_CORE_KERNELS_CAPTURED_FUNCTION_H_ +#define TENSORFLOW_CORE_KERNELS_CAPTURED_FUNCTION_H_ #include "tensorflow/core/kernels/data/captured_function.h" -#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_CAPTURED_FUNCTION_H_ +#endif // TENSORFLOW_CORE_KERNELS_CAPTURED_FUNCTION_H_ diff --git a/tensorflow/core/kernels/cast_op_impl.h b/tensorflow/core/kernels/cast_op_impl.h index 6309e4a4dc6f3ae094e5a310ca237474afeeca14..470e9e08041e808f7459b3c654d55b82fde629a9 100644 --- a/tensorflow/core/kernels/cast_op_impl.h +++ b/tensorflow/core/kernels/cast_op_impl.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_CAST_OP_IMPL_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_CAST_OP_IMPL_H_ +#ifndef TENSORFLOW_CORE_KERNELS_CAST_OP_IMPL_H_ +#define TENSORFLOW_CORE_KERNELS_CAST_OP_IMPL_H_ #define EIGEN_USE_THREADS @@ -181,4 +181,4 @@ GetSyclCastFromDouble(DataType dst_dtype); } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_CAST_OP_IMPL_H_ +#endif // TENSORFLOW_CORE_KERNELS_CAST_OP_IMPL_H_ diff --git a/tensorflow/core/kernels/compare_and_bitpack_op.h b/tensorflow/core/kernels/compare_and_bitpack_op.h index 8e020249c106f28a8aada2cef6c31c6796b6d332..af8566c7ce200004bc6e0b5fe82afb239ad9cfad 100644 --- a/tensorflow/core/kernels/compare_and_bitpack_op.h +++ b/tensorflow/core/kernels/compare_and_bitpack_op.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_COMPARE_AND_BITPACK_OP_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_COMPARE_AND_BITPACK_OP_H_ +#ifndef TENSORFLOW_CORE_KERNELS_COMPARE_AND_BITPACK_OP_H_ +#define TENSORFLOW_CORE_KERNELS_COMPARE_AND_BITPACK_OP_H_ #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/op_kernel.h" @@ -39,4 +39,4 @@ struct CompareAndBitpack { } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_COMPARE_AND_BITPACK_OP_H_ +#endif // TENSORFLOW_CORE_KERNELS_COMPARE_AND_BITPACK_OP_H_ diff --git a/tensorflow/core/kernels/conditional_accumulator_base.h b/tensorflow/core/kernels/conditional_accumulator_base.h index 27db6ee78533c59f26f538bc59956e50c6111ee7..794ac6fa6de1eb06fcfa614bbfa472814d630d99 100644 --- a/tensorflow/core/kernels/conditional_accumulator_base.h +++ b/tensorflow/core/kernels/conditional_accumulator_base.h @@ -161,21 +161,21 @@ class ConditionalAccumulatorBase : public ResourceBase { * The below macros return a boolean if the test fails, so that the calling * function can get an indication that a failure has occurred. */ -#define OP_REQUIRES_BOOLEAN(CTX, EXP, STATUS) \ - do { \ - if (!TF_PREDICT_TRUE(EXP)) { \ - (CTX)->CtxFailure((STATUS)); \ - return false; \ - } \ +#define OP_REQUIRES_BOOLEAN(CTX, EXP, STATUS) \ + do { \ + if (!TF_PREDICT_TRUE(EXP)) { \ + (CTX)->CtxFailure(__FILE__, __LINE__, (STATUS)); \ + return false; \ + } \ } while (0) -#define OP_REQUIRES_OK_BOOLEAN(CTX, STATUS) \ - do { \ - ::tensorflow::Status _s(STATUS); \ - if (!TF_PREDICT_TRUE(_s.ok())) { \ - (CTX)->CtxFailureWithWarning(_s); \ - return false; \ - } \ +#define OP_REQUIRES_OK_BOOLEAN(CTX, STATUS) \ + do { \ + ::tensorflow::Status _s(STATUS); \ + if (!TF_PREDICT_TRUE(_s.ok())) { \ + (CTX)->CtxFailureWithWarning(__FILE__, __LINE__, _s); \ + return false; \ + } \ } while (0) /* diff --git a/tensorflow/core/kernels/cuda_device_array.h b/tensorflow/core/kernels/cuda_device_array.h index a570993cf866a23ff205fb8d79c9db8badf27685..e7a5db0683eba48295dca96c6c7599126e436536 100644 --- a/tensorflow/core/kernels/cuda_device_array.h +++ b/tensorflow/core/kernels/cuda_device_array.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_CUDA_DEVICE_ARRAY_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_CUDA_DEVICE_ARRAY_H_ +#ifndef TENSORFLOW_CORE_KERNELS_CUDA_DEVICE_ARRAY_H_ +#define TENSORFLOW_CORE_KERNELS_CUDA_DEVICE_ARRAY_H_ #if GOOGLE_CUDA @@ -117,4 +117,4 @@ class CudaDeviceArrayOnHost { #endif // GOOGLE_CUDA -#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_CUDA_DEVICE_ARRAY_H_ +#endif // TENSORFLOW_CORE_KERNELS_CUDA_DEVICE_ARRAY_H_ diff --git a/tensorflow/core/kernels/cuda_device_array_gpu.h b/tensorflow/core/kernels/cuda_device_array_gpu.h index 220f7626368852aa8b19ad18285606ed775f80b5..64fa3cb806bc7454bc6d9893e560201a620df43a 100644 --- a/tensorflow/core/kernels/cuda_device_array_gpu.h +++ b/tensorflow/core/kernels/cuda_device_array_gpu.h @@ -15,8 +15,8 @@ limitations under the License. // Contains structs and functions to be included in device code. -#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_CUDA_DEVICE_ARRAY_GPU_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_CUDA_DEVICE_ARRAY_GPU_H_ +#ifndef TENSORFLOW_CORE_KERNELS_CUDA_DEVICE_ARRAY_GPU_H_ +#define TENSORFLOW_CORE_KERNELS_CUDA_DEVICE_ARRAY_GPU_H_ #if GOOGLE_CUDA @@ -47,4 +47,4 @@ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ValueType* GetCudaDeviceArrayOnDevice( #endif // GOOGLE_CUDA -#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_CUDA_DEVICE_ARRAY_GPU_H_ +#endif // TENSORFLOW_CORE_KERNELS_CUDA_DEVICE_ARRAY_GPU_H_ diff --git a/tensorflow/core/kernels/cuda_solvers.h b/tensorflow/core/kernels/cuda_solvers.h index 3c389a82ab4070d5fb1bf3a091a4c85a6309eda9..ecfa23750c213361bc2d0be8df0091ed6ea26dd9 100644 --- a/tensorflow/core/kernels/cuda_solvers.h +++ b/tensorflow/core/kernels/cuda_solvers.h @@ -427,7 +427,7 @@ inline DeviceLapackInfo CudaSolver::GetDeviceLapackInfo( int64 size, const string& debug_info) { DeviceLapackInfo new_dev_info(context_, size, debug_info); scratch_tensor_refs_.emplace_back(new_dev_info.tensor()); - return std::move(new_dev_info); + return new_dev_info; } } // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_pow.cc b/tensorflow/core/kernels/cwise_op_pow.cc index 5fb0735ac19ba9eb057dd68c7f2d849c65d5edaa..cf86478b0fe43c777563e62e6b3fea9c7d2e6575 100644 --- a/tensorflow/core/kernels/cwise_op_pow.cc +++ b/tensorflow/core/kernels/cwise_op_pow.cc @@ -16,8 +16,9 @@ limitations under the License. #include "tensorflow/core/kernels/cwise_ops_common.h" namespace tensorflow { -REGISTER7(BinaryOp, CPU, "Pow", functor::pow, float, Eigen::half, double, int32, - int64, complex64, complex128); +REGISTER5(BinaryOp, CPU, "Pow", functor::pow, float, Eigen::half, double, + complex64, complex128); +REGISTER2(BinaryOp, CPU, "Pow", functor::safe_pow, int32, int64); #if GOOGLE_CUDA REGISTER4(BinaryOp, GPU, "Pow", functor::pow, float, Eigen::half, double, @@ -25,5 +26,5 @@ REGISTER4(BinaryOp, GPU, "Pow", functor::pow, float, Eigen::half, double, #endif #ifdef TENSORFLOW_USE_SYCL REGISTER2(BinaryOp, SYCL, "Pow", functor::pow, float, double); -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL } // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_ops.h b/tensorflow/core/kernels/cwise_ops.h index da70b1e314e2fc1679401920f8a42dd37105e5af..06918075a42648a3cf7135376d728fa466e7c469 100644 --- a/tensorflow/core/kernels/cwise_ops.h +++ b/tensorflow/core/kernels/cwise_ops.h @@ -21,6 +21,7 @@ limitations under the License. #include #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + #include "tensorflow/core/framework/numeric_types.h" #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/kernels/bounds_check.h" @@ -115,6 +116,35 @@ struct functor_traits> { enum { Cost = 5 * NumTraits::MulCost, PacketAccess = false }; }; +template +struct safe_scalar_binary_pow_op { + static_assert(std::is_integral::value, "Integer type expected"); + static_assert(std::is_integral::value && + std::is_signed::value, + "Signed integer type expected"); + + bool* const error; + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE safe_scalar_binary_pow_op(bool* error) + : error(error) {} + + EIGEN_DEVICE_FUNC inline Scalar operator()(const Scalar& a, + const Exponent& b) const { + const Exponent safe_b = tensorflow::internal::SubtleMustCopy(b); + if (TF_PREDICT_TRUE(safe_b >= 0)) { + return numext::pow(a, safe_b); + } else { + *error = true; + return 0; + } + } +}; + +template +struct functor_traits> { + enum { Cost = 5 * NumTraits::MulCost, PacketAccess = false }; +}; + template struct safe_div_or_mod_op { static_assert(std::is_integral::value, "Integer type expected"); @@ -741,6 +771,11 @@ struct floor_div_real : base> {}; template struct pow : base> {}; +template +struct safe_pow : base> { + static const bool has_errors = true; +}; + template struct maximum : base> {}; diff --git a/tensorflow/core/kernels/cwise_ops_common.cc b/tensorflow/core/kernels/cwise_ops_common.cc index 693c6467ac592e3357e5b06a620a64b3829bc938..e561e59cf5a23d6d4881c7c5fcf289ccff4c21cb 100644 --- a/tensorflow/core/kernels/cwise_ops_common.cc +++ b/tensorflow/core/kernels/cwise_ops_common.cc @@ -40,6 +40,11 @@ void BinaryOpShared::SetComputeError(OpKernelContext* ctx) { if ((op == "Div" || op == "Mod" || op == "FloorMod" || op == "FloorDiv") && DataTypeIsInteger(ctx->op_kernel().input_type(0))) { ctx->CtxFailure(errors::InvalidArgument("Integer division by zero")); + } else if ((op == "Pow") && + DataTypeIsInteger(ctx->op_kernel().input_type(0)) && + DataTypeIsSigned(ctx->op_kernel().input_type(1))) { + ctx->CtxFailure(errors::InvalidArgument( + "Integers to negative integer powers are not allowed")); } else { ctx->CtxFailure( errors::Internal("Unexpected error in binary operator " diff --git a/tensorflow/core/kernels/data/captured_function.h b/tensorflow/core/kernels/data/captured_function.h index 99e0ef426e04b38027617dcd91f579c082638011..32d2bc3aaebf440584934231a8555199026074ae 100644 --- a/tensorflow/core/kernels/data/captured_function.h +++ b/tensorflow/core/kernels/data/captured_function.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_DATA_CAPTURED_FUNCTION_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_DATA_CAPTURED_FUNCTION_H_ +#ifndef TENSORFLOW_CORE_KERNELS_DATA_CAPTURED_FUNCTION_H_ +#define TENSORFLOW_CORE_KERNELS_DATA_CAPTURED_FUNCTION_H_ #include #include @@ -105,4 +105,4 @@ class CapturedFunction { } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_DATA_CAPTURED_FUNCTION_H_ +#endif // TENSORFLOW_CORE_KERNELS_DATA_CAPTURED_FUNCTION_H_ diff --git a/tensorflow/core/kernels/data/dataset.h b/tensorflow/core/kernels/data/dataset.h index 3cb3c08a327d00cda565a09851a1faf6d79a4842..2ef31ddfaaa2fd1bd6a4898726d788d1ceece82e 100644 --- a/tensorflow/core/kernels/data/dataset.h +++ b/tensorflow/core/kernels/data/dataset.h @@ -12,18 +12,20 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_DATA_DATASET_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_DATA_DATASET_H_ +#ifndef TENSORFLOW_CORE_KERNELS_DATA_DATASET_H_ +#define TENSORFLOW_CORE_KERNELS_DATA_DATASET_H_ #include #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/attr_value_util.h" +#include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/framework/variant_encode_decode.h" #include "tensorflow/core/framework/variant_tensor_data.h" #include "tensorflow/core/lib/strings/str_util.h" @@ -193,9 +195,17 @@ class GraphDefBuilderWrapper { return Status::OK(); } + // Returns whether an op has been whitelisted for use inside map_fns. + // Uses a heuristic to whitelist source dataset ops which have been + // marked stateful due to b/65524810. + // Also looks up the `op_def->name` in the global + // `WhitelistedStatefulOpRegistry`. bool IsOpWhitelisted(const OpDef* op_def) const { - return StringPiece(op_def->name()).ends_with("Dataset") && - HasAttr(op_def, "output_shapes"); + return (StringPiece(op_def->name()).ends_with("Dataset") && + op_def->output_arg_size() == 1 && + op_def->output_arg(0).type() == DT_VARIANT) || + dataset::WhitelistedStatefulOpRegistry::Global()->Contains( + op_def->name()); } bool HasAttr(const string& op_type_name, const string& attr_name) const; @@ -596,4 +606,4 @@ Status StoreDatasetInVariantTensor(DatasetBase* dataset, Tensor* tensor); } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_DATA_DATASET_H_ +#endif // TENSORFLOW_CORE_KERNELS_DATA_DATASET_H_ diff --git a/tensorflow/core/kernels/data/dataset_utils.h b/tensorflow/core/kernels/data/dataset_utils.h index 40bc8735847f56157d81f6d5fb7a2d02291232fe..6c4191c2be6c55bfde7c5e8bd2e3b1e92edbaf27 100644 --- a/tensorflow/core/kernels/data/dataset_utils.h +++ b/tensorflow/core/kernels/data/dataset_utils.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_DATA_DATASET_UTILS_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_DATA_DATASET_UTILS_H_ +#ifndef TENSORFLOW_CORE_KERNELS_DATA_DATASET_UTILS_H_ +#define TENSORFLOW_CORE_KERNELS_DATA_DATASET_UTILS_H_ #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/kernels/data/captured_function.h" @@ -32,4 +32,4 @@ Status MakeIteratorFromInputElement( } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_DATA_DATASET_UTILS_H_ +#endif // TENSORFLOW_CORE_KERNELS_DATA_DATASET_UTILS_H_ diff --git a/tensorflow/core/kernels/data/iterator_ops.cc b/tensorflow/core/kernels/data/iterator_ops.cc index 244df137cbdc325da236f25e0d45cf2b37269015..56044a3d41a9f8f2af3c3a72344845e3a59151af 100644 --- a/tensorflow/core/kernels/data/iterator_ops.cc +++ b/tensorflow/core/kernels/data/iterator_ops.cc @@ -829,8 +829,8 @@ class IteratorGetNextOp : public AsyncOpKernel { void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override { IteratorResource* iterator; - OP_REQUIRES_OK(ctx, - LookupResource(ctx, HandleFromInput(ctx, 0), &iterator)); + OP_REQUIRES_OK_ASYNC( + ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator), done); // The call to `iterator->GetNext()` may block and depend on an // inter-op thread pool thread, so we issue the call from the // owned thread pool. @@ -870,6 +870,39 @@ class IteratorGetNextOp : public AsyncOpKernel { std::unique_ptr thread_pool_; }; +class IteratorGetNextSyncOp : public OpKernel { + public: + explicit IteratorGetNextSyncOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override { + IteratorResource* iterator; + OP_REQUIRES_OK(ctx, + LookupResource(ctx, HandleFromInput(ctx, 0), &iterator)); + core::ScopedUnref unref_iterator(iterator); + + std::vector components; + bool end_of_sequence = false; + + IteratorContext::Params params; + params.env = ctx->env(); + params.stats_aggregator_getter = [iterator]() { + return iterator->stats_aggregator(); + }; + params.runner = *(ctx->runner()); + params.function_library = iterator->function_library(); + IteratorContext iter_ctx(std::move(params)); + + OP_REQUIRES_OK(ctx, + iterator->GetNext(&iter_ctx, &components, &end_of_sequence)); + OP_REQUIRES(ctx, !end_of_sequence, errors::OutOfRange("End of sequence")); + + for (int i = 0; i < components.size(); ++i) { + // TODO(mrry): Check that the shapes match the shape attrs. + ctx->set_output(i, components[i]); + } + } +}; + class IteratorToStringHandleOp : public OpKernel { public: explicit IteratorToStringHandleOp(OpKernelConstruction* ctx) @@ -1033,6 +1066,8 @@ REGISTER_KERNEL_BUILDER(Name("OneShotIterator").Device(DEVICE_CPU), OneShotIteratorOp); REGISTER_KERNEL_BUILDER(Name("IteratorGetNext").Device(DEVICE_CPU), IteratorGetNextOp); +REGISTER_KERNEL_BUILDER(Name("IteratorGetNextSync").Device(DEVICE_CPU), + IteratorGetNextSyncOp); REGISTER_KERNEL_BUILDER(Name("IteratorToStringHandle").Device(DEVICE_CPU), IteratorToStringHandleOp); REGISTER_KERNEL_BUILDER(Name("IteratorFromStringHandle").Device(DEVICE_CPU), diff --git a/tensorflow/core/kernels/data/map_dataset_op.cc b/tensorflow/core/kernels/data/map_dataset_op.cc index 01f9b9fa09621562fae38a7e8b6c7957a8e5538e..89360d1cd95e896ebf284a0058edb122c7f82d09 100644 --- a/tensorflow/core/kernels/data/map_dataset_op.cc +++ b/tensorflow/core/kernels/data/map_dataset_op.cc @@ -95,10 +95,10 @@ class MapDatasetOp : public UnaryDatasetOpKernel { Node* input_graph_node = nullptr; TF_RETURN_IF_ERROR(b->AddParentDataset(ctx, input_, &input_graph_node)); - DataTypeVector other_arguments_types( - captured_func_->captured_inputs().size()); - std::vector other_arguments( - captured_func_->captured_inputs().size()); + DataTypeVector other_arguments_types; + other_arguments_types.reserve(captured_func_->captured_inputs().size()); + std::vector other_arguments; + other_arguments.reserve(captured_func_->captured_inputs().size()); for (const Tensor& t : captured_func_->captured_inputs()) { Node* node; TF_RETURN_IF_ERROR(b->AddTensor(t, &node)); diff --git a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc index f09871d98d3eac325b91b52c7f7b6d4e18e6012e..bc4426a9fdbab971a4e49d57ffcea6896fc037a7 100644 --- a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc +++ b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc @@ -109,10 +109,10 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel { TF_RETURN_IF_ERROR(b->AddParentDataset(ctx, input_, &input_graph_node)); // Input: other_arguments - DataTypeVector other_arguments_types( - captured_func_->captured_inputs().size()); - std::vector other_arguments( - captured_func_->captured_inputs().size()); + DataTypeVector other_arguments_types; + other_arguments_types.reserve(captured_func_->captured_inputs().size()); + std::vector other_arguments; + other_arguments.reserve(captured_func_->captured_inputs().size()); for (const Tensor& t : captured_func_->captured_inputs()) { Node* node; TF_RETURN_IF_ERROR(b->AddTensor(t, &node)); diff --git a/tensorflow/core/kernels/data/sql/driver_manager.h b/tensorflow/core/kernels/data/sql/driver_manager.h index 0d0c38eb58314962554b929d1a5c4a387ab68e55..a34691b5a2f43034feaf55241d0a445456c23bc3 100644 --- a/tensorflow/core/kernels/data/sql/driver_manager.h +++ b/tensorflow/core/kernels/data/sql/driver_manager.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_DATA_SQL_DRIVER_MANAGER_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_DATA_SQL_DRIVER_MANAGER_H_ +#ifndef TENSORFLOW_CORE_KERNELS_DATA_SQL_DRIVER_MANAGER_H_ +#define TENSORFLOW_CORE_KERNELS_DATA_SQL_DRIVER_MANAGER_H_ #include "tensorflow/core/kernels/data/sql/query_connection.h" @@ -38,4 +38,4 @@ class DriverManager { } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_DATA_SQL_DRIVER_MANAGER_H_ +#endif // TENSORFLOW_CORE_KERNELS_DATA_SQL_DRIVER_MANAGER_H_ diff --git a/tensorflow/core/kernels/data/sql/query_connection.h b/tensorflow/core/kernels/data/sql/query_connection.h index 194714897221f73ffec51c50c5202860b1bd0b46..f31017bd1981c3809d9b7daaa2dc56256d19d914 100644 --- a/tensorflow/core/kernels/data/sql/query_connection.h +++ b/tensorflow/core/kernels/data/sql/query_connection.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_DATA_SQL_QUERY_CONNECTION_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_DATA_SQL_QUERY_CONNECTION_H_ +#ifndef TENSORFLOW_CORE_KERNELS_DATA_SQL_QUERY_CONNECTION_H_ +#define TENSORFLOW_CORE_KERNELS_DATA_SQL_QUERY_CONNECTION_H_ #include "tensorflow/core/framework/tensor.h" @@ -64,4 +64,4 @@ class QueryConnection { } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_DATA_SQL_QUERY_CONNECTION_H_ +#endif // TENSORFLOW_CORE_KERNELS_DATA_SQL_QUERY_CONNECTION_H_ diff --git a/tensorflow/core/kernels/data/sql/sqlite_query_connection.h b/tensorflow/core/kernels/data/sql/sqlite_query_connection.h index b36b69eae4e5ba6fc65e4075703be8ad5720c8b4..787c17d6c00d99afad3d7814c3c2daaf4295b1b3 100644 --- a/tensorflow/core/kernels/data/sql/sqlite_query_connection.h +++ b/tensorflow/core/kernels/data/sql/sqlite_query_connection.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_DATA_SQL_SQLITE_QUERY_CONNECTION_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_DATA_SQL_SQLITE_QUERY_CONNECTION_H_ +#ifndef TENSORFLOW_CORE_KERNELS_DATA_SQL_SQLITE_QUERY_CONNECTION_H_ +#define TENSORFLOW_CORE_KERNELS_DATA_SQL_SQLITE_QUERY_CONNECTION_H_ #include @@ -53,4 +53,4 @@ class SqliteQueryConnection : public QueryConnection { } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_DATA_SQL_SQLITE_QUERY_CONNECTION_H_ +#endif // TENSORFLOW_CORE_KERNELS_DATA_SQL_SQLITE_QUERY_CONNECTION_H_ diff --git a/tensorflow/core/kernels/data/stats_aggregator.h b/tensorflow/core/kernels/data/stats_aggregator.h index 4cb8dba5cbb4a3866b94101df0f1e9a8e52d9cf2..076a56b0bf100161fe2cf4384e6be0809eb251fe 100644 --- a/tensorflow/core/kernels/data/stats_aggregator.h +++ b/tensorflow/core/kernels/data/stats_aggregator.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_DATA_STATS_AGGREGATOR_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_DATA_STATS_AGGREGATOR_H_ +#ifndef TENSORFLOW_CORE_KERNELS_DATA_STATS_AGGREGATOR_H_ +#define TENSORFLOW_CORE_KERNELS_DATA_STATS_AGGREGATOR_H_ #include #include @@ -81,4 +81,4 @@ class StatsAggregatorResource : public ResourceBase { } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_DATA_STATS_AGGREGATOR_H_ +#endif // TENSORFLOW_CORE_KERNELS_DATA_STATS_AGGREGATOR_H_ diff --git a/tensorflow/core/kernels/data/window_dataset.h b/tensorflow/core/kernels/data/window_dataset.h index 25396bd3e72f01eb40922a83e6dd18d1fc81e077..97c31668acba8869f1f5947acbbb4069c4adccb0 100644 --- a/tensorflow/core/kernels/data/window_dataset.h +++ b/tensorflow/core/kernels/data/window_dataset.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_DATA_WINDOW_DATASET_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_DATA_WINDOW_DATASET_H_ +#ifndef TENSORFLOW_CORE_KERNELS_DATA_WINDOW_DATASET_H_ +#define TENSORFLOW_CORE_KERNELS_DATA_WINDOW_DATASET_H_ #include @@ -45,4 +45,4 @@ Status NewWindowDataset(std::vector> elements, } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_DATA_WINDOW_DATASET_H_ +#endif // TENSORFLOW_CORE_KERNELS_DATA_WINDOW_DATASET_H_ diff --git a/tensorflow/core/kernels/dataset.h b/tensorflow/core/kernels/dataset.h index 2aa6dbe6f3e1602e0fb94b8b196d41e29d644fd8..69ab78d6355dc2e22c7d77b62123fc0bd2359fc4 100644 --- a/tensorflow/core/kernels/dataset.h +++ b/tensorflow/core/kernels/dataset.h @@ -12,9 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_DATASET_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_DATASET_H_ +#ifndef TENSORFLOW_CORE_KERNELS_DATASET_H_ +#define TENSORFLOW_CORE_KERNELS_DATASET_H_ #include "tensorflow/core/kernels/data/dataset.h" -#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_DATASET_H_ +#endif // TENSORFLOW_CORE_KERNELS_DATASET_H_ diff --git a/tensorflow/core/kernels/decode_image_op.cc b/tensorflow/core/kernels/decode_image_op.cc index ceb152c3f00fe8923429eaa5f8cff026254803a5..44dcbf834ce838e3b25957f88bfcded645104957 100644 --- a/tensorflow/core/kernels/decode_image_op.cc +++ b/tensorflow/core/kernels/decode_image_op.cc @@ -87,11 +87,10 @@ class DecodeImageOp : public OpKernel { channels_ = 3; } else { OP_REQUIRES_OK(context, context->GetAttr("channels", &channels_)); - OP_REQUIRES( - context, - channels_ == 0 || channels_ == 1 || channels_ == 3 || channels_ == 4, - errors::InvalidArgument("channels must be 0, 1, 3, or 4, got ", - channels_)); + OP_REQUIRES(context, channels_ == 0 || channels_ == 1 || channels_ == 3 || + channels_ == 4, + errors::InvalidArgument( + "channels must be 0, 1, 3, or 4, got ", channels_)); } flags_.components = channels_; @@ -115,9 +114,8 @@ class DecodeImageOp : public OpKernel { if (format_ == kJpgFormat) { OP_REQUIRES_OK(context, context->GetAttr("ratio", &flags_.ratio)); - OP_REQUIRES(context, - flags_.ratio == 1 || flags_.ratio == 2 || flags_.ratio == 4 || - flags_.ratio == 8, + OP_REQUIRES(context, flags_.ratio == 1 || flags_.ratio == 2 || + flags_.ratio == 4 || flags_.ratio == 8, errors::InvalidArgument("ratio must be 1, 2, 4, or 8, got ", flags_.ratio)); OP_REQUIRES_OK(context, context->GetAttr("fancy_upscaling", @@ -132,9 +130,8 @@ class DecodeImageOp : public OpKernel { string dct_method; OP_REQUIRES_OK(context, context->GetAttr("dct_method", &dct_method)); OP_REQUIRES( - context, - (dct_method.empty() || dct_method == "INTEGER_FAST" || - dct_method == "INTEGER_ACCURATE"), + context, (dct_method.empty() || dct_method == "INTEGER_FAST" || + dct_method == "INTEGER_ACCURATE"), errors::InvalidArgument("dct_method must be one of " "{'', 'INTEGER_FAST', 'INTEGER_ACCURATE'}")); if (dct_method == "INTEGER_FAST") { @@ -160,9 +157,9 @@ class DecodeImageOp : public OpKernel { errors::InvalidArgument("Expected image (JPEG, PNG, or GIF), got ", FileFormatString(magic, input))); OP_REQUIRES(context, input.size() <= std::numeric_limits::max(), - errors::InvalidArgument( - FileFormatString(magic, input), - " contents are too large for int: ", input.size())); + errors::InvalidArgument(FileFormatString(magic, input), + " contents are too large for int: ", + input.size())); OP_REQUIRES(context, magic == kPngFormat || channel_bits_ == 8, errors::InvalidArgument(FileFormatString(magic, input), " does not support uint16 output")); @@ -215,10 +212,9 @@ class DecodeImageOp : public OpKernel { input.data(), input.size(), flags, nullptr /* nwarn */, [=, &output](int width, int height, int channels) -> uint8* { Status status(context->allocate_output( - 0, - format_ == kGifFormat - ? TensorShape({1, height, width, channels}) - : TensorShape({height, width, channels}), + 0, format_ == kGifFormat + ? TensorShape({1, height, width, channels}) + : TensorShape({height, width, channels}), &output)); if (!status.ok()) { VLOG(1) << status; @@ -294,6 +290,7 @@ class DecodeImageOp : public OpKernel { // Decode GIF, allocating tensor once the size is known. Tensor* output = nullptr; + string error_string; OP_REQUIRES( context, gif::Decode(input.data(), input.size(), @@ -320,8 +317,10 @@ class DecodeImageOp : public OpKernel { return nullptr; } return output->flat().data(); - }), - errors::InvalidArgument("Invalid GIF data, size ", input.size())); + }, + &error_string), + errors::InvalidArgument("Invalid GIF data (size ", input.size(), "), ", + error_string)); } private: diff --git a/tensorflow/core/kernels/deep_conv2d.h b/tensorflow/core/kernels/deep_conv2d.h index c3f6f66dc9ba6fcf3e29c139eec0030cc7a0be57..17a0230516e27a7121fd632479b9eb8227f83283 100644 --- a/tensorflow/core/kernels/deep_conv2d.h +++ b/tensorflow/core/kernels/deep_conv2d.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_DEEP_CONV2D_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_DEEP_CONV2D_H_ +#ifndef TENSORFLOW_CORE_KERNELS_DEEP_CONV2D_H_ +#define TENSORFLOW_CORE_KERNELS_DEEP_CONV2D_H_ #include "tensorflow/core/framework/types.h" @@ -114,4 +114,4 @@ struct DeepConv2D { } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_DEEP_CONV2D_H_ +#endif // TENSORFLOW_CORE_KERNELS_DEEP_CONV2D_H_ diff --git a/tensorflow/core/kernels/depthwise_conv_op.h b/tensorflow/core/kernels/depthwise_conv_op.h index 097a9f5bfad4f1cf0232b0bb31cf6f88fdb5696b..ba262d56eef62eed3abf23b34da2a3c4727795d4 100644 --- a/tensorflow/core/kernels/depthwise_conv_op.h +++ b/tensorflow/core/kernels/depthwise_conv_op.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_DEPTHWISE_CONV_OP_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_DEPTHWISE_CONV_OP_H_ +#ifndef TENSORFLOW_CORE_KERNELS_DEPTHWISE_CONV_OP_H_ +#define TENSORFLOW_CORE_KERNELS_DEPTHWISE_CONV_OP_H_ #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/types.h" @@ -284,4 +284,4 @@ struct DepthwiseInputCopyOp { } // namespace functor } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_DEPTHWISE_CONV_OP_H_ +#endif // TENSORFLOW_CORE_KERNELS_DEPTHWISE_CONV_OP_H_ diff --git a/tensorflow/core/kernels/determinant_op.h b/tensorflow/core/kernels/determinant_op.h index e931e328e4bbb2e29f3f3ff4fbaf3dfb76fb1ea7..eefdfe0ae40bca1713f9667bf9fced934a412acb 100644 --- a/tensorflow/core/kernels/determinant_op.h +++ b/tensorflow/core/kernels/determinant_op.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_DETERMINANT_OP_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_DETERMINANT_OP_H_ +#ifndef TENSORFLOW_CORE_KERNELS_DETERMINANT_OP_H_ +#define TENSORFLOW_CORE_KERNELS_DETERMINANT_OP_H_ #include "tensorflow/core/framework/tensor_types.h" @@ -44,4 +44,4 @@ struct LogDeterminantFromPivotedLUFunctor { } // namespace functor } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_DETERMINANT_OP_H_ +#endif // TENSORFLOW_CORE_KERNELS_DETERMINANT_OP_H_ diff --git a/tensorflow/core/kernels/eigen_activations.h b/tensorflow/core/kernels/eigen_activations.h index 57c8157b878f6b46ca5a57857747e899fddbebb2..99b4b2abe66d9f372f99af1ef6164774e7ebfabc 100644 --- a/tensorflow/core/kernels/eigen_activations.h +++ b/tensorflow/core/kernels/eigen_activations.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_EIGEN_ACTIVATIONS_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_EIGEN_ACTIVATIONS_H_ +#ifndef TENSORFLOW_CORE_KERNELS_EIGEN_ACTIVATIONS_H_ +#define TENSORFLOW_CORE_KERNELS_EIGEN_ACTIVATIONS_H_ #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" @@ -122,4 +122,4 @@ struct functor_traits > { } // end namespace Eigen -#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_EIGEN_ACTIVATIONS_H_ +#endif // TENSORFLOW_CORE_KERNELS_EIGEN_ACTIVATIONS_H_ diff --git a/tensorflow/core/kernels/eigen_attention.h b/tensorflow/core/kernels/eigen_attention.h index f4c42372b1840e0c46b57b133745670a07a8c46c..3a94b8c9933ddbf262552044c73206e1deb9828d 100644 --- a/tensorflow/core/kernels/eigen_attention.h +++ b/tensorflow/core/kernels/eigen_attention.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_EIGEN_ATTENTION_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_EIGEN_ATTENTION_H_ +#ifndef TENSORFLOW_CORE_KERNELS_EIGEN_ATTENTION_H_ +#define TENSORFLOW_CORE_KERNELS_EIGEN_ATTENTION_H_ #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" @@ -239,4 +239,4 @@ ExtractGlimpses(const Input& input, } // end namespace Eigen -#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_EIGEN_ATTENTION_H_ +#endif // TENSORFLOW_CORE_KERNELS_EIGEN_ATTENTION_H_ diff --git a/tensorflow/core/kernels/eigen_backward_cuboid_convolutions.h b/tensorflow/core/kernels/eigen_backward_cuboid_convolutions.h index a44e7197a9412926fc30eecbc8128fe08829d21e..e13e548f863bcdcb5e8853ea19532e8e787e4571 100644 --- a/tensorflow/core/kernels/eigen_backward_cuboid_convolutions.h +++ b/tensorflow/core/kernels/eigen_backward_cuboid_convolutions.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_EIGEN_BACKWARD_CUBOID_CONVOLUTIONS_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_EIGEN_BACKWARD_CUBOID_CONVOLUTIONS_H_ +#ifndef TENSORFLOW_CORE_KERNELS_EIGEN_BACKWARD_CUBOID_CONVOLUTIONS_H_ +#define TENSORFLOW_CORE_KERNELS_EIGEN_BACKWARD_CUBOID_CONVOLUTIONS_H_ #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/kernels/eigen_volume_patch.h" @@ -617,4 +617,4 @@ CuboidConvolutionBackwardKernel( } // end namespace Eigen -#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_EIGEN_BACKWARD_CUBOID_CONVOLUTIONS_H_ +#endif // TENSORFLOW_CORE_KERNELS_EIGEN_BACKWARD_CUBOID_CONVOLUTIONS_H_ diff --git a/tensorflow/core/kernels/eigen_backward_spatial_convolutions.h b/tensorflow/core/kernels/eigen_backward_spatial_convolutions.h index d172de8e18d89b4e006c0093b603b7d3f305494f..aec76978102ed4d5e8d0cca18f1ae4422acc1515 100644 --- a/tensorflow/core/kernels/eigen_backward_spatial_convolutions.h +++ b/tensorflow/core/kernels/eigen_backward_spatial_convolutions.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_EIGEN_BACKWARD_SPATIAL_CONVOLUTIONS_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_EIGEN_BACKWARD_SPATIAL_CONVOLUTIONS_H_ +#ifndef TENSORFLOW_CORE_KERNELS_EIGEN_BACKWARD_SPATIAL_CONVOLUTIONS_H_ +#define TENSORFLOW_CORE_KERNELS_EIGEN_BACKWARD_SPATIAL_CONVOLUTIONS_H_ #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" diff --git a/tensorflow/core/kernels/eigen_cuboid_convolution.h b/tensorflow/core/kernels/eigen_cuboid_convolution.h index 2dca664a86d6715e8e9d90842058d6ecc89f569a..62e9f9123dd4101d0e8466fb2f4f90fcb6da73c2 100644 --- a/tensorflow/core/kernels/eigen_cuboid_convolution.h +++ b/tensorflow/core/kernels/eigen_cuboid_convolution.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_EIGEN_CUBOID_CONVOLUTION_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_EIGEN_CUBOID_CONVOLUTION_H_ +#ifndef TENSORFLOW_CORE_KERNELS_EIGEN_CUBOID_CONVOLUTION_H_ +#define TENSORFLOW_CORE_KERNELS_EIGEN_CUBOID_CONVOLUTION_H_ #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/kernels/eigen_volume_patch.h" @@ -224,4 +224,4 @@ CuboidConvolution(const Input& input, const Kernel& kernel, } // end namespace Eigen -#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_EIGEN_CUBOID_CONVOLUTION_H_ +#endif // TENSORFLOW_CORE_KERNELS_EIGEN_CUBOID_CONVOLUTION_H_ diff --git a/tensorflow/core/kernels/eigen_pooling.h b/tensorflow/core/kernels/eigen_pooling.h index 94100d71ec30b07e47fafb826d5e428c2bde7bcb..972036833fff6753031e97216d524a014bb81cbb 100644 --- a/tensorflow/core/kernels/eigen_pooling.h +++ b/tensorflow/core/kernels/eigen_pooling.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_EIGEN_POOLING_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_EIGEN_POOLING_H_ +#ifndef TENSORFLOW_CORE_KERNELS_EIGEN_POOLING_H_ +#define TENSORFLOW_CORE_KERNELS_EIGEN_POOLING_H_ #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/kernels/eigen_volume_patch.h" @@ -610,4 +610,4 @@ CuboidAvgPooling(const Input& input, DenseIndex patchPlanes, } // end namespace Eigen -#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_EIGEN_POOLING_H_ +#endif // TENSORFLOW_CORE_KERNELS_EIGEN_POOLING_H_ diff --git a/tensorflow/core/kernels/eigen_softmax.h b/tensorflow/core/kernels/eigen_softmax.h index 20bb8a44dd9041b6c704447d7f14979bf0da0efb..a2930a726f908ac4862a47104e379e6d30e88477 100644 --- a/tensorflow/core/kernels/eigen_softmax.h +++ b/tensorflow/core/kernels/eigen_softmax.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_EIGEN_SOFTMAX_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_EIGEN_SOFTMAX_H_ +#ifndef TENSORFLOW_CORE_KERNELS_EIGEN_SOFTMAX_H_ +#define TENSORFLOW_CORE_KERNELS_EIGEN_SOFTMAX_H_ #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" @@ -87,4 +87,4 @@ SoftMax(const Input& input, const float beta) } // end namespace Eigen -#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_EIGEN_SOFTMAX_H_ +#endif // TENSORFLOW_CORE_KERNELS_EIGEN_SOFTMAX_H_ diff --git a/tensorflow/core/kernels/eigen_spatial_convolutions.h b/tensorflow/core/kernels/eigen_spatial_convolutions.h index 7702f3e70a806a3edda48a8a86e3a65571e8ba7e..2fe64cd72ac06e86cccea31145079451d0b28f88 100644 --- a/tensorflow/core/kernels/eigen_spatial_convolutions.h +++ b/tensorflow/core/kernels/eigen_spatial_convolutions.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_EIGEN_SPATIAL_CONVOLUTIONS_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_EIGEN_SPATIAL_CONVOLUTIONS_H_ +#ifndef TENSORFLOW_CORE_KERNELS_EIGEN_SPATIAL_CONVOLUTIONS_H_ +#define TENSORFLOW_CORE_KERNELS_EIGEN_SPATIAL_CONVOLUTIONS_H_ #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" @@ -1069,4 +1069,4 @@ EIGEN_DEVICE_FUNC } // end namespace Eigen -#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_EIGEN_SPATIAL_CONVOLUTIONS_H_ +#endif // TENSORFLOW_CORE_KERNELS_EIGEN_SPATIAL_CONVOLUTIONS_H_ diff --git a/tensorflow/core/kernels/eigen_volume_patch.h b/tensorflow/core/kernels/eigen_volume_patch.h index afd5f37e352a5b5e4c2f77666bc3b18be914b1b2..a3d795813de19c9571ffeec705a6e4cb19f6b641 100644 --- a/tensorflow/core/kernels/eigen_volume_patch.h +++ b/tensorflow/core/kernels/eigen_volume_patch.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_EIGEN_VOLUME_PATCH_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_EIGEN_VOLUME_PATCH_H_ +#ifndef TENSORFLOW_CORE_KERNELS_EIGEN_VOLUME_PATCH_H_ +#define TENSORFLOW_CORE_KERNELS_EIGEN_VOLUME_PATCH_H_ #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" @@ -653,4 +653,4 @@ OVERRIDE_EVALUATOR(Eigen::DefaultDevice); }; // namespace Eigen -#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_EIGEN_VOLUME_PATCH_H_ +#endif // TENSORFLOW_CORE_KERNELS_EIGEN_VOLUME_PATCH_H_ diff --git a/tensorflow/core/kernels/eye_functor.h b/tensorflow/core/kernels/eye_functor.h index 70f093f81366e017f3a07614e319435e1bf5aca2..3799cfba9aea54a603af56c5ade9197f53f96dd1 100644 --- a/tensorflow/core/kernels/eye_functor.h +++ b/tensorflow/core/kernels/eye_functor.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_EYE_FUNCTOR_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_EYE_FUNCTOR_H_ +#ifndef TENSORFLOW_CORE_KERNELS_EYE_FUNCTOR_H_ +#define TENSORFLOW_CORE_KERNELS_EYE_FUNCTOR_H_ #include "tensorflow/core/framework/tensor_types.h" @@ -29,4 +29,4 @@ struct EyeFunctor { } // namespace functor } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_EYE_FUNCTOR_H_ +#endif // TENSORFLOW_CORE_KERNELS_EYE_FUNCTOR_H_ diff --git a/tensorflow/core/kernels/fake_quant_ops_functor.h b/tensorflow/core/kernels/fake_quant_ops_functor.h index 7aaad6e6c7a48617d1a6cbc679eebc2297828f75..81189866c34819306231edc2073fbdc23fbb9baf 100644 --- a/tensorflow/core/kernels/fake_quant_ops_functor.h +++ b/tensorflow/core/kernels/fake_quant_ops_functor.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_FAKE_QUANT_FUNCTOR_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_FAKE_QUANT_FUNCTOR_H_ +#ifndef TENSORFLOW_CORE_KERNELS_FAKE_QUANT_FUNCTOR_H_ +#define TENSORFLOW_CORE_KERNELS_FAKE_QUANT_FUNCTOR_H_ #include @@ -277,4 +277,4 @@ struct FakeQuantWithMinMaxVarsPerChannelGradientFunctor { } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_FAKE_QUANT_FUNCTOR_H_ +#endif // TENSORFLOW_CORE_KERNELS_FAKE_QUANT_FUNCTOR_H_ diff --git a/tensorflow/core/kernels/gather_functor_gpu.cu.h b/tensorflow/core/kernels/gather_functor_gpu.cu.h index a50b51b54b1d8e23b4082ba7b6bee8db2cc28382..11ea63d730aa69509edaacf127e62b4bbeb5740f 100644 --- a/tensorflow/core/kernels/gather_functor_gpu.cu.h +++ b/tensorflow/core/kernels/gather_functor_gpu.cu.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_GATHER_FUNCTOR_GPU_CU_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_GATHER_FUNCTOR_GPU_CU_H_ +#ifndef TENSORFLOW_CORE_KERNELS_GATHER_FUNCTOR_GPU_CU_H_ +#define TENSORFLOW_CORE_KERNELS_GATHER_FUNCTOR_GPU_CU_H_ #if GOOGLE_CUDA @@ -118,4 +118,4 @@ struct GatherFunctor { #endif // GOOGLE_CUDA -#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_GATHER_FUNCTOR_GPU_CU_H_ +#endif // TENSORFLOW_CORE_KERNELS_GATHER_FUNCTOR_GPU_CU_H_ diff --git a/tensorflow/core/kernels/gpu_utils.h b/tensorflow/core/kernels/gpu_utils.h index 366877bcf5f57139a5600c4e198a7862d8ed9ef7..ffc733e6bb6b45ab463f319de39dfd175e83e5c1 100644 --- a/tensorflow/core/kernels/gpu_utils.h +++ b/tensorflow/core/kernels/gpu_utils.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_GPU_UTILS_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_GPU_UTILS_H_ +#ifndef TENSORFLOW_CORE_KERNELS_GPU_UTILS_H_ +#define TENSORFLOW_CORE_KERNELS_GPU_UTILS_H_ #if GOOGLE_CUDA @@ -162,4 +162,4 @@ class AutoTuneSingleton { #endif // GOOGLE_CUDA -#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_GPU_UTILS_H_ +#endif // TENSORFLOW_CORE_KERNELS_GPU_UTILS_H_ diff --git a/tensorflow/core/kernels/hexagon/graph_transferer.h b/tensorflow/core/kernels/hexagon/graph_transferer.h index 125d1fd200719de195da2ac3339576decde1ba46..a360d188cc2246b87af348db9958152418742822 100644 --- a/tensorflow/core/kernels/hexagon/graph_transferer.h +++ b/tensorflow/core/kernels/hexagon/graph_transferer.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_HEXAGON_GRAPH_TRANSFERER_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_HEXAGON_GRAPH_TRANSFERER_H_ +#ifndef TENSORFLOW_CORE_KERNELS_HEXAGON_GRAPH_TRANSFERER_H_ +#define TENSORFLOW_CORE_KERNELS_HEXAGON_GRAPH_TRANSFERER_H_ #include #include @@ -225,4 +225,4 @@ class GraphTransferer { } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_HEXAGON_GRAPH_TRANSFERER_H +#endif // TENSORFLOW_CORE_KERNELS_HEXAGON_GRAPH_TRANSFERER_H diff --git a/tensorflow/core/kernels/hexagon/hexagon_control_wrapper.h b/tensorflow/core/kernels/hexagon/hexagon_control_wrapper.h index 8eb3995fc4f7974e382eb1370e05bec4a2f4a3f2..dca1f94a9b156bc9199064a72efc69b34956e59f 100644 --- a/tensorflow/core/kernels/hexagon/hexagon_control_wrapper.h +++ b/tensorflow/core/kernels/hexagon/hexagon_control_wrapper.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_HEXAGON_CONTROL_WRAPPER_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_HEXAGON_CONTROL_WRAPPER_H_ +#ifndef TENSORFLOW_CORE_KERNELS_HEXAGON_CONTROL_WRAPPER_H_ +#define TENSORFLOW_CORE_KERNELS_HEXAGON_CONTROL_WRAPPER_H_ #include #include @@ -88,4 +88,4 @@ class HexagonControlWrapper final : public IRemoteFusedGraphExecutor { } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_HEXAGON_CONTROL_WRAPPER_H_ +#endif // TENSORFLOW_CORE_KERNELS_HEXAGON_CONTROL_WRAPPER_H_ diff --git a/tensorflow/core/kernels/hexagon/hexagon_ops_definitions.h b/tensorflow/core/kernels/hexagon/hexagon_ops_definitions.h index 993a5f9a3a81d1bfc00b59ec1364209d11ceeaa7..b9328c8e0e891cf637d467e7fcbbac331d84e12c 100644 --- a/tensorflow/core/kernels/hexagon/hexagon_ops_definitions.h +++ b/tensorflow/core/kernels/hexagon/hexagon_ops_definitions.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_HEXAGON_HEXAGON_OPS_DEFINITIONS_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_HEXAGON_HEXAGON_OPS_DEFINITIONS_H_ +#ifndef TENSORFLOW_CORE_KERNELS_HEXAGON_HEXAGON_OPS_DEFINITIONS_H_ +#define TENSORFLOW_CORE_KERNELS_HEXAGON_HEXAGON_OPS_DEFINITIONS_H_ #include @@ -55,4 +55,4 @@ class HexagonOpsDefinitions final : public IRemoteFusedGraphOpsDefinitions { } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_HEXAGON_HEXAGON_OPS_DEFINITIONS_H +#endif // TENSORFLOW_CORE_KERNELS_HEXAGON_HEXAGON_OPS_DEFINITIONS_H diff --git a/tensorflow/core/kernels/i_remote_fused_graph_executor.h b/tensorflow/core/kernels/i_remote_fused_graph_executor.h index 05b76172b203673917f65f048f8132c2fb0de173..eb6b64da583f0ea9e4bb462925ebdf1bcf8dc1e3 100644 --- a/tensorflow/core/kernels/i_remote_fused_graph_executor.h +++ b/tensorflow/core/kernels/i_remote_fused_graph_executor.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_I_REMOTE_GRAPH_EXECUTOR_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_I_REMOTE_GRAPH_EXECUTOR_H_ +#ifndef TENSORFLOW_CORE_KERNELS_I_REMOTE_GRAPH_EXECUTOR_H_ +#define TENSORFLOW_CORE_KERNELS_I_REMOTE_GRAPH_EXECUTOR_H_ #include "tensorflow/core/framework/remote_fused_graph_execute_info.pb.h" #include "tensorflow/core/framework/tensor.h" @@ -72,4 +72,4 @@ class IRemoteFusedGraphExecutor { } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_I_REMOTE_GRAPH_EXECUTOR_H_ +#endif // TENSORFLOW_CORE_KERNELS_I_REMOTE_GRAPH_EXECUTOR_H_ diff --git a/tensorflow/core/kernels/i_remote_fused_graph_ops_definitions.h b/tensorflow/core/kernels/i_remote_fused_graph_ops_definitions.h index 7d3329f490713c243cb23d2e3232d6e343c55187..9e51c9f51f4c75a7ccd635a0261f633b675326bf 100644 --- a/tensorflow/core/kernels/i_remote_fused_graph_ops_definitions.h +++ b/tensorflow/core/kernels/i_remote_fused_graph_ops_definitions.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_I_REMOTE_FUSED_GRAPH_OPS_DEFINITIONS_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_I_REMOTE_FUSED_GRAPH_OPS_DEFINITIONS_H_ +#ifndef TENSORFLOW_CORE_KERNELS_I_REMOTE_FUSED_GRAPH_OPS_DEFINITIONS_H_ +#define TENSORFLOW_CORE_KERNELS_I_REMOTE_FUSED_GRAPH_OPS_DEFINITIONS_H_ #include "tensorflow/core/framework/types.h" #include "tensorflow/core/platform/macros.h" @@ -43,4 +43,4 @@ class IRemoteFusedGraphOpsDefinitions { } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_I_REMOTE_FUSED_GRAPH_OPS_DEFINITIONS_H_ +#endif // TENSORFLOW_CORE_KERNELS_I_REMOTE_FUSED_GRAPH_OPS_DEFINITIONS_H_ diff --git a/tensorflow/core/kernels/list_kernels.cc b/tensorflow/core/kernels/list_kernels.cc index 5e405f16a4d141f344532fec8342eea754a80f5e..baf0a4abe48ea0c5a5fed5d7ef3e53925e393b10 100644 --- a/tensorflow/core/kernels/list_kernels.cc +++ b/tensorflow/core/kernels/list_kernels.cc @@ -87,6 +87,14 @@ REGISTER_LIST_COPY(VariantDeviceCopyDirection::DEVICE_TO_DEVICE); REGISTER_UNARY_VARIANT_DECODE_FUNCTION(TensorList, TensorList::kTypeName); +Status TensorListShape(const TensorList& t, TensorShape* s) { + *s = TensorShape({}); + return Status::OK(); +} + +REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(TensorList, TensorList::kTypeName, + TensorListShape); + bool TensorList::Decode(const VariantTensorData& data) { tensors = data.tensors(); string metadata; @@ -251,6 +259,45 @@ REGISTER_KERNEL_BUILDER( #endif // GOOGLE_CUDA +class TensorListElementShape : public OpKernel { + public: + explicit TensorListElementShape(OpKernelConstruction* c) : OpKernel(c) {} + + void Compute(OpKernelContext* c) override { + OP_REQUIRES( + c, c->input(0).shape().num_elements() == 1, + errors::InvalidArgument("List tensors are supposed to be scalars.")); + const TensorList* l = c->input(0).scalar()().get(); + OP_REQUIRES(c, l != nullptr, + errors::InvalidArgument( + "TensorListElementShape received a variant which is not a " + "list. Saw: '", + c->input(0).scalar()().DebugString(), "'")); + Tensor* result; + OP_REQUIRES_OK(c, c->allocate_output( + 0, TensorShape{l->element_shape.dims()}, &result)); + for (int i = 0; i < l->element_shape.dims(); ++i) { + if (result->dtype() == DT_INT32) { + result->flat()(i) = l->element_shape.dim_size(i); + } else { + result->flat()(i) = l->element_shape.dim_size(i); + } + } + } +}; + +REGISTER_KERNEL_BUILDER(Name("TensorListElementShape").Device(DEVICE_CPU), + TensorListElementShape); + +#if GOOGLE_CUDA + +REGISTER_KERNEL_BUILDER(Name("TensorListElementShape") + .Device(DEVICE_GPU) + .HostMemory("element_shape"), + TensorListElementShape); + +#endif // GOOGLE_CUDA + class TensorListPopBack : public OpKernel { public: explicit TensorListPopBack(OpKernelConstruction* c) : OpKernel(c) { @@ -299,6 +346,134 @@ REGISTER_KERNEL_BUILDER(Name("TensorListPopBack").Device(DEVICE_GPU), #endif // GOOGLE_CUDA +class TensorListReserve : public OpKernel { + public: + explicit TensorListReserve(OpKernelConstruction* c) : OpKernel(c) { + OP_REQUIRES_OK(c, c->GetAttr("element_dtype", &element_dtype_)); + } + + void Compute(OpKernelContext* c) override { + PartialTensorShape element_shape; + OP_REQUIRES_OK(c, TensorShapeFromTensor(c->input(0), &element_shape)); + int32 num_elements = c->input(1).scalar()(); + TensorList output; + output.element_shape = element_shape; + output.element_dtype = element_dtype_; + output.tensors.resize(num_elements, Tensor(DT_INVALID)); + Tensor* result; + AllocatorAttributes attr; + attr.set_on_host(true); + OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape{}, &result, attr)); + result->scalar()() = std::move(output); + } + + private: + DataType element_dtype_; +}; + +REGISTER_KERNEL_BUILDER(Name("TensorListReserve").Device(DEVICE_CPU), + TensorListReserve); + +#if GOOGLE_CUDA + +REGISTER_KERNEL_BUILDER(Name("TensorListReserve") + .Device(DEVICE_GPU) + .HostMemory("element_shape") + .HostMemory("num_elements"), + TensorListReserve); + +#endif // GOOGLE_CUDA + +class TensorListGetItem : public OpKernel { + public: + explicit TensorListGetItem(OpKernelConstruction* c) : OpKernel(c) { + OP_REQUIRES_OK(c, c->GetAttr("element_dtype", &element_dtype_)); + } + + void Compute(OpKernelContext* c) override { + OP_REQUIRES( + c, c->input(0).shape().num_elements() == 1, + errors::InvalidArgument("List tensors are supposed to be scalars.")); + const TensorList* l = c->input(0).scalar()().get(); + OP_REQUIRES(c, l != nullptr, + errors::InvalidArgument( + "Input handle is not a list. Saw: '", + c->input(0).scalar()().DebugString(), "'")); + OP_REQUIRES(c, element_dtype_ == l->element_dtype, + errors::InvalidArgument("Invalid data types; op elements ", + DataTypeString(element_dtype_), + " but list elements ", + DataTypeString(l->element_dtype))); + int32 index = c->input(1).scalar()(); + OP_REQUIRES(c, index < l->tensors.size(), + errors::InvalidArgument("Trying to access element ", index, + " in a list with ", l->tensors.size(), + " elements.")); + c->set_output(0, l->tensors[index]); + } + + private: + DataType element_dtype_; +}; + +REGISTER_KERNEL_BUILDER(Name("TensorListGetItem").Device(DEVICE_CPU), + TensorListGetItem); + +#if GOOGLE_CUDA + +REGISTER_KERNEL_BUILDER( + Name("TensorListGetItem").Device(DEVICE_GPU).HostMemory("index"), + TensorListGetItem); + +#endif // GOOGLE_CUDA + +class TensorListSetItem : public OpKernel { + public: + explicit TensorListSetItem(OpKernelConstruction* c) : OpKernel(c) { + OP_REQUIRES_OK(c, c->GetAttr("element_dtype", &element_dtype_)); + } + + void Compute(OpKernelContext* c) override { + const TensorList* l = c->input(0).scalar()().get(); + OP_REQUIRES(c, l != nullptr, + errors::InvalidArgument( + "Input handle is not a list. Saw: '", + c->input(0).scalar()().DebugString(), "'")); + OP_REQUIRES(c, element_dtype_ == l->element_dtype, + errors::InvalidArgument("Invalid data types; op elements ", + DataTypeString(element_dtype_), + " but list elements ", + DataTypeString(l->element_dtype))); + int32 index = c->input(1).scalar()(); + OP_REQUIRES(c, index < l->tensors.size(), + errors::InvalidArgument("Trying to modify element ", index, + " in a list with ", l->tensors.size(), + " elements.")); + TensorList output; + output = *l; + output.tensors[index] = c->input(2); + Tensor* result; + AllocatorAttributes attr; + attr.set_on_host(true); + OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape{}, &result, attr)); + result->scalar()() = std::move(output); + } + + private: + DataType element_dtype_; +}; + +REGISTER_KERNEL_BUILDER(Name("TensorListSetItem").Device(DEVICE_CPU), + TensorListSetItem); + +#if GOOGLE_CUDA + +REGISTER_KERNEL_BUILDER( + Name("TensorListSetItem").Device(DEVICE_GPU).HostMemory("index"), + TensorListSetItem); + +#endif // GOOGLE_CUDA + #define REGISTER_TENSOR_LIST_STACK_CPU(T) \ REGISTER_KERNEL_BUILDER(Name("TensorListStack") \ .TypeConstraint("element_dtype") \ diff --git a/tensorflow/core/kernels/list_kernels.h b/tensorflow/core/kernels/list_kernels.h index 6a2a572b6d7476cb4d457d19c2264c7e8217b7cb..9733883001d4ce7888b4893ecb43047b621a3eba 100644 --- a/tensorflow/core/kernels/list_kernels.h +++ b/tensorflow/core/kernels/list_kernels.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_LIST_KERNELS_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_LIST_KERNELS_H_ +#ifndef TENSORFLOW_CORE_KERNELS_LIST_KERNELS_H_ +#define TENSORFLOW_CORE_KERNELS_LIST_KERNELS_H_ #define EIGEN_USE_THREADS #if GOOGLE_CUDA @@ -76,14 +76,14 @@ class TensorListStack : public OpKernel { errors::InvalidArgument( "Input handle is not a list. Saw: '", c->input(0).scalar()().DebugString(), "'")); - OP_REQUIRES(c, l->element_shape.IsFullyDefined(), - errors::InvalidArgument("Tried to stack elements from a list " - "with non-fully-defined shape.")); OP_REQUIRES(c, element_dtype_ == l->element_dtype, errors::InvalidArgument("Invalid data types; op elements ", DataTypeString(element_dtype_), " but list elements ", DataTypeString(l->element_dtype))); + OP_REQUIRES(c, l->element_shape.IsFullyDefined(), + errors::InvalidArgument("Tried to stack elements from a list " + "with non-fully-defined shape.")); if (num_elements_ != -1) { OP_REQUIRES(c, l->tensors.size() == num_elements_, errors::InvalidArgument("Operation expected a list with ", @@ -98,16 +98,23 @@ class TensorListStack : public OpKernel { } Tensor* output; OP_REQUIRES_OK(c, c->allocate_output(0, resulting_shape, &output)); + if (output->NumElements() == 0) { + return; + } ConstMatrixVector inputs_flat; inputs_flat.reserve(l->tensors.size()); for (const auto& t : l->tensors) { + OP_REQUIRES( + c, l->element_shape.IsCompatibleWith(t.shape()), + errors::InvalidArgument( + "Tensor with invalid shape in list. List element shape shape: ", + l->element_shape.DebugString(), + " and tensor shape: ", t.shape().DebugString())); inputs_flat.emplace_back(new typename TTypes::ConstMatrix( t.shaped({1, t.NumElements()}))); } - auto output_flat = - output->shaped({1, static_cast(l->tensors.size()) * - l->element_shape.num_elements()}); + auto output_flat = output->shaped({1, output->NumElements()}); #if GOOGLE_CUDA if (std::is_same::value) { @@ -195,17 +202,26 @@ Status TensorListBinaryAdd(OpKernelContext* c, const TensorList& a, for (int i = 0; i < a.tensors.size(); ++i) { const Tensor& a_tensor = a.tensors[i]; const Tensor& b_tensor = b.tensors[i]; + if (a_tensor.dtype() == DT_INVALID) { + out->tensors.push_back(b_tensor); + continue; + } + if (b_tensor.dtype() == DT_INVALID) { + out->tensors.push_back(a_tensor); + continue; + } if (a_tensor.shape() != b_tensor.shape()) { // TODO(apassos) support broadcasting additions here? return errors::InvalidArgument( "Trying to add two tensors with incompatible element shapes. " "One is ", a_tensor.shape().DebugString(), " and the other is ", - b_tensor.shape().DebugString()); + b_tensor.shape().DebugString(), " in position ", i); } Tensor out_tensor; TF_RETURN_IF_ERROR( c->allocate_temp(a_tensor.dtype(), a_tensor.shape(), &out_tensor)); + out->tensors.push_back(out_tensor); switch (out_tensor.dtype()) { #define DTYPE_CASE(dtype) \ case DataTypeToEnum::value: \ @@ -254,4 +270,4 @@ Status TensorListZerosLike(OpKernelContext* c, const TensorList& x, } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_LIST_KERNELS_H_ +#endif // TENSORFLOW_CORE_KERNELS_LIST_KERNELS_H_ diff --git a/tensorflow/core/kernels/meta_support.h b/tensorflow/core/kernels/meta_support.h index 53aece78e87c17cac76866a84c930f3024d38cae..97f39eb598367b83d4e74d2b0cafadec62bb4cea 100644 --- a/tensorflow/core/kernels/meta_support.h +++ b/tensorflow/core/kernels/meta_support.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_QUANTIZATION_KERNELS_META_SUPPORT_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_QUANTIZATION_KERNELS_META_SUPPORT_H_ +#ifndef TENSORFLOW_CONTRIB_QUANTIZATION_KERNELS_META_SUPPORT_H_ +#define TENSORFLOW_CONTRIB_QUANTIZATION_KERNELS_META_SUPPORT_H_ #include "meta/multi_thread_gemm.h" #include "meta/multi_thread_transform.h" @@ -109,4 +109,4 @@ void Clamp(OpKernelContext* context, const quint8* input, int input_count, } // namespace meta } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_QUANTIZATION_KERNELS_META_SUPPORT_H_ +#endif // TENSORFLOW_CONTRIB_QUANTIZATION_KERNELS_META_SUPPORT_H_ diff --git a/tensorflow/core/kernels/mfcc.h b/tensorflow/core/kernels/mfcc.h index 0d5d9fb90f8bd137aea5d7f3b8c08dfcd1495c18..8268f4720348bbc820bd3f8863698d34999abb7b 100644 --- a/tensorflow/core/kernels/mfcc.h +++ b/tensorflow/core/kernels/mfcc.h @@ -15,8 +15,8 @@ limitations under the License. // Basic class for computing MFCCs from spectrogram slices. -#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_MFCC_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_MFCC_H_ +#ifndef TENSORFLOW_CORE_KERNELS_MFCC_H_ +#define TENSORFLOW_CORE_KERNELS_MFCC_H_ #include @@ -74,4 +74,4 @@ class Mfcc { } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_MFCC_H_ +#endif // TENSORFLOW_CORE_KERNELS_MFCC_H_ diff --git a/tensorflow/core/kernels/mfcc_dct.h b/tensorflow/core/kernels/mfcc_dct.h index 4fa3c01628d7f4888e6dd2c9cb5a1ef664e42723..888b8e8df8c45067981ef7ea27ddf568035dd3ae 100644 --- a/tensorflow/core/kernels/mfcc_dct.h +++ b/tensorflow/core/kernels/mfcc_dct.h @@ -15,8 +15,8 @@ limitations under the License. // Basic minimal DCT class for MFCC speech processing. -#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_MFCC_DCT_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_MFCC_DCT_H_ +#ifndef TENSORFLOW_CORE_KERNELS_MFCC_DCT_H_ +#define TENSORFLOW_CORE_KERNELS_MFCC_DCT_H_ #include @@ -41,4 +41,4 @@ class MfccDct { } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_MFCC_DCT_H_ +#endif // TENSORFLOW_CORE_KERNELS_MFCC_DCT_H_ diff --git a/tensorflow/core/kernels/mfcc_mel_filterbank.h b/tensorflow/core/kernels/mfcc_mel_filterbank.h index a766a20cbca4a7772a62a2701334c87a5ed57531..1bdc2dc93b80a2691d4adec219426b142ef24321 100644 --- a/tensorflow/core/kernels/mfcc_mel_filterbank.h +++ b/tensorflow/core/kernels/mfcc_mel_filterbank.h @@ -15,8 +15,8 @@ limitations under the License. // Basic class for applying a mel-scale mapping to a power spectrum. -#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_MFCC_MEL_FILTERBANK_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_MFCC_MEL_FILTERBANK_H_ +#ifndef TENSORFLOW_CORE_KERNELS_MFCC_MEL_FILTERBANK_H_ +#define TENSORFLOW_CORE_KERNELS_MFCC_MEL_FILTERBANK_H_ #include #include "tensorflow/core/framework/op_kernel.h" @@ -63,4 +63,4 @@ class MfccMelFilterbank { } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_MFCC_MEL_FILTERBANK_H_ +#endif // TENSORFLOW_CORE_KERNELS_MFCC_MEL_FILTERBANK_H_ diff --git a/tensorflow/core/kernels/mirror_pad_op_cpu_impl.h b/tensorflow/core/kernels/mirror_pad_op_cpu_impl.h index bb22b2aa918dad379b80931ba0893feb9366489b..6716a26fac2c77ee1ee5306cc26cf802585dcfc4 100644 --- a/tensorflow/core/kernels/mirror_pad_op_cpu_impl.h +++ b/tensorflow/core/kernels/mirror_pad_op_cpu_impl.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_MIRROR_PAD_OP_CPU_IMPL_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_MIRROR_PAD_OP_CPU_IMPL_H_ +#ifndef TENSORFLOW_CORE_MIRROR_PAD_OP_CPU_IMPL_H_ +#define TENSORFLOW_CORE_MIRROR_PAD_OP_CPU_IMPL_H_ #define EIGEN_USE_THREADS @@ -41,4 +41,4 @@ TF_CALL_NUMBER_TYPES(DEFINE_CPU_SPECS); } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_MIRROR_PAD_OP_CPU_IMPL_H_ +#endif // TENSORFLOW_CORE_MIRROR_PAD_OP_CPU_IMPL_H_ diff --git a/tensorflow/core/kernels/mkl_aggregate_ops.cc b/tensorflow/core/kernels/mkl_aggregate_ops.cc index 44b94be3a05662db2d3c190f5955d13a45a6d299..89d37d2f874c0b8fa7550b1c49c0e3c4106e2ee5 100644 --- a/tensorflow/core/kernels/mkl_aggregate_ops.cc +++ b/tensorflow/core/kernels/mkl_aggregate_ops.cc @@ -61,6 +61,16 @@ class MklAddNOp : public OpKernel { GetMklShape(ctx, src2_idx, &(mkl_context.input2_shape)); bool input2_in_mkl_format = mkl_context.input2_shape.IsMklTensor(); + // if the shapes of two tensors are not same raise op error + TensorShape src1_shape, src2_shape; + src1_shape = input0.shape(); + src2_shape = input1.shape(); + if (!src1_shape.IsSameSize(src2_shape)) { + ctx->SetStatus(errors::InvalidArgument( + "Inputs to operation ", this->name(), " of type ", + this->type_string(), " must have the same size and shape. Input 0: ", + src1_shape.DebugString(), " != input 1: ", src2_shape.DebugString())); + } // handle the case of a scalar if (!input1_in_mkl_format && input0.dims() == 0) { const TensorShape& o_shape = input0.shape(); @@ -70,17 +80,16 @@ class MklAddNOp : public OpKernel { mkl_context.output_shape); float user_i1 = (input0.scalar()()); float user_i2 = (input1.scalar()()); - out_tensor->scalar()() = - std::plus{}(user_i1, user_i2); + out_tensor->scalar()() = std::plus{}(user_i1, user_i2); return; } mkl_context.in_dims = input1_in_mkl_format - ? mkl_context.input1_shape.GetDimension() - : input0.dims(); + ? mkl_context.input1_shape.GetDimension() + : input0.dims(); mkl_context.in_dims = input2_in_mkl_format - ? mkl_context.input2_shape.GetDimension() - : input1.dims(); + ? mkl_context.input2_shape.GetDimension() + : input1.dims(); // If there is nothing to compute, return. if (!input1_in_mkl_format && !input2_in_mkl_format) { @@ -89,7 +98,7 @@ class MklAddNOp : public OpKernel { Tensor* out_tensor = nullptr; mkl_context.output_shape.SetMklTensor(false); AllocateOutputSetMklShape(ctx, src1_idx, &out_tensor, o_shape, - mkl_context.output_shape); + mkl_context.output_shape); return; } } @@ -98,9 +107,9 @@ class MklAddNOp : public OpKernel { mkl_context.in_strides = new size_t[mkl_context.in_dims]; // Generate size, stride for input if input is in MKL format. if (input1_in_mkl_format || input2_in_mkl_format) { - const MklShape* tmp_mkl_shape = - (input1_in_mkl_format) ? &mkl_context.input1_shape : - &mkl_context.input2_shape; + const MklShape* tmp_mkl_shape = (input1_in_mkl_format) + ? &mkl_context.input1_shape + : &mkl_context.input2_shape; for (int i = 0; i < mkl_context.in_dims; i++) { mkl_context.in_sizes[i] = tmp_mkl_shape->GetSizes()[i]; mkl_context.in_strides[i] = tmp_mkl_shape->GetStrides()[i]; @@ -124,32 +133,33 @@ class MklAddNOp : public OpKernel { Tensor mkl_tmp_input1_buf_tensor, mkl_tmp_input2_buf_tensor; mkl_context.MklPrepareAddNInputs(ctx, &mkl_tmp_input1_buf_tensor, - &mkl_tmp_input2_buf_tensor); + &mkl_tmp_input2_buf_tensor); Tensor* output = nullptr; if (input1_in_mkl_format || input2_in_mkl_format) { - TensorShape tf_shape; - mkl_context.output_shape.SetMklTensor(true); - mkl_context.output_shape.SetMklLayout(mkl_context.Eltwise, dnnResourceDst); - - mkl_context.output_shape.SetTfLayout( - mkl_context.in_dims, mkl_context.in_sizes, mkl_context.in_strides); - if (input1_in_mkl_format == true) { - mkl_context.output_shape.SetTfDimOrder(mkl_context.in_dims, - mkl_context.input1_shape.GetTfToMklDimMap()); - } else { - mkl_context.output_shape.SetTfDimOrder(mkl_context.in_dims, - mkl_context.input2_shape.GetTfToMklDimMap()); - } - tf_shape.AddDim(dnnLayoutGetMemorySize_F32(static_cast( - mkl_context.output_shape.GetMklLayout())) / - sizeof(T)); - - AllocateOutputSetMklShape(ctx, src1_idx, &output, tf_shape, - mkl_context.output_shape); + TensorShape tf_shape; + mkl_context.output_shape.SetMklTensor(true); + mkl_context.output_shape.SetMklLayout(mkl_context.Eltwise, + dnnResourceDst); + + mkl_context.output_shape.SetTfLayout( + mkl_context.in_dims, mkl_context.in_sizes, mkl_context.in_strides); + if (input1_in_mkl_format == true) { + mkl_context.output_shape.SetTfDimOrder( + mkl_context.in_dims, mkl_context.input1_shape.GetTfToMklDimMap()); + } else { + mkl_context.output_shape.SetTfDimOrder( + mkl_context.in_dims, mkl_context.input2_shape.GetTfToMklDimMap()); + } + tf_shape.AddDim(dnnLayoutGetMemorySize_F32(static_cast( + mkl_context.output_shape.GetMklLayout())) / + sizeof(T)); + + AllocateOutputSetMklShape(ctx, src1_idx, &output, tf_shape, + mkl_context.output_shape); } else { - const TensorShape& o_shape = input1.shape(); - mkl_context.output_shape.SetMklTensor(false); - AllocateOutputSetMklShape(ctx, src1_idx, &output, o_shape, + const TensorShape& o_shape = input1.shape(); + mkl_context.output_shape.SetMklTensor(false); + AllocateOutputSetMklShape(ctx, src1_idx, &output, o_shape, mkl_context.output_shape); } @@ -177,18 +187,16 @@ class MklAddNOp : public OpKernel { void MklCreateInputLayouts(OpKernelContext* context) { bool input1_in_mkl_format = input1_shape.IsMklTensor(); if (!input1_in_mkl_format) { - CHECK_EQ( - dnnLayoutCreate_F32(<_input1, in_dims, in_sizes, in_strides), - E_SUCCESS); + CHECK_EQ(dnnLayoutCreate_F32(<_input1, in_dims, in_sizes, in_strides), + E_SUCCESS); } else { lt_input1 = static_cast(input1_shape.GetCurLayout()); } bool input2_in_mkl_format = input2_shape.IsMklTensor(); if (!input2_in_mkl_format) { - CHECK_EQ( - dnnLayoutCreate_F32(<_input2, in_dims, in_sizes, in_strides), - E_SUCCESS); + CHECK_EQ(dnnLayoutCreate_F32(<_input2, in_dims, in_sizes, in_strides), + E_SUCCESS); } else { lt_input2 = static_cast(input2_shape.GetCurLayout()); } @@ -264,14 +272,14 @@ class MklAddNOp : public OpKernel { bool input2_in_mkl_format = input2_shape.IsMklTensor(); dnnDelete_F32(Eltwise); if (!input1_in_mkl_format || !input2_in_mkl_format) { - delete [] in_sizes; - delete [] in_strides; + delete[] in_sizes; + delete[] in_strides; } if (!input1_in_mkl_format) { - dnnLayoutDelete_F32(lt_input1); + dnnLayoutDelete_F32(lt_input1); } if (!input2_in_mkl_format) { - dnnLayoutDelete_F32(lt_input2); + dnnLayoutDelete_F32(lt_input2); } } } MklAddNOpContext; @@ -303,33 +311,44 @@ class MklAddNOp : public OpKernel { GetMklShape(ctx, src2_idx, &src2_mkl_shape); bool input1_in_mkl_format = src1_mkl_shape.IsMklTensor(); bool input2_in_mkl_format = src2_mkl_shape.IsMklTensor(); - int src1_dims_size = input1_in_mkl_format? - src1_mkl_shape.GetDimension(): src1_tensor.dims(); - int src2_dims_size = input2_in_mkl_format? - src2_mkl_shape.GetDimension(): src2_tensor.dims(); + int src1_dims_size = input1_in_mkl_format ? src1_mkl_shape.GetDimension() + : src1_tensor.dims(); + int src2_dims_size = input2_in_mkl_format ? src2_mkl_shape.GetDimension() + : src2_tensor.dims(); + // if the shapes of two tensors are not same raise op error + TensorShape src1_shape, src2_shape; + src1_shape = src1_tensor.shape(); + src2_shape = src2_tensor.shape(); + if (!src1_shape.IsSameSize(src2_shape)) { + ctx->SetStatus(errors::InvalidArgument( + "Inputs to operation ", this->name(), " of type ", + this->type_string(), + " must have the same size and shape. Input 0: ", + src1_shape.DebugString(), + " != input 1: ", src2_shape.DebugString())); + } if (!input1_in_mkl_format && src1_dims_size == 0) { - Tensor* dst_tensor = nullptr; - MklShape mkl_shape_dst; - mkl_shape_dst.SetMklTensor(false); - AllocateOutputSetMklShape(ctx, output_idx, &dst_tensor, - src1_tensor.shape(), mkl_shape_dst); - float user_i1 = (src1_tensor.scalar()()); - float user_i2 = (src2_tensor.scalar()()); - dst_tensor->scalar()() = - std::plus{}(user_i1, user_i2); - return; - } + Tensor* dst_tensor = nullptr; + MklShape mkl_shape_dst; + mkl_shape_dst.SetMklTensor(false); + AllocateOutputSetMklShape(ctx, output_idx, &dst_tensor, + src1_tensor.shape(), mkl_shape_dst); + float user_i1 = (src1_tensor.scalar()()); + float user_i2 = (src2_tensor.scalar()()); + dst_tensor->scalar()() = std::plus{}(user_i1, user_i2); + return; + } // If there is nothing to compute, return. if (!input1_in_mkl_format && !input2_in_mkl_format) { if (src1_tensor.shape().num_elements() == 0) { - Tensor* dst_tensor = nullptr; - MklShape mkl_shape_dst; - mkl_shape_dst.SetMklTensor(false); - AllocateOutputSetMklShape(ctx, output_idx, &dst_tensor, - src1_tensor.shape(), mkl_shape_dst); - return; + Tensor* dst_tensor = nullptr; + MklShape mkl_shape_dst; + mkl_shape_dst.SetMklTensor(false); + AllocateOutputSetMklShape(ctx, output_idx, &dst_tensor, + src1_tensor.shape(), mkl_shape_dst); + return; } } @@ -338,7 +357,7 @@ class MklAddNOp : public OpKernel { MklDnnData src2(&cpu_engine); MklDnnData dst(&cpu_engine); - int tmp_size = input1_in_mkl_format ? src2_dims_size: src1_dims_size; + int tmp_size = input1_in_mkl_format ? src2_dims_size : src1_dims_size; memory::dims dims(tmp_size); memory::dims strides(tmp_size); memory::desc md1({}, memory::data_undef, memory::format_undef); @@ -368,21 +387,19 @@ class MklAddNOp : public OpKernel { md1 = src1_mkl_shape.GetMklLayout(); memory::format src1_mkl_data_format = src1_mkl_shape.GetTfDataFormat(); - auto src1_tf_data_format = MklDnnDataFormatToTFDataFormat( - src1_mkl_data_format); - auto src2_dims = TFShapeToMklDnnDimsInNCHW(src2_tensor.shape(), - src1_tf_data_format); - md2 = memory::desc(src2_dims, MklDnnType(), - src1_mkl_data_format); + auto src1_tf_data_format = + MklDnnDataFormatToTFDataFormat(src1_mkl_data_format); + auto src2_dims = + TFShapeToMklDnnDimsInNCHW(src2_tensor.shape(), src1_tf_data_format); + md2 = memory::desc(src2_dims, MklDnnType(), src1_mkl_data_format); } else if (input2_in_mkl_format && !input1_in_mkl_format) { // Same comment as above. memory::format src2_mkl_data_format = src2_mkl_shape.GetTfDataFormat(); - auto src2_tf_data_format = MklDnnDataFormatToTFDataFormat( - src2_mkl_data_format); - auto src1_dims = TFShapeToMklDnnDimsInNCHW(src1_tensor.shape(), - src2_tf_data_format); - md1 = memory::desc(src1_dims, MklDnnType(), - src2_mkl_data_format); + auto src2_tf_data_format = + MklDnnDataFormatToTFDataFormat(src2_mkl_data_format); + auto src1_dims = + TFShapeToMklDnnDimsInNCHW(src1_tensor.shape(), src2_tf_data_format); + md1 = memory::desc(src1_dims, MklDnnType(), src2_mkl_data_format); md2 = src2_mkl_shape.GetMklLayout(); } else { @@ -456,20 +473,19 @@ class MklAddNOp : public OpKernel { output_mkl_shape.SetMklTensor(false); output_tf_shape = src1_tensor.shape(); } - AllocateOutputSetMklShape(ctx, output_idx, &dst_tensor, - output_tf_shape, output_mkl_shape); + AllocateOutputSetMklShape(ctx, output_idx, &dst_tensor, output_tf_shape, + output_mkl_shape); dst.SetUsrMemDataHandle(dst_tensor); // Create Sum op, and submit net for execution. net.push_back(sum(sum_pd, inputs, dst.GetOpMem())); stream(stream::kind::eager).submit(net).wait(); - } catch (mkldnn::error &e) { + } catch (mkldnn::error& e) { string error_msg = "Status: " + std::to_string(e.status) + - ", message: " + string(e.message) + - ", in file " + string(__FILE__) + ":" + - std::to_string(__LINE__); - OP_REQUIRES_OK(ctx, errors::Aborted("Operation received an exception:", - error_msg)); + ", message: " + string(e.message) + ", in file " + + string(__FILE__) + ":" + std::to_string(__LINE__); + OP_REQUIRES_OK( + ctx, errors::Aborted("Operation received an exception:", error_msg)); } } }; diff --git a/tensorflow/core/kernels/mkl_input_conversion_op.cc b/tensorflow/core/kernels/mkl_input_conversion_op.cc index 001834b13bdd64ffd0d536897fbc4a170c4c4117..4b5f7b831001458c222536be30bc40fcf5d2899a 100644 --- a/tensorflow/core/kernels/mkl_input_conversion_op.cc +++ b/tensorflow/core/kernels/mkl_input_conversion_op.cc @@ -396,7 +396,7 @@ class MklInputConversionOp : public OpKernel { auto cpu_engine = engine(engine::cpu, 0); MklDnnData tf_input(&cpu_engine); auto input_tf_md = mkl_output_mkl_shape.GetTfLayout(); - tf_input.SetUsrMem(input_tf_md, &tf_tensor); + tf_input.SetUsrMem(input_tf_md, tf_tensor); // Create reorder between tensorflow layout and Mkl layout. std::vector net; diff --git a/tensorflow/core/kernels/mkl_lrn_op.cc b/tensorflow/core/kernels/mkl_lrn_op.cc index 66bc7dd8eedbda08e052c8c3c1bd552c7b955ecb..95e0404ba8ab7d305e530239be30c7a842edf16d 100644 --- a/tensorflow/core/kernels/mkl_lrn_op.cc +++ b/tensorflow/core/kernels/mkl_lrn_op.cc @@ -43,7 +43,7 @@ limitations under the License. using mkldnn::lrn_forward; using mkldnn::lrn_backward; using mkldnn::prop_kind; -using mkldnn::algorithm::lrn_across_channels; +using mkldnn::lrn_across_channels; using mkldnn::stream; #endif diff --git a/tensorflow/core/kernels/mkl_softmax_op.cc b/tensorflow/core/kernels/mkl_softmax_op.cc index 896d56293303b06adb554cef7e2f3ef16a5a8eda..c46eabdde103913a712c3d058aa23a627d19f5ea 100644 --- a/tensorflow/core/kernels/mkl_softmax_op.cc +++ b/tensorflow/core/kernels/mkl_softmax_op.cc @@ -17,13 +17,13 @@ limitations under the License. #ifdef INTEL_MKL #ifdef INTEL_MKL_DNN +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/util/tensor_format.h" -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "mkldnn.h" #include "mkldnn_types.h" @@ -31,16 +31,14 @@ limitations under the License. #include "tensorflow/core/util/mkl_util.h" #include "mkldnn.hpp" -using mkldnn::stream; using mkldnn::prop_kind; using mkldnn::softmax_forward; +using mkldnn::stream; namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; - - template class MklSoftmaxOp : public OpKernel { public: @@ -60,11 +58,11 @@ class MklSoftmaxOp : public OpKernel { MklDnnShape src_mkl_shape; GetMklShape(context, src_idx, &src_mkl_shape); - // src_dims is the dimenstion of src_tensor // dim of the dst will also be same as src_dims - auto src_tf_shape = src_mkl_shape.IsMklTensor() ? - src_mkl_shape.GetTfShape() : src_tensor.shape(); + auto src_tf_shape = src_mkl_shape.IsMklTensor() + ? src_mkl_shape.GetTfShape() + : src_tensor.shape(); auto src_dims = TFShapeToMklDnnDims(src_tf_shape); auto output_dims = src_dims; @@ -77,10 +75,10 @@ class MklSoftmaxOp : public OpKernel { // construct input Tf layout. For TF layout, although input shape // (src_dims) required is in MKL-DNN order, the layout is Tensorflow's // layout - auto src_md = src_mkl_shape.IsMklTensor() - ? src_mkl_shape.GetMklLayout() - : memory::desc(src_dims, MklDnnType(), - memory::format::nc); + auto src_md = + src_mkl_shape.IsMklTensor() + ? src_mkl_shape.GetMklLayout() + : memory::desc(src_dims, MklDnnType(), memory::format::nc); // src: setting memory descriptor and op memory descriptor // Basically following two functions maps the TF "src_tensor" to mkl @@ -95,8 +93,8 @@ class MklSoftmaxOp : public OpKernel { int axis = 1; // axis to which softmax will be applied auto softmax_fwd_desc = softmax_forward::desc(prop_kind::forward_scoring, src.GetOpMemDesc(), axis); - auto softmax_fwd_pd = softmax_forward::primitive_desc(softmax_fwd_desc, - cpu_engine); + auto softmax_fwd_pd = + softmax_forward::primitive_desc(softmax_fwd_desc, cpu_engine); // add: output Tensor* output_tensor = nullptr; @@ -136,9 +134,9 @@ class MklSoftmaxOp : public OpKernel { net.push_back(softmax_fwd); stream(stream::kind::eager).submit(net).wait(); } catch (mkldnn::error& e) { - string error_msg = "Status: " + std::to_string(e.status) + ", message: " + - string(e.message) + ", in file " + string(__FILE__) + - ":" + std::to_string(__LINE__); + string error_msg = "Status: " + std::to_string(e.status) + + ", message: " + string(e.message) + ", in file " + + string(__FILE__) + ":" + std::to_string(__LINE__); OP_REQUIRES_OK( context, errors::Aborted("Operation received an exception:", error_msg)); @@ -148,7 +146,7 @@ class MklSoftmaxOp : public OpKernel { /* Register DNN kernels for supported operations and supported types - right now * it is only Softmax and f32 */ -#define REGISTER_SOFTMAX_MKL_SUPPORTED_KERNELS_TYPES(type) \ +#define REGISTER_SOFTMAX_MKL_SUPPORTED_KERNELS_TYPES(type) \ REGISTER_KERNEL_BUILDER(Name("_MklSoftmax") \ .Device(DEVICE_CPU) \ .TypeConstraint("T") \ @@ -156,7 +154,6 @@ class MklSoftmaxOp : public OpKernel { MklSoftmaxOp); TF_CALL_float(REGISTER_SOFTMAX_MKL_SUPPORTED_KERNELS_TYPES); - } // namespace tensorflow #endif // INTEL_MKL_DNN diff --git a/tensorflow/core/kernels/neon/BUILD b/tensorflow/core/kernels/neon/BUILD index 536b2bdc03c5dc91e8e3e25dd9fbba82cd29fd5b..c3d24e50effb3fe5184e264064393a7f339105f0 100644 --- a/tensorflow/core/kernels/neon/BUILD +++ b/tensorflow/core/kernels/neon/BUILD @@ -39,6 +39,6 @@ tf_kernel_library( "//tensorflow/core:nn_ops_op_lib", "//tensorflow/core/kernels:bounds_check", "//tensorflow/core/kernels:ops_util", - "@gemmlowp//:gemmlowp", + "@gemmlowp", ], ) diff --git a/tensorflow/core/kernels/neon/depthwiseconv_float.h b/tensorflow/core/kernels/neon/depthwiseconv_float.h index acd58a644f3b0b0b578778f8c017efff30771efa..11f5be7c03dcd3c03014a40b4901ef9fef1b892b 100644 --- a/tensorflow/core/kernels/neon/depthwiseconv_float.h +++ b/tensorflow/core/kernels/neon/depthwiseconv_float.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_NEON_DEPTHWISECONV_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_NEON_DEPTHWISECONV_H_ +#ifndef TENSORFLOW_CORE_KERNELS_NEON_DEPTHWISECONV_H_ +#define TENSORFLOW_CORE_KERNELS_NEON_DEPTHWISECONV_H_ #include "public/gemmlowp.h" #include "tensorflow/core/kernels/neon/types.h" @@ -722,4 +722,4 @@ void DepthwiseConv(const float* input_data, const Dims<4>& input_dims, } // end namespace neon } // end namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_NEON_DEPTHWISECONV_H_ +#endif // TENSORFLOW_CORE_KERNELS_NEON_DEPTHWISECONV_H_ diff --git a/tensorflow/core/kernels/neon/types.h b/tensorflow/core/kernels/neon/types.h index 4ece22f015954a1867dd2a4a5365cc93c1eaee5d..05ff1bcc6cdbe7bf26766fc0b11909e3da8de71f 100644 --- a/tensorflow/core/kernels/neon/types.h +++ b/tensorflow/core/kernels/neon/types.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_NEON_TYPES_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_NEON_TYPES_H_ +#ifndef TENSORFLOW_CORE_KERNELS_NEON_TYPES_H_ +#define TENSORFLOW_CORE_KERNELS_NEON_TYPES_H_ #include "tensorflow/core/platform/logging.h" @@ -70,4 +70,4 @@ inline int RequiredBufferSizeForDims(const Dims<4>& dims) { } // end namespace neon } // end namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_NEON_TYPES_H_ +#endif // TENSORFLOW_CORE_KERNELS_NEON_TYPES_H_ diff --git a/tensorflow/core/kernels/pack_op.cc b/tensorflow/core/kernels/pack_op.cc index 2923c38662e3c2b74df5c72c513b5e3ecab9f5e5..2033fbf5dc3f238b665c6f4afced06e90c81bb7c 100644 --- a/tensorflow/core/kernels/pack_op.cc +++ b/tensorflow/core/kernels/pack_op.cc @@ -139,7 +139,6 @@ class PackOp : public OpKernel { TF_CALL_ALL_TYPES(REGISTER_PACK); TF_CALL_QUANTIZED_TYPES(REGISTER_PACK); -TF_CALL_bfloat16(REGISTER_PACK); TF_CALL_variant(REGISTER_PACK); #if defined(IS_MOBILE_PLATFORM) && !defined(SUPPORT_SELECTIVE_REGISTRATION) diff --git a/tensorflow/core/kernels/pooling_ops_common.cc b/tensorflow/core/kernels/pooling_ops_common.cc index 6a52a15c931290fcdaabbb259f9dbd86f1824a30..d4241b58090e4e4c1300fdcdc0e46411aa5a88f3 100644 --- a/tensorflow/core/kernels/pooling_ops_common.cc +++ b/tensorflow/core/kernels/pooling_ops_common.cc @@ -222,7 +222,7 @@ void DnnPoolingOp::Compute( output_desc, &output_data) .ok(); OP_REQUIRES(context, status, - errors::Internal("cudnn PoolBackward launch failed")); + errors::Internal("cudnn PoolForward launch failed")); if (data_format == FORMAT_NHWC) { /// Transform the output data from NCHW back to NHWC diff --git a/tensorflow/core/kernels/population_count_op.h b/tensorflow/core/kernels/population_count_op.h index de89582e139b03de48719749ef29a0d3bb638e0e..2c9812967366d8b943715f08caf07ce5804877ca 100644 --- a/tensorflow/core/kernels/population_count_op.h +++ b/tensorflow/core/kernels/population_count_op.h @@ -14,8 +14,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_POPULATION_COUNT_OP_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_POPULATION_COUNT_OP_H_ +#ifndef TENSORFLOW_CORE_KERNELS_POPULATION_COUNT_OP_H_ +#define TENSORFLOW_CORE_KERNELS_POPULATION_COUNT_OP_H_ #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor_types.h" @@ -35,4 +35,4 @@ struct PopulationCount { } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_POPULATION_COUNT_OP_H_ +#endif // TENSORFLOW_CORE_KERNELS_POPULATION_COUNT_OP_H_ diff --git a/tensorflow/core/kernels/quantization_utils.h b/tensorflow/core/kernels/quantization_utils.h index 7c18496357c468322313b7b9064cfd7b3a22661a..9fafe6bb65406a1dbcbcb63624fe58019f9e83a3 100644 --- a/tensorflow/core/kernels/quantization_utils.h +++ b/tensorflow/core/kernels/quantization_utils.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_QUANTIZATION_UTILS_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_QUANTIZATION_UTILS_H_ +#ifndef TENSORFLOW_CORE_KERNELS_QUANTIZATION_UTILS_H_ +#define TENSORFLOW_CORE_KERNELS_QUANTIZATION_UTILS_H_ #define EIGEN_USE_THREADS @@ -956,4 +956,4 @@ class TensorflowGemmContext : public gemmlowp::MultiThreadGemmContextBase { } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_QUANTIZATION_UTILS_H_ +#endif // TENSORFLOW_CORE_KERNELS_QUANTIZATION_UTILS_H_ diff --git a/tensorflow/core/kernels/reference_gemm.h b/tensorflow/core/kernels/reference_gemm.h index bb2a21720f337c61c38e91688fa99360d1270652..c9cc04ed1b7387b9a4a2f335a14100d7c691d507 100644 --- a/tensorflow/core/kernels/reference_gemm.h +++ b/tensorflow/core/kernels/reference_gemm.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_REFERENCE_GEMM_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_REFERENCE_GEMM_H_ +#ifndef TENSORFLOW_CORE_KERNELS_REFERENCE_GEMM_H_ +#define TENSORFLOW_CORE_KERNELS_REFERENCE_GEMM_H_ #include @@ -92,4 +92,4 @@ void ReferenceGemm(bool transpose_a, bool transpose_b, bool transpose_c, } } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_REFERENCE_GEMM_H_ +#endif // TENSORFLOW_CORE_KERNELS_REFERENCE_GEMM_H_ diff --git a/tensorflow/core/kernels/remote_fused_graph_execute_op_test_utils.h b/tensorflow/core/kernels/remote_fused_graph_execute_op_test_utils.h index 3fa052108ec2d466caead1cb3c14e2ecc00a45f9..7de45eaaa16030b6e80c427b2db8ebd7280aed00 100644 --- a/tensorflow/core/kernels/remote_fused_graph_execute_op_test_utils.h +++ b/tensorflow/core/kernels/remote_fused_graph_execute_op_test_utils.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_REMOTE_FUSED_GRAPH_EXECUTE_OP_TEST_UTILS_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_REMOTE_FUSED_GRAPH_EXECUTE_OP_TEST_UTILS_H_ +#ifndef TENSORFLOW_CORE_KERNELS_REMOTE_FUSED_GRAPH_EXECUTE_OP_TEST_UTILS_H_ +#define TENSORFLOW_CORE_KERNELS_REMOTE_FUSED_GRAPH_EXECUTE_OP_TEST_UTILS_H_ #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/framework/scope.h" @@ -86,4 +86,4 @@ class TestRemoteFusedGraphExecutor final : public IRemoteFusedGraphExecutor { } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_REMOTE_FUSED_GRAPH_EXECUTE_OP_TEST_UTILS_H_ +#endif // TENSORFLOW_CORE_KERNELS_REMOTE_FUSED_GRAPH_EXECUTE_OP_TEST_UTILS_H_ diff --git a/tensorflow/core/kernels/remote_fused_graph_execute_utils.h b/tensorflow/core/kernels/remote_fused_graph_execute_utils.h index 541c26baaf999d6ad7b34aaf65bf43cb788da582..f0471442781de7c901e1c1cec69b840186015ce3 100644 --- a/tensorflow/core/kernels/remote_fused_graph_execute_utils.h +++ b/tensorflow/core/kernels/remote_fused_graph_execute_utils.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_REMOTE_FUSED_GRAPH_EXECUTE_UTILS_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_REMOTE_FUSED_GRAPH_EXECUTE_UTILS_H_ +#ifndef TENSORFLOW_CORE_KERNELS_REMOTE_FUSED_GRAPH_EXECUTE_UTILS_H_ +#define TENSORFLOW_CORE_KERNELS_REMOTE_FUSED_GRAPH_EXECUTE_UTILS_H_ #include #include @@ -312,4 +312,4 @@ class RemoteFusedGraphExecuteUtils { }; } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_REMOTE_FUSED_GRAPH_EXECUTE_UTILS_H_ +#endif // TENSORFLOW_CORE_KERNELS_REMOTE_FUSED_GRAPH_EXECUTE_UTILS_H_ diff --git a/tensorflow/core/kernels/reshape_util.h b/tensorflow/core/kernels/reshape_util.h index ed583afd13824eff789ea556045507fb4cff44e6..6777748b63b299b450e4fdc09376f18127c8ab85 100644 --- a/tensorflow/core/kernels/reshape_util.h +++ b/tensorflow/core/kernels/reshape_util.h @@ -13,8 +13,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_RESHAPE_UTIL_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_RESHAPE_UTIL_H_ +#ifndef TENSORFLOW_CORE_KERNELS_RESHAPE_UTIL_H_ +#define TENSORFLOW_CORE_KERNELS_RESHAPE_UTIL_H_ #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/core/status.h" @@ -28,4 +28,4 @@ void Reshape(OpKernelContext *context, const Tensor &input_indices_in, } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_RESHAPE_UTIL_H_ +#endif // TENSORFLOW_CORE_KERNELS_RESHAPE_UTIL_H_ diff --git a/tensorflow/core/kernels/scatter_nd_op_cpu_impl.h b/tensorflow/core/kernels/scatter_nd_op_cpu_impl.h index cffc326174b274e5e42ee5676a6addad7d7c9203..c6c9d4e6588f1f4d847810de1e736220d5572f25 100644 --- a/tensorflow/core/kernels/scatter_nd_op_cpu_impl.h +++ b/tensorflow/core/kernels/scatter_nd_op_cpu_impl.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_SCATTER_ND_OP_CPU_IMPL_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_SCATTER_ND_OP_CPU_IMPL_H_ +#ifndef TENSORFLOW_CORE_KERNELS_SCATTER_ND_OP_CPU_IMPL_H_ +#define TENSORFLOW_CORE_KERNELS_SCATTER_ND_OP_CPU_IMPL_H_ // Functor definitions for ScatterND ops, must be compilable by nvcc. @@ -257,4 +257,4 @@ REGISTER_SCATTER_ND_MATH_SYCL(int32); } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_SCATTER_ND_OP_CPU_IMPL_H_ +#endif // TENSORFLOW_CORE_KERNELS_SCATTER_ND_OP_CPU_IMPL_H_ diff --git a/tensorflow/core/kernels/segment_reduction_ops.h b/tensorflow/core/kernels/segment_reduction_ops.h index b10bea72ba89e7089e0668389995c629644b534d..bcdd42c80c18af381988808db74319e5072f38a7 100644 --- a/tensorflow/core/kernels/segment_reduction_ops.h +++ b/tensorflow/core/kernels/segment_reduction_ops.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_SEGMENT_REDUCTION_OPS_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_SEGMENT_REDUCTION_OPS_H_ +#ifndef TENSORFLOW_CORE_KERNELS_SEGMENT_REDUCTION_OPS_H_ +#define TENSORFLOW_CORE_KERNELS_SEGMENT_REDUCTION_OPS_H_ #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/tensor.h" @@ -98,4 +98,4 @@ struct UnsortedSegmentMaxFunctor: public UnsortedSegmentBaseFunctor #include @@ -109,4 +109,4 @@ class Spectrogram { } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_SPECTROGRAM_H_ +#endif // TENSORFLOW_CORE_KERNELS_SPECTROGRAM_H_ diff --git a/tensorflow/core/kernels/spectrogram_test_utils.cc b/tensorflow/core/kernels/spectrogram_test_utils.cc index 046f6344dfed44069cf27f1b6d923db10498c98c..872a6e9d1bcce09765d1531c5f2898b2badc66a7 100644 --- a/tensorflow/core/kernels/spectrogram_test_utils.cc +++ b/tensorflow/core/kernels/spectrogram_test_utils.cc @@ -70,10 +70,24 @@ bool ReadRawFloatFileToComplexVector( int offset = 0; const int end = data_string.size(); while (offset < end) { +#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ + char arr[4]; + for (int i = 0; i < kBytesPerValue; ++i) { + arr[3 - i] = *(data_string.data() + offset + i); + } + memcpy(&real_out, arr, kBytesPerValue); + offset += kBytesPerValue; + for (int i = 0; i < kBytesPerValue; ++i) { + arr[3 - i] = *(data_string.data() + offset + i); + } + memcpy(&imag_out, arr, kBytesPerValue); + offset += kBytesPerValue; +#else memcpy(&real_out, data_string.data() + offset, kBytesPerValue); offset += kBytesPerValue; memcpy(&imag_out, data_string.data() + offset, kBytesPerValue); offset += kBytesPerValue; +#endif if (row_counter >= row_length) { data->push_back(data_row); data_row.clear(); diff --git a/tensorflow/core/kernels/spectrogram_test_utils.h b/tensorflow/core/kernels/spectrogram_test_utils.h index 59a903549e853b0d270ba8cd565830f1310b677e..d4187076e748af5454e6dd03d05e49d923f1e9d2 100644 --- a/tensorflow/core/kernels/spectrogram_test_utils.h +++ b/tensorflow/core/kernels/spectrogram_test_utils.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_SPECTROGRAM_TEST_UTILS_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_SPECTROGRAM_TEST_UTILS_H_ +#ifndef TENSORFLOW_CORE_KERNELS_SPECTROGRAM_TEST_UTILS_H_ +#define TENSORFLOW_CORE_KERNELS_SPECTROGRAM_TEST_UTILS_H_ #include #include @@ -78,4 +78,4 @@ void SineWave(int sample_rate, float frequency, float duration_seconds, } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_SPECTROGRAM_TEST_UTILS_H_ +#endif // TENSORFLOW_CORE_KERNELS_SPECTROGRAM_TEST_UTILS_H_ diff --git a/tensorflow/core/kernels/tile_ops_cpu_impl.h b/tensorflow/core/kernels/tile_ops_cpu_impl.h index a6eed4935d5c4a2aaa8618bab88998d4ce060ecb..054b31ef9e0b4904d8803d1c4542ff805e0a7673 100644 --- a/tensorflow/core/kernels/tile_ops_cpu_impl.h +++ b/tensorflow/core/kernels/tile_ops_cpu_impl.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_TILE_OPS_CPU_IMPL_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_TILE_OPS_CPU_IMPL_H_ +#ifndef TENSORFLOW_CORE_KERNELS_TILE_OPS_CPU_IMPL_H_ +#define TENSORFLOW_CORE_KERNELS_TILE_OPS_CPU_IMPL_H_ #define EIGEN_USE_THREADS @@ -68,4 +68,4 @@ TF_CALL_int64(DEFINE_TYPE); } // end namespace functor } // end namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_TILE_OPS_CPU_IMPL_H_ +#endif // TENSORFLOW_CORE_KERNELS_TILE_OPS_CPU_IMPL_H_ diff --git a/tensorflow/core/kernels/tile_ops_gpu_impl.h b/tensorflow/core/kernels/tile_ops_gpu_impl.h index 592f99e9b7b5c928c7e522b734186ab0225cd1d0..8da337dabd2e7fc021ec92df97091d15fa39aeab 100644 --- a/tensorflow/core/kernels/tile_ops_gpu_impl.h +++ b/tensorflow/core/kernels/tile_ops_gpu_impl.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_TILE_OPS_GPU_IMPL_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_TILE_OPS_GPU_IMPL_H_ +#ifndef TENSORFLOW_CORE_KERNELS_TILE_OPS_GPU_IMPL_H_ +#define TENSORFLOW_CORE_KERNELS_TILE_OPS_GPU_IMPL_H_ // Header used to split up compilation of GPU tile ops. For each type you want // to have tile ops, create a .cu.cc file containing @@ -56,4 +56,4 @@ limitations under the License. } \ } -#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_TILE_OPS_GPU_IMPL_H_ +#endif // TENSORFLOW_CORE_KERNELS_TILE_OPS_GPU_IMPL_H_ diff --git a/tensorflow/core/kernels/transpose_functor_cpu.cc b/tensorflow/core/kernels/transpose_functor_cpu.cc index 41b73fdaf4aced13070164afb81825592637f8c4..5198df7e16e020f0ee19baa387ccae899e21499a 100644 --- a/tensorflow/core/kernels/transpose_functor_cpu.cc +++ b/tensorflow/core/kernels/transpose_functor_cpu.cc @@ -88,6 +88,18 @@ struct Transpose { internal::TransposeUsingEigen(d, in, perm, conjugate, out); break; + case 6: + internal::TransposeUsingEigen(d, in, perm, conjugate, + out); + break; + case 7: + internal::TransposeUsingEigen(d, in, perm, conjugate, + out); + break; + case 8: + internal::TransposeUsingEigen(d, in, perm, conjugate, + out); + break; default: TransposeSimple(d, in, perm, out); break; diff --git a/tensorflow/core/kernels/transpose_functor_gpu.cu.cc b/tensorflow/core/kernels/transpose_functor_gpu.cu.cc index 493dac9a7ca5a57dba10a3c155299d78e3a69f38..d6a237d6c183cbacf2b5bbbd5f5e9034e84c73af 100644 --- a/tensorflow/core/kernels/transpose_functor_gpu.cu.cc +++ b/tensorflow/core/kernels/transpose_functor_gpu.cu.cc @@ -201,6 +201,27 @@ struct Transpose { out); } break; + case 6: + if (!internal::TransposeUsingTile::run(d, in, perm, + out)) { + internal::TransposeUsingEigen(d, in, perm, conjugate, + out); + } + break; + case 7: + if (!internal::TransposeUsingTile::run(d, in, perm, + out)) { + internal::TransposeUsingEigen(d, in, perm, conjugate, + out); + } + break; + case 8: + if (!internal::TransposeUsingTile::run(d, in, perm, + out)) { + internal::TransposeUsingEigen(d, in, perm, conjugate, + out); + } + break; default: internal::TransposeSimple(d, in, perm, out); break; diff --git a/tensorflow/core/kernels/winograd_transform.h b/tensorflow/core/kernels/winograd_transform.h index 5caee9fdc14ddeeae5adbf9fa22cfc04ac53b58a..d22710e503285150cf270f5c4e32796f275171a0 100644 --- a/tensorflow/core/kernels/winograd_transform.h +++ b/tensorflow/core/kernels/winograd_transform.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_WINOGRAD_TRANSFORM_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_WINOGRAD_TRANSFORM_H_ +#ifndef TENSORFLOW_CORE_KERNELS_WINOGRAD_TRANSFORM_H_ +#define TENSORFLOW_CORE_KERNELS_WINOGRAD_TRANSFORM_H_ #include "tensorflow/core/kernels/deep_conv2d.h" @@ -374,4 +374,4 @@ void WinogradTransform::GetOutputTransformMatrix(const int64 rows, } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_WINOGRAD_TRANSFORM_H_ +#endif // TENSORFLOW_CORE_KERNELS_WINOGRAD_TRANSFORM_H_ diff --git a/tensorflow/core/kernels/xent_op.cc b/tensorflow/core/kernels/xent_op.cc index dc21cee3a8a5a76d8fe5d0d88eae03e7cede3f58..0f8d027caadab2dee04d3041ed515a40f22476f3 100644 --- a/tensorflow/core/kernels/xent_op.cc +++ b/tensorflow/core/kernels/xent_op.cc @@ -67,10 +67,12 @@ class SoftmaxXentWithLogitsOp : public OpKernel { // Try to reuse the logits_in buffer for the backprop output. OP_REQUIRES_OK(context, context->forward_input_or_allocate_output( {0}, 1, logits_in.shape(), &back_out)); - functor::XentFunctor functor; - functor(context->eigen_device(), logits_in.matrix(), - labels_in.matrix(), scratch.matrix(), loss_out->vec(), - back_out->matrix()); + if (logits_in.dim_size(0) > 0) { + functor::XentFunctor functor; + functor(context->eigen_device(), logits_in.matrix(), + labels_in.matrix(), scratch.matrix(), loss_out->vec(), + back_out->matrix()); + } } }; diff --git a/tensorflow/core/kernels/xsmm_conv2d.h b/tensorflow/core/kernels/xsmm_conv2d.h index b439511dc78b46dc90eb8523b98b42d9ba1de45a..003291329a8d3c4062aee00c5b5e1ab8e0ebf8c2 100644 --- a/tensorflow/core/kernels/xsmm_conv2d.h +++ b/tensorflow/core/kernels/xsmm_conv2d.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_XSMM_CONV2D_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_XSMM_CONV2D_H_ +#ifndef TENSORFLOW_CORE_KERNELS_XSMM_CONV2D_H_ +#define TENSORFLOW_CORE_KERNELS_XSMM_CONV2D_H_ #include "tensorflow/core/framework/types.h" #include "tensorflow/core/util/tensor_format.h" @@ -57,4 +57,4 @@ struct XsmmBkwFilterConv2D { } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_XSMM_CONV2D_H_ +#endif // TENSORFLOW_CORE_KERNELS_XSMM_CONV2D_H_ diff --git a/tensorflow/core/lib/core/bitmap.h b/tensorflow/core/lib/core/bitmap.h index b30479fa1bbec58697d50a6bb85d6f430454e5e9..8ff1e666b4ffcdc09353b57b949584404be4aeed 100644 --- a/tensorflow/core/lib/core/bitmap.h +++ b/tensorflow/core/lib/core/bitmap.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_LIB_CORE_BITMAP_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_LIB_CORE_BITMAP_H_ +#ifndef TENSORFLOW_CORE_LIB_CORE_BITMAP_H_ +#define TENSORFLOW_CORE_LIB_CORE_BITMAP_H_ #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -103,4 +103,4 @@ inline void Bitmap::clear(size_t i) { } // namespace core } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_LIB_CORE_BITMAP_H_ +#endif // TENSORFLOW_CORE_LIB_CORE_BITMAP_H_ diff --git a/tensorflow/core/lib/gif/gif_io.cc b/tensorflow/core/lib/gif/gif_io.cc index b5c0d9f621dd2e6fa8c5fd64d71f886fcfb3fd1e..0f6999c88fca3fd7ab91d2f3e28348e22d106f45 100644 --- a/tensorflow/core/lib/gif/gif_io.cc +++ b/tensorflow/core/lib/gif/gif_io.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/core/lib/gif/gif_io.h" #include "tensorflow/core/lib/gtl/cleanup.h" +#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/gif.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/mem.h" @@ -44,7 +45,8 @@ int input_callback(GifFileType* gif_file, GifByteType* buf, int size) { } uint8* Decode(const void* srcdata, int datasize, - std::function allocate_output) { + const std::function& allocate_output, + string* error_string) { int error_code = D_GIF_SUCCEEDED; InputBufferInfo info = {reinterpret_cast(srcdata), datasize}; GifFileType* gif_file = @@ -57,17 +59,17 @@ uint8* Decode(const void* srcdata, int datasize, } }); if (error_code != D_GIF_SUCCEEDED) { - LOG(ERROR) << "Fail to open gif file, reason: " - << GifErrorString(error_code); + *error_string = strings::StrCat("failed to open gif file: ", + GifErrorString(error_code)); return nullptr; } if (DGifSlurp(gif_file) != GIF_OK) { - LOG(ERROR) << "Fail to slurp gif file, reason: " - << GifErrorString(gif_file->Error); + *error_string = strings::StrCat("failed to slurp gif file: ", + GifErrorString(gif_file->Error)); return nullptr; } if (gif_file->ImageCount <= 0) { - LOG(ERROR) << "Gif file does not contain any image"; + *error_string = strings::StrCat("gif file does not contain any image"); return nullptr; } @@ -83,7 +85,7 @@ uint8* Decode(const void* srcdata, int datasize, GifImageDesc* img_desc = &this_image->ImageDesc; if (img_desc->Left != 0 || img_desc->Top != 0 || img_desc->Width != width || img_desc->Height != height) { - LOG(ERROR) << "Can't process optimized gif."; + *error_string = strings::StrCat("can't process optimized gif"); return nullptr; } diff --git a/tensorflow/core/lib/gif/gif_io.h b/tensorflow/core/lib/gif/gif_io.h index 5399e6a53812b70ac25d33dc5c8acd93a8a82f04..0a7967a5a1534ea61e6adab67492802882a02c5c 100644 --- a/tensorflow/core/lib/gif/gif_io.h +++ b/tensorflow/core/lib/gif/gif_io.h @@ -43,7 +43,8 @@ namespace tensorflow { namespace gif { uint8* Decode(const void* srcdata, int datasize, - std::function allocate_output); + const std::function& allocate_output, + string* error_string); } // namespace gif } // namespace tensorflow diff --git a/tensorflow/core/lib/gtl/compactptrset.h b/tensorflow/core/lib/gtl/compactptrset.h index 1d4d6cc8d2de035f345c4fef8121041b091c24d7..d3d23b94aa26471f7b0d178296c7112c5084f8cf 100644 --- a/tensorflow/core/lib/gtl/compactptrset.h +++ b/tensorflow/core/lib/gtl/compactptrset.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_LIB_GTL_COMPACTPTRSET_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_LIB_GTL_COMPACTPTRSET_H_ +#ifndef TENSORFLOW_CORE_LIB_GTL_COMPACTPTRSET_H_ +#define TENSORFLOW_CORE_LIB_GTL_COMPACTPTRSET_H_ #include #include "tensorflow/core/lib/gtl/flatset.h" @@ -205,4 +205,4 @@ class CompactPointerSet { } // namespace gtl } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_LIB_GTL_COMPACTPTRSET_H_ +#endif // TENSORFLOW_CORE_LIB_GTL_COMPACTPTRSET_H_ diff --git a/tensorflow/core/lib/gtl/flatmap.h b/tensorflow/core/lib/gtl/flatmap.h index 6dd67ad2ea56a2691d27cfe4b9a11c0aafa05d01..889d2ddaa6be36332a3b810c0aefef6ecb684e40 100644 --- a/tensorflow/core/lib/gtl/flatmap.h +++ b/tensorflow/core/lib/gtl/flatmap.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_LIB_GTL_FLATMAP_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_LIB_GTL_FLATMAP_H_ +#ifndef TENSORFLOW_CORE_LIB_GTL_FLATMAP_H_ +#define TENSORFLOW_CORE_LIB_GTL_FLATMAP_H_ #include #include @@ -379,4 +379,4 @@ class FlatMap { } // namespace gtl } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_LIB_GTL_FLATMAP_H_ +#endif // TENSORFLOW_CORE_LIB_GTL_FLATMAP_H_ diff --git a/tensorflow/core/lib/gtl/flatrep.h b/tensorflow/core/lib/gtl/flatrep.h index bb405b327aa86983a171727b76a63109d7028431..0d7e7487fc33353603bf3c4d56d8d04466e326a1 100644 --- a/tensorflow/core/lib/gtl/flatrep.h +++ b/tensorflow/core/lib/gtl/flatrep.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_LIB_GTL_FLATREP_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_LIB_GTL_FLATREP_H_ +#ifndef TENSORFLOW_CORE_LIB_GTL_FLATREP_H_ +#define TENSORFLOW_CORE_LIB_GTL_FLATREP_H_ #include #include @@ -328,4 +328,4 @@ class FlatRep { } // namespace gtl } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_LIB_GTL_FLATREP_H_ +#endif // TENSORFLOW_CORE_LIB_GTL_FLATREP_H_ diff --git a/tensorflow/core/lib/gtl/flatset.h b/tensorflow/core/lib/gtl/flatset.h index 2b7f31ab224f3d70fe5e69ced17f54cc1e742453..f31e3abe4115887ed1f2ed3bec52c73b2622715c 100644 --- a/tensorflow/core/lib/gtl/flatset.h +++ b/tensorflow/core/lib/gtl/flatset.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_LIB_GTL_FLATSET_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_LIB_GTL_FLATSET_H_ +#ifndef TENSORFLOW_CORE_LIB_GTL_FLATSET_H_ +#define TENSORFLOW_CORE_LIB_GTL_FLATSET_H_ #include #include @@ -278,4 +278,4 @@ class FlatSet { } // namespace gtl } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_LIB_GTL_FLATSET_H_ +#endif // TENSORFLOW_CORE_LIB_GTL_FLATSET_H_ diff --git a/tensorflow/core/lib/io/buffered_inputstream.h b/tensorflow/core/lib/io/buffered_inputstream.h index 2b824f35f80de47f951477a9352bedeca1290848..924619f40f23152e8155651c72538ef5da98e611 100644 --- a/tensorflow/core/lib/io/buffered_inputstream.h +++ b/tensorflow/core/lib/io/buffered_inputstream.h @@ -104,4 +104,4 @@ class BufferedInputStream : public InputStreamInterface { } // namespace io } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_LIB_IO_BUFFERED_INPUTSTREAM_H_ +#endif // TENSORFLOW_LIB_IO_BUFFERED_INPUTSTREAM_H_ diff --git a/tensorflow/core/lib/io/compression.h b/tensorflow/core/lib/io/compression.h index 7a0c5c12a7461546a7511ccc967237336a61b744..ef90c60a3a411cdc94a9f92522116db340e04f1b 100644 --- a/tensorflow/core/lib/io/compression.h +++ b/tensorflow/core/lib/io/compression.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_LIB_IO_COMPRESSION_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_LIB_IO_COMPRESSION_H_ +#ifndef TENSORFLOW_CORE_LIB_IO_COMPRESSION_H_ +#define TENSORFLOW_CORE_LIB_IO_COMPRESSION_H_ namespace tensorflow { namespace io { @@ -27,4 +27,4 @@ extern const char kGzip[]; } } -#endif // THIRD_PARTY_TENSORFLOW_CORE_LIB_IO_COMPRESSION_H_ +#endif // TENSORFLOW_CORE_LIB_IO_COMPRESSION_H_ diff --git a/tensorflow/core/lib/io/inputstream_interface.h b/tensorflow/core/lib/io/inputstream_interface.h index 096248693bb83cb4e4ede64fb3e9aac2bee42c7a..3083d20776f8a85d03a07756954980fd7e100141 100644 --- a/tensorflow/core/lib/io/inputstream_interface.h +++ b/tensorflow/core/lib/io/inputstream_interface.h @@ -54,4 +54,4 @@ class InputStreamInterface { } // namespace io } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_LIB_IO_INPUTSTREAM_INTERFACE_H_ +#endif // TENSORFLOW_CORE_LIB_IO_INPUTSTREAM_INTERFACE_H_ diff --git a/tensorflow/core/lib/io/random_inputstream.cc b/tensorflow/core/lib/io/random_inputstream.cc index 8b8c1392a1dce339a56b718af036248f22ba0b59..09336e79cda67b324299d78c65217e6a7b40dc21 100644 --- a/tensorflow/core/lib/io/random_inputstream.cc +++ b/tensorflow/core/lib/io/random_inputstream.cc @@ -57,6 +57,43 @@ Status RandomAccessInputStream::ReadNBytes(int64 bytes_to_read, return Status::OK(); } +// To limit memory usage, the default implementation of SkipNBytes() only reads +// 8MB at a time. +static constexpr int64 kMaxSkipSize = 8 * 1024 * 1024; + +Status RandomAccessInputStream::SkipNBytes(int64 bytes_to_skip) { + if (bytes_to_skip < 0) { + return errors::InvalidArgument("Can't skip a negative number of bytes"); + } + std::unique_ptr scratch(new char[kMaxSkipSize]); + // Try to read 1 bytes first, if we could complete the read then EOF is + // not reached yet and we could return. + if (bytes_to_skip > 0) { + StringPiece data; + Status s = file_->Read(pos_ + bytes_to_skip - 1, 1, &data, scratch.get()); + if ((s.ok() || errors::IsOutOfRange(s)) && data.size() == 1) { + pos_ += bytes_to_skip; + return Status::OK(); + } + } + // Read kDefaultSkipSize at a time till bytes_to_skip. + while (bytes_to_skip > 0) { + int64 bytes_to_read = std::min(kMaxSkipSize, bytes_to_skip); + StringPiece data; + Status s = file_->Read(pos_, bytes_to_read, &data, scratch.get()); + if (s.ok() || errors::IsOutOfRange(s)) { + pos_ += data.size(); + } else { + return s; + } + if (data.size() < bytes_to_read) { + return errors::OutOfRange("reached end of file"); + } + bytes_to_skip -= bytes_to_read; + } + return Status::OK(); +} + int64 RandomAccessInputStream::Tell() const { return pos_; } } // namespace io diff --git a/tensorflow/core/lib/io/random_inputstream.h b/tensorflow/core/lib/io/random_inputstream.h index 09ebe9ba49e741945457c82cf0c64b3c1268a694..bdbdbd71ff914cfaf1690b2813ddbab070a9f99a 100644 --- a/tensorflow/core/lib/io/random_inputstream.h +++ b/tensorflow/core/lib/io/random_inputstream.h @@ -34,6 +34,8 @@ class RandomAccessInputStream : public InputStreamInterface { Status ReadNBytes(int64 bytes_to_read, string* result) override; + Status SkipNBytes(int64 bytes_to_skip) override; + int64 Tell() const override; Status Seek(int64 position) { diff --git a/tensorflow/core/lib/io/snappy/snappy_outputbuffer.h b/tensorflow/core/lib/io/snappy/snappy_outputbuffer.h index 5d330a2c5a3d97456495893d3bb87c376beeeb1f..5aea503846df7c1b0f3c3f140a820dc0cd951726 100644 --- a/tensorflow/core/lib/io/snappy/snappy_outputbuffer.h +++ b/tensorflow/core/lib/io/snappy/snappy_outputbuffer.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_LIB_IO_SNAPPY_OUTPUTBUFFER_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_LIB_IO_SNAPPY_OUTPUTBUFFER_H_ +#ifndef TENSORFLOW_CORE_LIB_IO_SNAPPY_OUTPUTBUFFER_H_ +#define TENSORFLOW_CORE_LIB_IO_SNAPPY_OUTPUTBUFFER_H_ #include #include "tensorflow/core/lib/core/status.h" @@ -117,4 +117,4 @@ class SnappyOutputBuffer { } // namespace io } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_LIB_IO_SNAPPY_OUTPUTBUFFER_H_ +#endif // TENSORFLOW_CORE_LIB_IO_SNAPPY_OUTPUTBUFFER_H_ diff --git a/tensorflow/core/lib/io/zlib_outputbuffer.h b/tensorflow/core/lib/io/zlib_outputbuffer.h index 5cad2e945705701662d845315c86acbf70f1f1d3..3d86d89a99204c1c8a80081b299e28837141b33d 100644 --- a/tensorflow/core/lib/io/zlib_outputbuffer.h +++ b/tensorflow/core/lib/io/zlib_outputbuffer.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_LIB_IO_COMPRESSED_OUTPUTBUFFER_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_LIB_IO_COMPRESSED_OUTPUTBUFFER_H_ +#ifndef TENSORFLOW_CORE_LIB_IO_COMPRESSED_OUTPUTBUFFER_H_ +#define TENSORFLOW_CORE_LIB_IO_COMPRESSED_OUTPUTBUFFER_H_ #include @@ -143,4 +143,4 @@ class ZlibOutputBuffer : public WritableFile { } // namespace io } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_LIB_IO_COMPRESSED_OUTPUTBUFFER_H_ +#endif // TENSORFLOW_CORE_LIB_IO_COMPRESSED_OUTPUTBUFFER_H_ diff --git a/tensorflow/core/lib/monitoring/collected_metrics.h b/tensorflow/core/lib/monitoring/collected_metrics.h index acdb0d86edb1a15631c324afe9d535e0660c4b98..e2009816097804c228a094575d05e732c08b4b90 100644 --- a/tensorflow/core/lib/monitoring/collected_metrics.h +++ b/tensorflow/core/lib/monitoring/collected_metrics.h @@ -17,8 +17,8 @@ limitations under the License. // These are to be used only by the CollectionRegistry and exporters which // collect metrics using the CollectionRegistry. -#ifndef THIRD_PARTY_TENSORFLOW_CORE_LIB_MONITORING_COLLECTED_METRICS_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_LIB_MONITORING_COLLECTED_METRICS_H_ +#ifndef TENSORFLOW_CORE_LIB_MONITORING_COLLECTED_METRICS_H_ +#define TENSORFLOW_CORE_LIB_MONITORING_COLLECTED_METRICS_H_ #include #include @@ -151,4 +151,4 @@ struct CollectedMetrics { } // namespace monitoring } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_LIB_MONITORING_COLLECTED_METRICS_H_ +#endif // TENSORFLOW_CORE_LIB_MONITORING_COLLECTED_METRICS_H_ diff --git a/tensorflow/core/lib/monitoring/collection_registry.h b/tensorflow/core/lib/monitoring/collection_registry.h index 2c8e250c5631ee8a56d6871c1a61ef17efc97c82..63cc0f550df79c4c6821f4618a4d8324969577b2 100644 --- a/tensorflow/core/lib/monitoring/collection_registry.h +++ b/tensorflow/core/lib/monitoring/collection_registry.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_LIB_MONITORING_COLLECTION_REGISTRY_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_LIB_MONITORING_COLLECTION_REGISTRY_H_ +#ifndef TENSORFLOW_CORE_LIB_MONITORING_COLLECTION_REGISTRY_H_ +#define TENSORFLOW_CORE_LIB_MONITORING_COLLECTION_REGISTRY_H_ #include #include @@ -356,4 +356,4 @@ MetricCollector MetricCollectorGetter::Get( } // namespace monitoring } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_LIB_MONITORING_COLLECTION_REGISTRY_H_ +#endif // TENSORFLOW_CORE_LIB_MONITORING_COLLECTION_REGISTRY_H_ diff --git a/tensorflow/core/lib/monitoring/counter.h b/tensorflow/core/lib/monitoring/counter.h index 7240348a9b764e3092f71da4bce9a953c08e7900..8ff810db41d98024eb9e6be1e1c2a10a4b792a75 100644 --- a/tensorflow/core/lib/monitoring/counter.h +++ b/tensorflow/core/lib/monitoring/counter.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_LIB_MONITORING_COUNTER_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_LIB_MONITORING_COUNTER_H_ +#ifndef TENSORFLOW_CORE_LIB_MONITORING_COUNTER_H_ +#define TENSORFLOW_CORE_LIB_MONITORING_COUNTER_H_ // We replace this implementation with a null implementation for mobile // platforms. @@ -172,4 +172,4 @@ CounterCell* Counter::GetCell(const Labels&... labels) } // namespace tensorflow #endif // IS_MOBILE_PLATFORM -#endif // THIRD_PARTY_TENSORFLOW_CORE_LIB_MONITORING_COUNTER_H_ +#endif // TENSORFLOW_CORE_LIB_MONITORING_COUNTER_H_ diff --git a/tensorflow/core/lib/monitoring/gauge.h b/tensorflow/core/lib/monitoring/gauge.h index ec978a91935890cb0563f39ba0e6554a03d7c86e..ee9a862f40a8266b1f3fa35150a7209f1b61819b 100644 --- a/tensorflow/core/lib/monitoring/gauge.h +++ b/tensorflow/core/lib/monitoring/gauge.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_LIB_MONITORING_GAUGE_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_LIB_MONITORING_GAUGE_H_ +#ifndef TENSORFLOW_CORE_LIB_MONITORING_GAUGE_H_ +#define TENSORFLOW_CORE_LIB_MONITORING_GAUGE_H_ // We replace this implementation with a null implementation for mobile // platforms. @@ -241,4 +241,4 @@ GaugeCell* Gauge::GetCell( } // namespace tensorflow #endif // IS_MOBILE_PLATFORM -#endif // THIRD_PARTY_TENSORFLOW_CORE_LIB_MONITORING_GAUGE_H_ +#endif // TENSORFLOW_CORE_LIB_MONITORING_GAUGE_H_ diff --git a/tensorflow/core/lib/monitoring/metric_def.h b/tensorflow/core/lib/monitoring/metric_def.h index f046842618a03f7a161a11d3b493b71be50ad988..5ecadcc4272581a5e4e2c934cd605bd1a1110fcd 100644 --- a/tensorflow/core/lib/monitoring/metric_def.h +++ b/tensorflow/core/lib/monitoring/metric_def.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_LIB_MONITORING_METRIC_DEF_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_LIB_MONITORING_METRIC_DEF_H_ +#ifndef TENSORFLOW_CORE_LIB_MONITORING_METRIC_DEF_H_ +#define TENSORFLOW_CORE_LIB_MONITORING_METRIC_DEF_H_ #include #include @@ -139,4 +139,4 @@ class MetricDef : public AbstractMetricDef { } // namespace monitoring } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_LIB_MONITORING_METRIC_DEF_H_ +#endif // TENSORFLOW_CORE_LIB_MONITORING_METRIC_DEF_H_ diff --git a/tensorflow/core/lib/monitoring/mobile_counter.h b/tensorflow/core/lib/monitoring/mobile_counter.h index c30bfe026f15922213312c68af0236e3d07d9380..c297d843d2fa7fb487b315fc8870e62fd5ec930d 100644 --- a/tensorflow/core/lib/monitoring/mobile_counter.h +++ b/tensorflow/core/lib/monitoring/mobile_counter.h @@ -15,8 +15,8 @@ limitations under the License. // Null implementation of the Counter metric for mobile platforms. -#ifndef THIRD_PARTY_TENSORFLOW_CORE_LIB_MONITORING_MOBILE_COUNTER_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_LIB_MONITORING_MOBILE_COUNTER_H_ +#ifndef TENSORFLOW_CORE_LIB_MONITORING_MOBILE_COUNTER_H_ +#define TENSORFLOW_CORE_LIB_MONITORING_MOBILE_COUNTER_H_ #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" @@ -64,4 +64,4 @@ class Counter { } // namespace monitoring } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_LIB_MONITORING_MOBILE_COUNTER_H_ +#endif // TENSORFLOW_CORE_LIB_MONITORING_MOBILE_COUNTER_H_ diff --git a/tensorflow/core/lib/monitoring/mobile_gauge.h b/tensorflow/core/lib/monitoring/mobile_gauge.h index ac13ad35c020a45770e8acd7cd0820cbc2ac8cf4..a03b41aef334901eec206ce2ebfcf28251f4e28e 100644 --- a/tensorflow/core/lib/monitoring/mobile_gauge.h +++ b/tensorflow/core/lib/monitoring/mobile_gauge.h @@ -15,8 +15,8 @@ limitations under the License. // Null implementation of the Gauge metric for mobile platforms. -#ifndef THIRD_PARTY_TENSORFLOW_CORE_LIB_MONITORING_MOBILE_GAUGE_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_LIB_MONITORING_MOBILE_GAUGE_H_ +#ifndef TENSORFLOW_CORE_LIB_MONITORING_MOBILE_GAUGE_H_ +#define TENSORFLOW_CORE_LIB_MONITORING_MOBILE_GAUGE_H_ #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" @@ -69,4 +69,4 @@ class Gauge { } // namespace monitoring } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_LIB_MONITORING_MOBILE_GAUGE_H_ +#endif // TENSORFLOW_CORE_LIB_MONITORING_MOBILE_GAUGE_H_ diff --git a/tensorflow/core/lib/monitoring/mobile_sampler.h b/tensorflow/core/lib/monitoring/mobile_sampler.h index cf390e5c7f67723e017b991cd7d0cd15266e24d9..77310dd619fd886c65b3ae3bf7c12d050d82c9d8 100644 --- a/tensorflow/core/lib/monitoring/mobile_sampler.h +++ b/tensorflow/core/lib/monitoring/mobile_sampler.h @@ -15,8 +15,8 @@ limitations under the License. // Null implementation of the Sampler metric for mobile platforms. -#ifndef THIRD_PARTY_TENSORFLOW_CORE_LIB_MONITORING_MOBILE_SAMPLER_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_LIB_MONITORING_MOBILE_SAMPLER_H_ +#ifndef TENSORFLOW_CORE_LIB_MONITORING_MOBILE_SAMPLER_H_ +#define TENSORFLOW_CORE_LIB_MONITORING_MOBILE_SAMPLER_H_ #include @@ -98,4 +98,4 @@ class Sampler { } // namespace monitoring } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_LIB_MONITORING_MOBILE_SAMPLER_H_ +#endif // TENSORFLOW_CORE_LIB_MONITORING_MOBILE_SAMPLER_H_ diff --git a/tensorflow/core/lib/monitoring/sampler.h b/tensorflow/core/lib/monitoring/sampler.h index c7a05428e2dced68ce3dc165616837084916f49d..a4f397f5566a7425b197e5de91aed811ec08e564 100644 --- a/tensorflow/core/lib/monitoring/sampler.h +++ b/tensorflow/core/lib/monitoring/sampler.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_LIB_MONITORING_SAMPLER_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_LIB_MONITORING_SAMPLER_H_ +#ifndef TENSORFLOW_CORE_LIB_MONITORING_SAMPLER_H_ +#define TENSORFLOW_CORE_LIB_MONITORING_SAMPLER_H_ // We replace this implementation with a null implementation for mobile // platforms. @@ -215,4 +215,4 @@ SamplerCell* Sampler::GetCell(const Labels&... labels) } // namespace tensorflow #endif // IS_MOBILE_PLATFORM -#endif // THIRD_PARTY_TENSORFLOW_CORE_LIB_MONITORING_SAMPLER_H_ +#endif // TENSORFLOW_CORE_LIB_MONITORING_SAMPLER_H_ diff --git a/tensorflow/core/lib/strings/proto_text_util.h b/tensorflow/core/lib/strings/proto_text_util.h index ed6d0af0105c37e77debdab1db549d131752d615..05dbda6e152b7a3b820e36f7c1b56094e2dc04fa 100644 --- a/tensorflow/core/lib/strings/proto_text_util.h +++ b/tensorflow/core/lib/strings/proto_text_util.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_LIB_STRINGS_PROTO_TEXT_UTIL_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_LIB_STRINGS_PROTO_TEXT_UTIL_H_ +#ifndef TENSORFLOW_CORE_LIB_STRINGS_PROTO_TEXT_UTIL_H_ +#define TENSORFLOW_CORE_LIB_STRINGS_PROTO_TEXT_UTIL_H_ #include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/lib/strings/scanner.h" @@ -164,4 +164,4 @@ bool ProtoParseStringLiteralFromScanner(Scanner* scanner, string* value); } // namespace strings } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_LIB_STRINGS_PROTO_TEXT_UTIL_H_ +#endif // TENSORFLOW_CORE_LIB_STRINGS_PROTO_TEXT_UTIL_H_ diff --git a/tensorflow/core/ops/batch_ops.cc b/tensorflow/core/ops/batch_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..a64582acee7e84a6eb5c73a61d57148d994558c9 --- /dev/null +++ b/tensorflow/core/ops/batch_ops.cc @@ -0,0 +1,84 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" + +namespace tensorflow { + +REGISTER_OP("Batch") + .Input("in_tensors: T") + .Output("batched_tensors: T") + .Output("batch_index: int64") + .Output("id: int64") + .Attr("num_batch_threads: int") + .Attr("max_batch_size: int") + .Attr("batch_timeout_micros: int") + .Attr("allowed_batch_sizes: list(int) = []") + .Attr("grad_timeout_micros: int") + .Attr("container: string = ''") + .Attr("shared_name: string = ''") + .Attr("batching_queue: string = ''") + .Attr("T: list(type)") + .SetShapeFn([](shape_inference::InferenceContext* c) { + std::vector in_shapes; + TF_RETURN_IF_ERROR(c->input("in_tensors", &in_shapes)); + std::vector out_shapes(in_shapes.size()); + for (int i = 0; i < in_shapes.size(); ++i) { + TF_RETURN_IF_ERROR( + c->ReplaceDim(in_shapes[i], 0, c->UnknownDim(), &out_shapes[i])); + } + TF_RETURN_IF_ERROR(c->set_output("batched_tensors", out_shapes)); + TF_RETURN_IF_ERROR(c->set_output("id", {c->Scalar()})); + TF_RETURN_IF_ERROR(c->set_output( + "batch_index", + {c->MakeShape({shape_inference::DimensionOrConstant(c->UnknownDim()), + shape_inference::DimensionOrConstant(3)})})); + return Status::OK(); + }); + +REGISTER_OP("Unbatch") + .Input("batched_tensor: T") + .Input("batch_index: int64") + .Input("id: int64") + .Output("unbatched_tensor: T") + .Attr("timeout_micros: int") + .Attr("container: string = ''") + .Attr("shared_name: string = ''") + .Attr("T: type") + .SetShapeFn([](shape_inference::InferenceContext* c) { + shape_inference::ShapeHandle out_shape; + TF_RETURN_IF_ERROR( + c->ReplaceDim(c->input(0), 0, c->UnknownDim(), &out_shape)); + c->set_output(0, out_shape); + return Status::OK(); + }); + +REGISTER_OP("UnbatchGrad") + .Input("original_input: T") + .Input("batch_index: int64") + .Input("grad: T") + .Input("id: int64") + .Output("batched_grad: T") + .Attr("container: string = ''") + .Attr("shared_name: string = ''") + .Attr("T: type") + .SetShapeFn([](shape_inference::InferenceContext* c) { + c->set_output(0, c->UnknownShapeOfRank(c->Rank(c->input(2)))); + return Status::OK(); + }); + +} // namespace tensorflow diff --git a/tensorflow/core/ops/compat/ops_history.v1.pbtxt b/tensorflow/core/ops/compat/ops_history.v1.pbtxt index 08b685319eaf725b01ab460903b82680c6bd247f..65ab81931ad4261f432034f73269d1e8c8005384 100644 --- a/tensorflow/core/ops/compat/ops_history.v1.pbtxt +++ b/tensorflow/core/ops/compat/ops_history.v1.pbtxt @@ -7903,6 +7903,76 @@ op { } } } +op { + name: "Batch" + input_arg { + name: "in_tensors" + type_list_attr: "T" + } + output_arg { + name: "batched_tensors" + type_list_attr: "T" + } + output_arg { + name: "batch_index" + type: DT_INT64 + } + output_arg { + name: "id" + type: DT_INT64 + } + attr { + name: "num_batch_threads" + type: "int" + } + attr { + name: "max_batch_size" + type: "int" + } + attr { + name: "batch_timeout_micros" + type: "int" + } + attr { + name: "allowed_batch_sizes" + type: "list(int)" + default_value { + list { + } + } + } + attr { + name: "grad_timeout_micros" + type: "int" + } + attr { + name: "container" + type: "string" + default_value { + s: "" + } + } + attr { + name: "shared_name" + type: "string" + default_value { + s: "" + } + } + attr { + name: "batching_queue" + type: "string" + default_value { + s: "" + } + } + attr { + name: "T" + type: "list(type)" + has_minimum: true + minimum: 1 + } +} op { name: "BatchCholesky" input_arg { @@ -18752,6 +18822,57 @@ op { } is_stateful: true } +op { + name: "FixedLengthRecordReader" + output_arg { + name: "reader_handle" + type: DT_STRING + is_ref: true + } + attr { + name: "header_bytes" + type: "int" + default_value { + i: 0 + } + } + attr { + name: "record_bytes" + type: "int" + } + attr { + name: "footer_bytes" + type: "int" + default_value { + i: 0 + } + } + attr { + name: "hop_bytes" + type: "int" + default_value { + i: 0 + } + } + attr { + name: "container" + type: "string" + default_value { + s: "" + } + } + attr { + name: "shared_name" + type: "string" + default_value { + s: "" + } + } + deprecation { + version: 26 + } + is_stateful: true +} op { name: "FixedLengthRecordReaderV2" output_arg { @@ -21314,6 +21435,32 @@ op { } is_stateful: true } +op { + name: "IdentityReader" + output_arg { + name: "reader_handle" + type: DT_STRING + is_ref: true + } + attr { + name: "container" + type: "string" + default_value { + s: "" + } + } + attr { + name: "shared_name" + type: "string" + default_value { + s: "" + } + } + deprecation { + version: 26 + } + is_stateful: true +} op { name: "IdentityReaderV2" output_arg { @@ -22532,6 +22679,30 @@ op { } is_stateful: true } +op { + name: "IteratorGetNextSync" + input_arg { + name: "iterator" + type: DT_RESOURCE + } + output_arg { + name: "components" + type_list_attr: "output_types" + } + attr { + name: "output_types" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "output_shapes" + type: "list(shape)" + has_minimum: true + minimum: 1 + } + is_stateful: true +} op { name: "IteratorSetStatsAggregator" input_arg { @@ -38474,6 +38645,46 @@ op { } } } +op { + name: "ResizeBilinear" + input_arg { + name: "images" + type_attr: "T" + } + input_arg { + name: "size" + type: DT_INT32 + } + output_arg { + name: "resized_images" + type: DT_FLOAT + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_INT8 + type: DT_UINT8 + type: DT_INT16 + type: DT_UINT16 + type: DT_INT32 + type: DT_INT64 + type: DT_BFLOAT16 + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + } + } + } + attr { + name: "align_corners" + type: "bool" + default_value { + b: false + } + } +} op { name: "ResizeBilinearGrad" input_arg { @@ -38507,6 +38718,40 @@ op { } } } +op { + name: "ResizeBilinearGrad" + input_arg { + name: "grads" + type: DT_FLOAT + } + input_arg { + name: "original_image" + type_attr: "T" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_FLOAT + type: DT_BFLOAT16 + type: DT_HALF + type: DT_DOUBLE + } + } + } + attr { + name: "align_corners" + type: "bool" + default_value { + b: false + } + } +} op { name: "ResizeNearestNeighbor" input_arg { @@ -61788,6 +62033,39 @@ op { } is_stateful: true } +op { + name: "TFRecordReader" + output_arg { + name: "reader_handle" + type: DT_STRING + is_ref: true + } + attr { + name: "container" + type: "string" + default_value { + s: "" + } + } + attr { + name: "shared_name" + type: "string" + default_value { + s: "" + } + } + attr { + name: "compression_type" + type: "string" + default_value { + s: "" + } + } + deprecation { + version: 26 + } + is_stateful: true +} op { name: "TFRecordReaderV2" output_arg { @@ -62189,6 +62467,16 @@ op { type: DT_STRING } } +op { + name: "TensorArrayCloseV2" + input_arg { + name: "handle" + type: DT_STRING + } + deprecation { + version: 26 + } +} op { name: "TensorArrayCloseV3" input_arg { @@ -62366,6 +62654,41 @@ op { } } } +op { + name: "TensorArrayGatherV2" + input_arg { + name: "handle" + type: DT_STRING + } + input_arg { + name: "indices" + type: DT_INT32 + } + input_arg { + name: "flow_in" + type: DT_FLOAT + } + output_arg { + name: "value" + type_attr: "dtype" + } + attr { + name: "dtype" + type: "type" + } + attr { + name: "element_shape" + type: "shape" + default_value { + shape { + unknown_rank: true + } + } + } + deprecation { + version: 26 + } +} op { name: "TensorArrayGatherV3" input_arg { @@ -62443,6 +62766,29 @@ op { } is_stateful: true } +op { + name: "TensorArrayGradV2" + input_arg { + name: "handle" + type: DT_STRING + } + input_arg { + name: "flow_in" + type: DT_FLOAT + } + output_arg { + name: "grad_handle" + type: DT_STRING + } + attr { + name: "source" + type: "string" + } + deprecation { + version: 26 + } + is_stateful: true +} op { name: "TensorArrayGradV3" input_arg { @@ -62549,6 +62895,32 @@ op { type: "type" } } +op { + name: "TensorArrayReadV2" + input_arg { + name: "handle" + type: DT_STRING + } + input_arg { + name: "index" + type: DT_INT32 + } + input_arg { + name: "flow_in" + type: DT_FLOAT + } + output_arg { + name: "value" + type_attr: "dtype" + } + attr { + name: "dtype" + type: "type" + } + deprecation { + version: 26 + } +} op { name: "TensorArrayReadV3" input_arg { @@ -62631,6 +63003,36 @@ op { type: "type" } } +op { + name: "TensorArrayScatterV2" + input_arg { + name: "handle" + type: DT_STRING + } + input_arg { + name: "indices" + type: DT_INT32 + } + input_arg { + name: "value" + type_attr: "T" + } + input_arg { + name: "flow_in" + type: DT_FLOAT + } + output_arg { + name: "flow_out" + type: DT_FLOAT + } + attr { + name: "T" + type: "type" + } + deprecation { + version: 26 + } +} op { name: "TensorArrayScatterV3" input_arg { @@ -62693,6 +63095,24 @@ op { type: DT_INT32 } } +op { + name: "TensorArraySizeV2" + input_arg { + name: "handle" + type: DT_STRING + } + input_arg { + name: "flow_in" + type: DT_FLOAT + } + output_arg { + name: "size" + type: DT_INT32 + } + deprecation { + version: 26 + } +} op { name: "TensorArraySizeV3" input_arg { @@ -62767,6 +63187,36 @@ op { type: "type" } } +op { + name: "TensorArraySplitV2" + input_arg { + name: "handle" + type: DT_STRING + } + input_arg { + name: "value" + type_attr: "T" + } + input_arg { + name: "lengths" + type: DT_INT64 + } + input_arg { + name: "flow_in" + type: DT_FLOAT + } + output_arg { + name: "flow_out" + type: DT_FLOAT + } + attr { + name: "T" + type: "type" + } + deprecation { + version: 26 + } +} op { name: "TensorArraySplitV3" input_arg { @@ -62868,6 +63318,55 @@ op { } is_stateful: true } +op { + name: "TensorArrayV2" + input_arg { + name: "size" + type: DT_INT32 + } + output_arg { + name: "handle" + type: DT_STRING + } + attr { + name: "dtype" + type: "type" + } + attr { + name: "element_shape" + type: "shape" + default_value { + shape { + unknown_rank: true + } + } + } + attr { + name: "dynamic_size" + type: "bool" + default_value { + b: false + } + } + attr { + name: "clear_after_read" + type: "bool" + default_value { + b: true + } + } + attr { + name: "tensor_array_name" + type: "string" + default_value { + s: "" + } + } + deprecation { + version: 26 + } + is_stateful: true +} op { name: "TensorArrayV3" input_arg { @@ -63033,6 +63532,36 @@ op { type: "type" } } +op { + name: "TensorArrayWriteV2" + input_arg { + name: "handle" + type: DT_STRING + } + input_arg { + name: "index" + type: DT_INT32 + } + input_arg { + name: "value" + type_attr: "T" + } + input_arg { + name: "flow_in" + type: DT_FLOAT + } + output_arg { + name: "flow_out" + type: DT_FLOAT + } + attr { + name: "T" + type: "type" + } + deprecation { + version: 26 + } +} op { name: "TensorArrayWriteV3" input_arg { @@ -63085,6 +63614,27 @@ op { } is_stateful: true } +op { + name: "TensorListElementShape" + input_arg { + name: "input_handle" + type: DT_VARIANT + } + output_arg { + name: "element_shape" + type_attr: "shape_type" + } + attr { + name: "shape_type" + type: "type" + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } +} op { name: "TensorListFromTensor" input_arg { @@ -63114,6 +63664,25 @@ op { } } } +op { + name: "TensorListGetItem" + input_arg { + name: "input_handle" + type: DT_VARIANT + } + input_arg { + name: "index" + type: DT_INT32 + } + output_arg { + name: "item" + type_attr: "element_dtype" + } + attr { + name: "element_dtype" + type: "type" + } +} op { name: "TensorListLength" input_arg { @@ -63163,6 +63732,58 @@ op { type: "type" } } +op { + name: "TensorListReserve" + input_arg { + name: "element_shape" + type_attr: "shape_type" + } + input_arg { + name: "num_elements" + type: DT_INT32 + } + output_arg { + name: "handle" + type: DT_VARIANT + } + attr { + name: "element_dtype" + type: "type" + } + attr { + name: "shape_type" + type: "type" + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } +} +op { + name: "TensorListSetItem" + input_arg { + name: "input_handle" + type: DT_VARIANT + } + input_arg { + name: "index" + type: DT_INT32 + } + input_arg { + name: "item" + type_attr: "element_dtype" + } + output_arg { + name: "output_handle" + type: DT_VARIANT + } + attr { + name: "element_dtype" + type: "type" + } +} op { name: "TensorListStack" input_arg { @@ -63319,6 +63940,39 @@ op { } is_stateful: true } +op { + name: "TextLineReader" + output_arg { + name: "reader_handle" + type: DT_STRING + is_ref: true + } + attr { + name: "skip_header_lines" + type: "int" + default_value { + i: 0 + } + } + attr { + name: "container" + type: "string" + default_value { + s: "" + } + } + attr { + name: "shared_name" + type: "string" + default_value { + s: "" + } + } + deprecation { + version: 26 + } + is_stateful: true +} op { name: "TextLineReaderV2" output_arg { @@ -64140,6 +64794,88 @@ op { } is_stateful: true } +op { + name: "Unbatch" + input_arg { + name: "batched_tensor" + type_attr: "T" + } + input_arg { + name: "batch_index" + type: DT_INT64 + } + input_arg { + name: "id" + type: DT_INT64 + } + output_arg { + name: "unbatched_tensor" + type_attr: "T" + } + attr { + name: "timeout_micros" + type: "int" + } + attr { + name: "container" + type: "string" + default_value { + s: "" + } + } + attr { + name: "shared_name" + type: "string" + default_value { + s: "" + } + } + attr { + name: "T" + type: "type" + } +} +op { + name: "UnbatchGrad" + input_arg { + name: "original_input" + type_attr: "T" + } + input_arg { + name: "batch_index" + type: DT_INT64 + } + input_arg { + name: "grad" + type_attr: "T" + } + input_arg { + name: "id" + type: DT_INT64 + } + output_arg { + name: "batched_grad" + type_attr: "T" + } + attr { + name: "container" + type: "string" + default_value { + s: "" + } + } + attr { + name: "shared_name" + type: "string" + default_value { + s: "" + } + } + attr { + name: "T" + type: "type" + } +} op { name: "UniformCandidateSampler" input_arg { diff --git a/tensorflow/core/ops/data_flow_ops.cc b/tensorflow/core/ops/data_flow_ops.cc index cf949ed64777b545863fcff68257cf678d45ef97..12c27c79840de6981629984732147671b8a1e28e 100644 --- a/tensorflow/core/ops/data_flow_ops.cc +++ b/tensorflow/core/ops/data_flow_ops.cc @@ -787,7 +787,6 @@ REGISTER_OP("TensorArray") .SetIsStateful() .SetShapeFn(shape_inference::UnknownShape) .Deprecated(16, "Use TensorArrayV3"); -// TODO(cwhipkey): mark this deprecated in favor of V3. REGISTER_OP("TensorArrayV2") .Input("size: int32") .Attr("dtype: type") @@ -802,7 +801,8 @@ REGISTER_OP("TensorArrayV2") TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused)); c->set_output(0, c->Vector(2)); return Status::OK(); - }); + }) + .Deprecated(26, "Use TensorArrayV3"); REGISTER_OP("TensorArrayGrad") .Input("handle: string") .Input("flow_in: float") @@ -811,7 +811,6 @@ REGISTER_OP("TensorArrayGrad") .SetIsStateful() .SetShapeFn(shape_inference::UnknownShape) .Deprecated(16, "Use TensorArrayGradV3"); -// TODO(cwhipkey): mark this deprecated in favor of V3. REGISTER_OP("TensorArrayGradV2") .Input("handle: string") .Input("flow_in: float") @@ -825,7 +824,8 @@ REGISTER_OP("TensorArrayGradV2") TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim)); c->set_output(0, c->Vector(2)); return Status::OK(); - }); + }) + .Deprecated(26, "Use TensorArrayGradV3"); REGISTER_OP("TensorArrayWrite") .Input("handle: Ref(string)") .Input("index: int32") @@ -835,7 +835,6 @@ REGISTER_OP("TensorArrayWrite") .Attr("T: type") .SetShapeFn(shape_inference::UnknownShape) .Deprecated(16, "Use TensorArrayWriteV3"); -// TODO(cwhipkey): mark this deprecated in favor of V3. REGISTER_OP("TensorArrayWriteV2") .Input("handle: string") .Input("index: int32") @@ -853,7 +852,8 @@ REGISTER_OP("TensorArrayWriteV2") TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); return shape_inference::ScalarShape(c); - }); + }) + .Deprecated(26, "Use TensorArrayWriteV3"); REGISTER_OP("TensorArrayRead") .Input("handle: Ref(string)") .Input("index: int32") @@ -862,7 +862,6 @@ REGISTER_OP("TensorArrayRead") .Attr("dtype: type") .SetShapeFn(shape_inference::UnknownShape) .Deprecated(16, "Use TensorArrayReadV3"); -// TODO(cwhipkey): mark this deprecated in favor of V3. REGISTER_OP("TensorArrayReadV2") .Input("handle: string") .Input("index: int32") @@ -878,7 +877,8 @@ REGISTER_OP("TensorArrayReadV2") TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); return shape_inference::UnknownShape(c); - }); + }) + .Deprecated(26, "Use TensorArrayReadV3"); REGISTER_OP("TensorArrayPack") .Input("handle: Ref(string)") .Input("flow_in: float") @@ -904,7 +904,6 @@ REGISTER_OP("TensorArrayGather") .Attr("element_shape: shape = { unknown_rank: true }") .SetShapeFn(shape_inference::UnknownShape) .Deprecated(16, "Use TensorArrayGatherV3"); -// TODO(cwhipkey): mark this deprecated in favor of V3. REGISTER_OP("TensorArrayGatherV2") .Input("handle: string") .Input("indices: int32") @@ -920,7 +919,8 @@ REGISTER_OP("TensorArrayGatherV2") TF_RETURN_IF_ERROR(c->WithValue(c->Dim(c->input(0), 0), 2, &unused_dim)); TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); return shape_inference::UnknownShape(c); - }); + }) + .Deprecated(26, "Use TensorArrayGatherV3"); REGISTER_OP("TensorArrayScatter") .Input("handle: Ref(string)") .Input("indices: int32") @@ -930,7 +930,6 @@ REGISTER_OP("TensorArrayScatter") .Attr("T: type") .SetShapeFn(shape_inference::UnknownShape) .Deprecated(19, "Use TensorArrayGradV3"); -// TODO(cwhipkey): mark this deprecated in favor of V3. REGISTER_OP("TensorArrayScatterV2") .Input("handle: string") .Input("indices: int32") @@ -946,7 +945,8 @@ REGISTER_OP("TensorArrayScatterV2") TF_RETURN_IF_ERROR(c->WithValue(c->Dim(c->input(0), 0), 2, &unused_dim)); TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); return shape_inference::ScalarShape(c); - }); + }) + .Deprecated(26, "Use TensorArrayScatterV3"); REGISTER_OP("TensorArrayConcat") .Input("handle: Ref(string)") .Input("flow_in: float") @@ -983,7 +983,6 @@ REGISTER_OP("TensorArraySplit") .Attr("T: type") .SetShapeFn(shape_inference::UnknownShape) .Deprecated(16, "Use TensorArraySplitV3"); -// TODO(cwhipkey): mark this deprecated in favor of V3. REGISTER_OP("TensorArraySplitV2") .Input("handle: string") .Input("value: T") @@ -1000,14 +999,14 @@ REGISTER_OP("TensorArraySplitV2") TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused)); TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); return shape_inference::ScalarShape(c); - }); + }) + .Deprecated(26, "Use TensorArraySplitV3"); REGISTER_OP("TensorArraySize") .Input("handle: Ref(string)") .Input("flow_in: float") .Output("size: int32") .SetShapeFn(shape_inference::UnknownShape) .Deprecated(16, "Use TensorArraySizeV3"); -// TODO(cwhipkey): mark this deprecated in favor of V3. REGISTER_OP("TensorArraySizeV2") .Input("handle: string") .Input("flow_in: float") @@ -1018,12 +1017,12 @@ REGISTER_OP("TensorArraySizeV2") TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle)); TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim)); return shape_inference::ScalarShape(c); - }); + }) + .Deprecated(26, "Use TensorArraySizeV3"); REGISTER_OP("TensorArrayClose") .Input("handle: Ref(string)") .SetShapeFn([](InferenceContext* c) { return Status::OK(); }) .Deprecated(16, "Use TensorArrayCloseV3"); -// TODO(cwhipkey): mark this deprecated in favor of V3. REGISTER_OP("TensorArrayCloseV2") .Input("handle: string") .SetShapeFn([](InferenceContext* c) { @@ -1032,7 +1031,8 @@ REGISTER_OP("TensorArrayCloseV2") TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle)); TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim)); return Status::OK(); - }); + }) + .Deprecated(26, "Use TensorArrayCloseV3"); // -------------------------------------------------------------------------- diff --git a/tensorflow/core/ops/dataset_ops.cc b/tensorflow/core/ops/dataset_ops.cc index b86816bb5412e59fabbf01acb64a1856fc78bbed..2cae814eab1602e72ffcfd100f9813f8f41c6ac9 100644 --- a/tensorflow/core/ops/dataset_ops.cc +++ b/tensorflow/core/ops/dataset_ops.cc @@ -409,53 +409,49 @@ REGISTER_OP("OneShotIterator") .SetIsStateful() .SetShapeFn(shape_inference::ScalarShape); +namespace { + +Status IteratorGetNextShapeFn(shape_inference::InferenceContext* c) { + shape_inference::ShapeHandle unused; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused)); + std::vector output_shapes; + TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes)); + if (output_shapes.size() != c->num_outputs()) { + return errors::InvalidArgument( + "`output_shapes` must be the same length as `output_types` (", + output_shapes.size(), " vs. ", c->num_outputs()); + } + for (size_t i = 0; i < output_shapes.size(); ++i) { + shape_inference::ShapeHandle output_shape_handle; + TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape( + output_shapes[i], &output_shape_handle)); + c->set_output(static_cast(i), output_shape_handle); + } + return Status::OK(); +} + +} // namespace + REGISTER_OP("IteratorGetNext") .Input("iterator: resource") .Output("components: output_types") .Attr("output_types: list(type) >= 1") .Attr("output_shapes: list(shape) >= 1") - .SetShapeFn([](shape_inference::InferenceContext* c) { - shape_inference::ShapeHandle unused; - TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused)); - std::vector output_shapes; - TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes)); - if (output_shapes.size() != c->num_outputs()) { - return errors::InvalidArgument( - "`output_shapes` must be the same length as `output_types` (", - output_shapes.size(), " vs. ", c->num_outputs()); - } - for (size_t i = 0; i < output_shapes.size(); ++i) { - shape_inference::ShapeHandle output_shape_handle; - TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape( - output_shapes[i], &output_shape_handle)); - c->set_output(static_cast(i), output_shape_handle); - } - return Status::OK(); - }); + .SetShapeFn(IteratorGetNextShapeFn); + +REGISTER_OP("IteratorGetNextSync") + .Input("iterator: resource") + .Output("components: output_types") + .Attr("output_types: list(type) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .SetShapeFn(IteratorGetNextShapeFn); REGISTER_OP("DatasetToSingleElement") .Input("dataset: variant") .Output("components: output_types") .Attr("output_types: list(type) >= 1") .Attr("output_shapes: list(shape) >= 1") - .SetShapeFn([](shape_inference::InferenceContext* c) { - shape_inference::ShapeHandle unused; - TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused)); - std::vector output_shapes; - TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes)); - if (output_shapes.size() != c->num_outputs()) { - return errors::InvalidArgument( - "`output_shapes` must be the same length as `output_types` (", - output_shapes.size(), " vs. ", c->num_outputs()); - } - for (size_t i = 0; i < output_shapes.size(); ++i) { - shape_inference::ShapeHandle output_shape_handle; - TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape( - output_shapes[i], &output_shape_handle)); - c->set_output(static_cast(i), output_shape_handle); - } - return Status::OK(); - }); + .SetShapeFn(IteratorGetNextShapeFn); REGISTER_OP("IteratorToStringHandle") .Input("resource_handle: resource") diff --git a/tensorflow/core/ops/image_ops.cc b/tensorflow/core/ops/image_ops.cc index 31cc662d218802cc6e748b8e262e3c80c317fb22..7484ebb07808a7670d80a4bfdb590e85b94de04f 100644 --- a/tensorflow/core/ops/image_ops.cc +++ b/tensorflow/core/ops/image_ops.cc @@ -181,7 +181,9 @@ REGISTER_OP("ResizeBilinear") .Input("images: T") .Input("size: int32") .Output("resized_images: float") - .Attr("T: {int8, uint8, int16, uint16, int32, int64, half, float, double}") + .Attr( + "T: {int8, uint8, int16, uint16, int32, int64, bfloat16, half, " + "float, double}") .Attr("align_corners: bool = false") .SetShapeFn(ResizeShapeFn); @@ -212,7 +214,7 @@ REGISTER_OP("ResizeBilinearGrad") .Input("grads: float") .Input("original_image: T") .Output("output: T") - .Attr("T: {float, half, double}") + .Attr("T: {float, bfloat16, half, double}") .Attr("align_corners: bool = false") .SetShapeFn([](InferenceContext* c) { c->set_output(0, c->input(1)); diff --git a/tensorflow/core/ops/io_ops.cc b/tensorflow/core/ops/io_ops.cc index 21f0d02ff27924c9361eafcbb545e394d47c7308..7db4d0c4b667ebb11aa95142b88e615248926562 100644 --- a/tensorflow/core/ops/io_ops.cc +++ b/tensorflow/core/ops/io_ops.cc @@ -272,14 +272,14 @@ REGISTER_OP("WholeFileReaderV2") .SetIsStateful() .SetShapeFn(shape_inference::ScalarShape); -// TODO(cwhipkey): mark this deprecated in favor of V2. REGISTER_OP("TextLineReader") .Output("reader_handle: Ref(string)") .Attr("skip_header_lines: int = 0") .Attr("container: string = ''") .Attr("shared_name: string = ''") .SetIsStateful() - .SetShapeFn(TwoElementOutput); + .SetShapeFn(TwoElementOutput) + .Deprecated(26, "Use TextLineReaderV2"); REGISTER_OP("TextLineReaderV2") .Output("reader_handle: resource") @@ -289,7 +289,6 @@ REGISTER_OP("TextLineReaderV2") .SetIsStateful() .SetShapeFn(shape_inference::ScalarShape); -// TODO(cwhipkey): mark this deprecated in favor of V2. REGISTER_OP("FixedLengthRecordReader") .Output("reader_handle: Ref(string)") .Attr("header_bytes: int = 0") @@ -299,7 +298,8 @@ REGISTER_OP("FixedLengthRecordReader") .Attr("container: string = ''") .Attr("shared_name: string = ''") .SetIsStateful() - .SetShapeFn(TwoElementOutput); + .SetShapeFn(TwoElementOutput) + .Deprecated(26, "Use FixedLengthRecordReaderV2"); REGISTER_OP("FixedLengthRecordReaderV2") .Output("reader_handle: resource") @@ -313,14 +313,14 @@ REGISTER_OP("FixedLengthRecordReaderV2") .SetIsStateful() .SetShapeFn(shape_inference::ScalarShape); -// TODO(cwhipkey): mark this deprecated in favor of V2. REGISTER_OP("TFRecordReader") .Output("reader_handle: Ref(string)") .Attr("container: string = ''") .Attr("shared_name: string = ''") .Attr("compression_type: string = ''") .SetIsStateful() - .SetShapeFn(TwoElementOutput); + .SetShapeFn(TwoElementOutput) + .Deprecated(26, "Use TFRecordReaderV2"); REGISTER_OP("TFRecordReaderV2") .Output("reader_handle: resource") @@ -337,13 +337,13 @@ REGISTER_OP("LMDBReader") .SetIsStateful() .SetShapeFn(TwoElementOutput); -// TODO(cwhipkey): mark this deprecated in favor of V2. REGISTER_OP("IdentityReader") .Output("reader_handle: Ref(string)") .Attr("container: string = ''") .Attr("shared_name: string = ''") .SetIsStateful() - .SetShapeFn(TwoElementOutput); + .SetShapeFn(TwoElementOutput) + .Deprecated(26, "Use IdentityReaderV2"); REGISTER_OP("IdentityReaderV2") .Output("reader_handle: resource") diff --git a/tensorflow/core/ops/list_ops.cc b/tensorflow/core/ops/list_ops.cc index db534857720f5cc611a9b092a628a89b649d3783..fa40f41bb949767f76ee9dae60a3f6312bd80186 100644 --- a/tensorflow/core/ops/list_ops.cc +++ b/tensorflow/core/ops/list_ops.cc @@ -176,5 +176,81 @@ REGISTER_OP("TensorListFromTensor") return Status::OK(); }); +REGISTER_OP("TensorListElementShape") + .Input("input_handle: variant") + .Output("element_shape: shape_type") + .Attr("shape_type: {int32, int64}") + .SetShapeFn([](shape_inference::InferenceContext* c) { + auto* handle_data = c->input_handle_shapes_and_types(0); + if (handle_data == nullptr) { + c->set_output(0, c->Vector(c->UnknownDim())); + return Status::OK(); + } + c->set_output(0, c->Vector(c->Rank((*handle_data)[0].shape))); + return Status::OK(); + }); + +REGISTER_OP("TensorListReserve") + .Input("element_shape: shape_type") + .Input("num_elements: int32") + .Output("handle: variant") + .Attr("element_dtype: type") + .Attr("shape_type: {int32, int64}") + .SetShapeFn([](shape_inference::InferenceContext* c) { + shape_inference::ShapeHandle s; + TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s)); + DataType t; + TF_RETURN_IF_ERROR(c->GetAttr("element_dtype", &t)); + c->set_output_handle_shapes_and_types( + 0, std::vector{{s, t}}); + return Status::OK(); + }); + +REGISTER_OP("TensorListGetItem") + .Input("input_handle: variant") + .Input("index: int32") + .Output("item: element_dtype") + .Attr("element_dtype: type") + .SetShapeFn([](shape_inference::InferenceContext* c) { + DataType t; + TF_RETURN_IF_ERROR(c->GetAttr("element_dtype", &t)); + auto* handle_data = c->input_handle_shapes_and_types(0); + shape_inference::ShapeHandle element_shape = c->UnknownShape(); + if (handle_data != nullptr) { + const shape_inference::ShapeAndType& list_shape_type = + (*handle_data)[0]; + element_shape = list_shape_type.shape; + if (list_shape_type.dtype != t) { + return errors::InvalidArgument("Expected list with element dtype ", + DataTypeString(t), + " but got list with element dtype ", + DataTypeString(list_shape_type.dtype)); + } + } + c->set_output(0, element_shape); + return Status::OK(); + }); + +REGISTER_OP("TensorListSetItem") + .Input("input_handle: variant") + .Input("index: int32") + .Input("item: element_dtype") + .Output("output_handle: variant") + .Attr("element_dtype: type") + .SetShapeFn([](shape_inference::InferenceContext* c) { + DataType t; + TF_RETURN_IF_ERROR(c->GetAttr("element_dtype", &t)); + auto* handle_data = c->input_handle_shapes_and_types(0); + if (handle_data == nullptr) { + c->set_output_handle_shapes_and_types(0, {{c->UnknownShape(), t}}); + return Status::OK(); + } + const shape_inference::ShapeAndType& list_shape_type = (*handle_data)[0]; + shape_inference::ShapeHandle s = c->input(2); + TF_RETURN_IF_ERROR(c->Merge(s, list_shape_type.shape, &s)); + c->set_output_handle_shapes_and_types(0, *handle_data); + return Status::OK(); + }); + } // namespace } // namespace tensorflow diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt index 82a895a98b7467358c54b620c9af034486fa98f6..b57206c9c4f53fbf73537f466206f5c1b0caefcb 100644 --- a/tensorflow/core/ops/ops.pbtxt +++ b/tensorflow/core/ops/ops.pbtxt @@ -2737,6 +2737,76 @@ op { } } } +op { + name: "Batch" + input_arg { + name: "in_tensors" + type_list_attr: "T" + } + output_arg { + name: "batched_tensors" + type_list_attr: "T" + } + output_arg { + name: "batch_index" + type: DT_INT64 + } + output_arg { + name: "id" + type: DT_INT64 + } + attr { + name: "num_batch_threads" + type: "int" + } + attr { + name: "max_batch_size" + type: "int" + } + attr { + name: "batch_timeout_micros" + type: "int" + } + attr { + name: "allowed_batch_sizes" + type: "list(int)" + default_value { + list { + } + } + } + attr { + name: "grad_timeout_micros" + type: "int" + } + attr { + name: "container" + type: "string" + default_value { + s: "" + } + } + attr { + name: "shared_name" + type: "string" + default_value { + s: "" + } + } + attr { + name: "batching_queue" + type: "string" + default_value { + s: "" + } + } + attr { + name: "T" + type: "list(type)" + has_minimum: true + minimum: 1 + } +} op { name: "BatchCholesky" input_arg { @@ -8470,6 +8540,10 @@ op { s: "" } } + deprecation { + version: 26 + explanation: "Use FixedLengthRecordReaderV2" + } is_stateful: true } op { @@ -10067,6 +10141,10 @@ op { s: "" } } + deprecation { + version: 26 + explanation: "Use IdentityReaderV2" + } is_stateful: true } op { @@ -10788,6 +10866,30 @@ op { } is_stateful: true } +op { + name: "IteratorGetNextSync" + input_arg { + name: "iterator" + type: DT_RESOURCE + } + output_arg { + name: "components" + type_list_attr: "output_types" + } + attr { + name: "output_types" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "output_shapes" + type: "list(shape)" + has_minimum: true + minimum: 1 + } + is_stateful: true +} op { name: "IteratorSetStatsAggregator" input_arg { @@ -19629,6 +19731,7 @@ op { type: DT_UINT16 type: DT_INT32 type: DT_INT64 + type: DT_BFLOAT16 type: DT_HALF type: DT_FLOAT type: DT_DOUBLE @@ -19663,6 +19766,7 @@ op { allowed_values { list { type: DT_FLOAT + type: DT_BFLOAT16 type: DT_HALF type: DT_DOUBLE } @@ -28548,6 +28652,10 @@ op { s: "" } } + deprecation { + version: 26 + explanation: "Use TFRecordReaderV2" + } is_stateful: true } op { @@ -28818,6 +28926,10 @@ op { name: "handle" type: DT_STRING } + deprecation { + version: 26 + explanation: "Use TensorArrayCloseV3" + } } op { name: "TensorArrayCloseV3" @@ -28997,6 +29109,10 @@ op { } } } + deprecation { + version: 26 + explanation: "Use TensorArrayGatherV3" + } } op { name: "TensorArrayGatherV3" @@ -29074,6 +29190,10 @@ op { name: "source" type: "string" } + deprecation { + version: 26 + explanation: "Use TensorArrayGradV3" + } is_stateful: true } op { @@ -29183,6 +29303,10 @@ op { name: "dtype" type: "type" } + deprecation { + version: 26 + explanation: "Use TensorArrayReadV3" + } } op { name: "TensorArrayReadV3" @@ -29266,6 +29390,10 @@ op { name: "T" type: "type" } + deprecation { + version: 26 + explanation: "Use TensorArrayScatterV3" + } } op { name: "TensorArrayScatterV3" @@ -29329,6 +29457,10 @@ op { name: "size" type: DT_INT32 } + deprecation { + version: 26 + explanation: "Use TensorArraySizeV3" + } } op { name: "TensorArraySizeV3" @@ -29404,6 +29536,10 @@ op { name: "T" type: "type" } + deprecation { + version: 26 + explanation: "Use TensorArraySplitV3" + } } op { name: "TensorArraySplitV3" @@ -29505,6 +29641,10 @@ op { s: "" } } + deprecation { + version: 26 + explanation: "Use TensorArrayV3" + } is_stateful: true } op { @@ -29622,6 +29762,10 @@ op { name: "T" type: "type" } + deprecation { + version: 26 + explanation: "Use TensorArrayWriteV3" + } } op { name: "TensorArrayWriteV3" @@ -29675,6 +29819,27 @@ op { } is_stateful: true } +op { + name: "TensorListElementShape" + input_arg { + name: "input_handle" + type: DT_VARIANT + } + output_arg { + name: "element_shape" + type_attr: "shape_type" + } + attr { + name: "shape_type" + type: "type" + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } +} op { name: "TensorListFromTensor" input_arg { @@ -29704,6 +29869,25 @@ op { } } } +op { + name: "TensorListGetItem" + input_arg { + name: "input_handle" + type: DT_VARIANT + } + input_arg { + name: "index" + type: DT_INT32 + } + output_arg { + name: "item" + type_attr: "element_dtype" + } + attr { + name: "element_dtype" + type: "type" + } +} op { name: "TensorListLength" input_arg { @@ -29753,6 +29937,58 @@ op { type: "type" } } +op { + name: "TensorListReserve" + input_arg { + name: "element_shape" + type_attr: "shape_type" + } + input_arg { + name: "num_elements" + type: DT_INT32 + } + output_arg { + name: "handle" + type: DT_VARIANT + } + attr { + name: "element_dtype" + type: "type" + } + attr { + name: "shape_type" + type: "type" + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } +} +op { + name: "TensorListSetItem" + input_arg { + name: "input_handle" + type: DT_VARIANT + } + input_arg { + name: "index" + type: DT_INT32 + } + input_arg { + name: "item" + type_attr: "element_dtype" + } + output_arg { + name: "output_handle" + type: DT_VARIANT + } + attr { + name: "element_dtype" + type: "type" + } +} op { name: "TensorListStack" input_arg { @@ -29907,6 +30143,10 @@ op { s: "" } } + deprecation { + version: 26 + explanation: "Use TextLineReaderV2" + } is_stateful: true } op { @@ -30289,6 +30529,88 @@ op { } is_stateful: true } +op { + name: "Unbatch" + input_arg { + name: "batched_tensor" + type_attr: "T" + } + input_arg { + name: "batch_index" + type: DT_INT64 + } + input_arg { + name: "id" + type: DT_INT64 + } + output_arg { + name: "unbatched_tensor" + type_attr: "T" + } + attr { + name: "timeout_micros" + type: "int" + } + attr { + name: "container" + type: "string" + default_value { + s: "" + } + } + attr { + name: "shared_name" + type: "string" + default_value { + s: "" + } + } + attr { + name: "T" + type: "type" + } +} +op { + name: "UnbatchGrad" + input_arg { + name: "original_input" + type_attr: "T" + } + input_arg { + name: "batch_index" + type: DT_INT64 + } + input_arg { + name: "grad" + type_attr: "T" + } + input_arg { + name: "id" + type: DT_INT64 + } + output_arg { + name: "batched_grad" + type_attr: "T" + } + attr { + name: "container" + type: "string" + default_value { + s: "" + } + } + attr { + name: "shared_name" + type: "string" + default_value { + s: "" + } + } + attr { + name: "T" + type: "type" + } +} op { name: "UniformCandidateSampler" input_arg { diff --git a/tensorflow/core/platform/cloud/BUILD b/tensorflow/core/platform/cloud/BUILD index 6b6be757f6e825ef15e918f0dac9f7bcb0ed22fa..07aecf848326b23b18b58ae60e896150ab7b4ef9 100644 --- a/tensorflow/core/platform/cloud/BUILD +++ b/tensorflow/core/platform/cloud/BUILD @@ -102,7 +102,7 @@ cc_library( ":http_request", "//tensorflow/core:framework_headers_lib", "//tensorflow/core:lib_internal", - "@curl//:curl", + "@curl", ], ) @@ -119,7 +119,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:test", - "@curl//:curl", + "@curl", ], ) diff --git a/tensorflow/core/platform/cloud/gcs_dns_cache.h b/tensorflow/core/platform/cloud/gcs_dns_cache.h index dd95c18f35053faa500c78da8362fd7691694f84..40f16f10443a6729477310db44b789d71a0ffd48 100644 --- a/tensorflow/core/platform/cloud/gcs_dns_cache.h +++ b/tensorflow/core/platform/cloud/gcs_dns_cache.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_PLATNFORM_CLOUD_DNS_CACHE_H_ -#define THIRD_PARTY_TENSORFLOW_PLATNFORM_CLOUD_DNS_CACHE_H_ +#ifndef TENSORFLOW_PLATNFORM_CLOUD_DNS_CACHE_H_ +#define TENSORFLOW_PLATNFORM_CLOUD_DNS_CACHE_H_ #include @@ -74,4 +74,4 @@ class GcsDnsCache { } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_PLATNFORM_CLOUD_DNS_CACHE_H_ +#endif // TENSORFLOW_PLATNFORM_CLOUD_DNS_CACHE_H_ diff --git a/tensorflow/core/platform/cloud/gcs_file_system.cc b/tensorflow/core/platform/cloud/gcs_file_system.cc index 4b30291076d722973bb12a26a12f60ab2c1d40f7..520720372d9ff12556110967d2c47703ec4b5132 100644 --- a/tensorflow/core/platform/cloud/gcs_file_system.cc +++ b/tensorflow/core/platform/cloud/gcs_file_system.cc @@ -117,6 +117,9 @@ constexpr char kReadRequestTimeout[] = "GCS_READ_REQUEST_TIMEOUT_SECS"; // The environment variable to configure the overall request timeout for // upload requests. constexpr char kWriteRequestTimeout[] = "GCS_WRITE_REQUEST_TIMEOUT_SECS"; +// The environment variable to configure an additional header to send with +// all requests to GCS (format HEADERNAME:HEADERCONTENT) +constexpr char kAdditionalRequestHeader[] = "GCS_ADDITIONAL_REQUEST_HEADER"; // TODO: DO NOT use a hardcoded path Status GetTmpFilename(string* filename) { @@ -607,6 +610,11 @@ bool GetEnvVar(const char* varname, bool (*convert)(StringPiece, T*), return convert(env_value, value); } +bool StringPieceIdentity(StringPiece str, StringPiece* value) { + *value = str; + return true; +} + } // namespace GcsFileSystem::GcsFileSystem() @@ -668,6 +676,36 @@ GcsFileSystem::GcsFileSystem() VLOG(1) << "GCS DNS cache is disabled, because " << kResolveCacheSecs << " = 0 (or is not set)"; } + + // Get the additional header + StringPiece add_header_contents; + if (GetEnvVar(kAdditionalRequestHeader, StringPieceIdentity, + &add_header_contents)) { + size_t split = add_header_contents.find(':', 0); + + if (split != StringPiece::npos) { + StringPiece header_name = add_header_contents.substr(0, split); + StringPiece header_value = add_header_contents.substr(split + 1); + + if (!header_name.empty() && !header_value.empty()) { + additional_header_.reset(new std::pair( + header_name.ToString(), header_value.ToString())); + + VLOG(1) << "GCS additional header ENABLED. " + << "Name: " << additional_header_->first << ", " + << "Value: " << additional_header_->second; + } else { + LOG(ERROR) << "GCS additional header DISABLED. Invalid contents: " + << add_header_contents; + } + } else { + LOG(ERROR) << "GCS additional header DISABLED. Invalid contents: " + << add_header_contents; + } + } else { + VLOG(1) << "GCS additional header DISABLED. No environment variable set."; + } + // Apply the overrides for request timeouts uint32 timeout_value; if (GetEnvVar(kRequestConnectionTimeout, strings::safe_strtou32, @@ -696,7 +734,8 @@ GcsFileSystem::GcsFileSystem( uint64 stat_cache_max_age, size_t stat_cache_max_entries, uint64 matching_paths_cache_max_age, size_t matching_paths_cache_max_entries, int64 initial_retry_delay_usec, - TimeoutConfig timeouts) + TimeoutConfig timeouts, + std::pair* additional_header) : auth_provider_(std::move(auth_provider)), http_request_factory_(std::move(http_request_factory)), file_block_cache_( @@ -705,7 +744,8 @@ GcsFileSystem::GcsFileSystem( matching_paths_cache_(new MatchingPathsCache( matching_paths_cache_max_age, matching_paths_cache_max_entries)), timeouts_(timeouts), - initial_retry_delay_usec_(initial_retry_delay_usec) {} + initial_retry_delay_usec_(initial_retry_delay_usec), + additional_header_(additional_header) {} Status GcsFileSystem::NewRandomAccessFile( const string& fname, std::unique_ptr* result) { @@ -1397,6 +1437,11 @@ Status GcsFileSystem::CreateHttpRequest(std::unique_ptr* request) { new_request->AddAuthBearerHeader(auth_token); + if (additional_header_) { + new_request->AddHeader(additional_header_->first, + additional_header_->second); + } + *request = std::move(new_request); return Status::OK(); } diff --git a/tensorflow/core/platform/cloud/gcs_file_system.h b/tensorflow/core/platform/cloud/gcs_file_system.h index adde161a9340da61791e5c781c608caabc75d996..2eae39608e38184450290e86bc12d81494bb8302 100644 --- a/tensorflow/core/platform/cloud/gcs_file_system.h +++ b/tensorflow/core/platform/cloud/gcs_file_system.h @@ -17,7 +17,9 @@ limitations under the License. #define TENSORFLOW_CORE_PLATFORM_GCS_FILE_SYSTEM_H_ #include +#include #include + #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/cloud/auth_provider.h" #include "tensorflow/core/platform/cloud/expiring_lru_cache.h" @@ -44,7 +46,8 @@ class GcsFileSystem : public FileSystem { uint64 stat_cache_max_age, size_t stat_cache_max_entries, uint64 matching_paths_cache_max_age, size_t matching_paths_cache_max_entries, - int64 initial_retry_delay_usec, TimeoutConfig timeouts); + int64 initial_retry_delay_usec, TimeoutConfig timeouts, + std::pair* additional_header); Status NewRandomAccessFile( const string& filename, @@ -92,6 +95,12 @@ class GcsFileSystem : public FileSystem { size_t max_bytes() const { return file_block_cache_->max_bytes(); } uint64 max_staleness() const { return file_block_cache_->max_staleness(); } TimeoutConfig timeouts() const { return timeouts_; } + string additional_header_name() const { + return additional_header_ ? additional_header_->first : ""; + } + string additional_header_value() const { + return additional_header_ ? additional_header_->second : ""; + } uint64 stat_cache_max_age() const { return stat_cache_->max_age(); } size_t stat_cache_max_entries() const { return stat_cache_->max_entries(); } @@ -197,6 +206,9 @@ class GcsFileSystem : public FileSystem { /// The initial delay for exponential backoffs when retrying failed calls. const int64 initial_retry_delay_usec_ = 1000000L; + // Additional header material to be transmitted with all GCS requests + std::unique_ptr> additional_header_; + TF_DISALLOW_COPY_AND_ASSIGN(GcsFileSystem); }; diff --git a/tensorflow/core/platform/cloud/gcs_file_system_test.cc b/tensorflow/core/platform/cloud/gcs_file_system_test.cc index 772aec527313fc43fef40983d22e313e338bbe02..d452074ce312f98abe6b058ea56d2e0ce4cf047a 100644 --- a/tensorflow/core/platform/cloud/gcs_file_system_test.cc +++ b/tensorflow/core/platform/cloud/gcs_file_system_test.cc @@ -53,7 +53,8 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_NoBlockCache) { 0 /* stat cache max age */, 0 /* stat cache max entries */, 0 /* matching paths cache max age */, 0 /* matching paths cache max entries */, - 0 /* initial retry delay */, kTestTimeoutConfig); + 0 /* initial retry delay */, kTestTimeoutConfig, + nullptr /* gcs additional header */); std::unique_ptr file; TF_EXPECT_OK(fs.NewRandomAccessFile("gs://bucket/random_access.txt", &file)); @@ -93,7 +94,8 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_NoBlockCache_differentN) { 0 /* stat cache max age */, 0 /* stat cache max entries */, 0 /* matching paths cache max age */, 0 /* matching paths cache max entries */, - 0 /* initial retry delay */, kTestTimeoutConfig); + 0 /* initial retry delay */, kTestTimeoutConfig, + nullptr /* gcs additional header */); std::unique_ptr file; TF_EXPECT_OK(fs.NewRandomAccessFile("gs://bucket/random_access.txt", &file)); @@ -137,15 +139,15 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_WithBlockCache) { "Range: 18-26\n" "Timeouts: 5 1 20\n", "")}); - GcsFileSystem fs(std::unique_ptr(new FakeAuthProvider), - std::unique_ptr( - new FakeHttpRequestFactory(&requests)), - 9 /* block size */, 18 /* max bytes */, - 0 /* max staleness */, 0 /* stat cache max age */, - 0 /* stat cache max entries */, - 0 /* matching paths cache max age */, - 0 /* matching paths cache max entries */, - 0 /* initial retry delay */, kTestTimeoutConfig); + GcsFileSystem fs( + std::unique_ptr(new FakeAuthProvider), + std::unique_ptr( + new FakeHttpRequestFactory(&requests)), + 9 /* block size */, 18 /* max bytes */, 0 /* max staleness */, + 0 /* stat cache max age */, 0 /* stat cache max entries */, + 0 /* matching paths cache max age */, + 0 /* matching paths cache max entries */, 0 /* initial retry delay */, + kTestTimeoutConfig, nullptr /* gcs additional header */); char scratch[100]; StringPiece result; @@ -211,15 +213,15 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_WithBlockCache_Flush) { "Range: 0-8\n" "Timeouts: 5 1 20\n", "012345678")}); - GcsFileSystem fs(std::unique_ptr(new FakeAuthProvider), - std::unique_ptr( - new FakeHttpRequestFactory(&requests)), - 9 /* block size */, 18 /* max bytes */, - 0 /* max staleness */, 0 /* stat cache max age */, - 0 /* stat cache max entries */, - 0 /* matching paths cache max age */, - 0 /* matching paths cache max entries */, - 0 /* initial retry delay */, kTestTimeoutConfig); + GcsFileSystem fs( + std::unique_ptr(new FakeAuthProvider), + std::unique_ptr( + new FakeHttpRequestFactory(&requests)), + 9 /* block size */, 18 /* max bytes */, 0 /* max staleness */, + 0 /* stat cache max age */, 0 /* stat cache max entries */, + 0 /* matching paths cache max age */, + 0 /* matching paths cache max entries */, 0 /* initial retry delay */, + kTestTimeoutConfig, nullptr /* gcs additional header */); char scratch[100]; StringPiece result; @@ -252,15 +254,15 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_WithBlockCache_MaxStaleness) { "Range: 8-15\n" "Timeouts: 5 1 20\n", "89abcdef")}); - GcsFileSystem fs(std::unique_ptr(new FakeAuthProvider), - std::unique_ptr( - new FakeHttpRequestFactory(&requests)), - 8 /* block size */, 16 /* max bytes */, - 3600 /* max staleness */, 0 /* stat cache max age */, - 0 /* stat cache max entries */, - 0 /* matching paths cache max age */, - 0 /* matching paths cache max entries */, - 0 /* initial retry delay */, kTestTimeoutConfig); + GcsFileSystem fs( + std::unique_ptr(new FakeAuthProvider), + std::unique_ptr( + new FakeHttpRequestFactory(&requests)), + 8 /* block size */, 16 /* max bytes */, 3600 /* max staleness */, + 0 /* stat cache max age */, 0 /* stat cache max entries */, + 0 /* matching paths cache max age */, + 0 /* matching paths cache max entries */, 0 /* initial retry delay */, + kTestTimeoutConfig, nullptr /* gcs additional header */); char scratch[100]; StringPiece result; // There should only be two HTTP requests issued to GCS even though we iterate @@ -294,15 +296,15 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_WithBlockCache_MaxStaleness) { TEST(GcsFileSystemTest, NewRandomAccessFile_NoObjectName) { std::vector requests; - GcsFileSystem fs(std::unique_ptr(new FakeAuthProvider), - std::unique_ptr( - new FakeHttpRequestFactory(&requests)), - 0 /* read ahead bytes */, 0 /* max bytes */, - 0 /* max staleness */, 0 /* stat cache max age */, - 0 /* stat cache max entries */, - 0 /* matching paths cache max age */, - 0 /* matching paths cache max entries */, - 0 /* initial retry delay */, kTestTimeoutConfig); + GcsFileSystem fs( + std::unique_ptr(new FakeAuthProvider), + std::unique_ptr( + new FakeHttpRequestFactory(&requests)), + 0 /* read ahead bytes */, 0 /* max bytes */, 0 /* max staleness */, + 0 /* stat cache max age */, 0 /* stat cache max entries */, + 0 /* matching paths cache max age */, + 0 /* matching paths cache max entries */, 0 /* initial retry delay */, + kTestTimeoutConfig, nullptr /* gcs additional header */); std::unique_ptr file; EXPECT_EQ(errors::Code::INVALID_ARGUMENT, @@ -344,7 +346,8 @@ TEST(GcsFileSystemTest, NewWritableFile) { 0 /* stat cache max age */, 0 /* stat cache max entries */, 0 /* matching paths cache max age */, 0 /* matching paths cache max entries */, - 0 /* initial retry delay */, kTestTimeoutConfig); + 0 /* initial retry delay */, kTestTimeoutConfig, + nullptr /* gcs additional header */); // Read from the file first, to fill the block cache. std::unique_ptr rfile; @@ -418,7 +421,8 @@ TEST(GcsFileSystemTest, NewWritableFile_ResumeUploadSucceeds) { 0 /* stat cache max age */, 0 /* stat cache max entries */, 0 /* matching paths cache max age */, 0 /* matching paths cache max entries */, - 0 /* initial retry delay */, kTestTimeoutConfig); + 0 /* initial retry delay */, kTestTimeoutConfig, + nullptr /* gcs additional header */); std::unique_ptr file; TF_EXPECT_OK(fs.NewWritableFile("gs://bucket/path/writeable.txt", &file)); @@ -465,15 +469,15 @@ TEST(GcsFileSystemTest, NewWritableFile_ResumeUploadSucceedsOnGetStatus) { "Range: 0-7\n" "Timeouts: 5 1 20\n", "01234567")}); - GcsFileSystem fs(std::unique_ptr(new FakeAuthProvider), - std::unique_ptr( - new FakeHttpRequestFactory(&requests)), - 8 /* block size */, 8 /* max bytes */, - 3600 /* max staleness */, 0 /* stat cache max age */, - 0 /* stat cache max entries */, - 0 /* matching paths cache max age */, - 0 /* matching paths cache max entries */, - 0 /* initial retry delay */, kTestTimeoutConfig); + GcsFileSystem fs( + std::unique_ptr(new FakeAuthProvider), + std::unique_ptr( + new FakeHttpRequestFactory(&requests)), + 8 /* block size */, 8 /* max bytes */, 3600 /* max staleness */, + 0 /* stat cache max age */, 0 /* stat cache max entries */, + 0 /* matching paths cache max age */, + 0 /* matching paths cache max entries */, 0 /* initial retry delay */, + kTestTimeoutConfig, nullptr /* gcs additional header */); // Pull the file's first block into the cache. This will trigger the first // HTTP request to GCS. std::unique_ptr rfile; @@ -557,7 +561,8 @@ TEST(GcsFileSystemTest, NewWritableFile_ResumeUploadAllAttemptsFail) { 0 /* stat cache max age */, 0 /* stat cache max entries */, 0 /* matching paths cache max age */, 0 /* matching paths cache max entries */, - 2 /* initial retry delay */, kTestTimeoutConfig); + 2 /* initial retry delay */, kTestTimeoutConfig, + nullptr /* gcs additional header */); std::unique_ptr file; TF_EXPECT_OK(fs.NewWritableFile("gs://bucket/path/writeable.txt", &file)); @@ -612,7 +617,8 @@ TEST(GcsFileSystemTest, NewWritableFile_UploadReturns410) { 0 /* stat cache max age */, 0 /* stat cache max entries */, 0 /* matching paths cache max age */, 0 /* matching paths cache max entries */, - 0 /* initial retry delay */, kTestTimeoutConfig); + 0 /* initial retry delay */, kTestTimeoutConfig, + nullptr /* gcs additional header */); std::unique_ptr file; TF_EXPECT_OK(fs.NewWritableFile("gs://bucket/path/writeable.txt", &file)); @@ -641,7 +647,8 @@ TEST(GcsFileSystemTest, NewWritableFile_NoObjectName) { 0 /* stat cache max age */, 0 /* stat cache max entries */, 0 /* matching paths cache max age */, 0 /* matching paths cache max entries */, - 0 /* initial retry delay */, kTestTimeoutConfig); + 0 /* initial retry delay */, kTestTimeoutConfig, + nullptr /* gcs additional header */); std::unique_ptr file; EXPECT_EQ(errors::Code::INVALID_ARGUMENT, @@ -676,15 +683,15 @@ TEST(GcsFileSystemTest, NewAppendableFile) { "Range: 0-31\n" "Timeouts: 5 1 20\n", "01234567")}); - GcsFileSystem fs(std::unique_ptr(new FakeAuthProvider), - std::unique_ptr( - new FakeHttpRequestFactory(&requests)), - 32 /* block size */, 32 /* max bytes */, - 0 /* max staleness */, 0 /* stat cache max age */, - 0 /* stat cache max entries */, - 0 /* matching paths cache max age */, - 0 /* matching paths cache max entries */, - 0 /* initial retry delay */, kTestTimeoutConfig); + GcsFileSystem fs( + std::unique_ptr(new FakeAuthProvider), + std::unique_ptr( + new FakeHttpRequestFactory(&requests)), + 32 /* block size */, 32 /* max bytes */, 0 /* max staleness */, + 0 /* stat cache max age */, 0 /* stat cache max entries */, + 0 /* matching paths cache max age */, + 0 /* matching paths cache max entries */, 0 /* initial retry delay */, + kTestTimeoutConfig, nullptr /* gcs additional header */); // Create an appendable file. This should read the file from GCS, and pull its // contents into the block cache. @@ -717,7 +724,8 @@ TEST(GcsFileSystemTest, NewAppendableFile_NoObjectName) { 0 /* stat cache max age */, 0 /* stat cache max entries */, 0 /* matching paths cache max age */, 0 /* matching paths cache max entries */, - 0 /* initial retry delay */, kTestTimeoutConfig); + 0 /* initial retry delay */, kTestTimeoutConfig, + nullptr /* gcs additional header */); std::unique_ptr file; EXPECT_EQ(errors::Code::INVALID_ARGUMENT, @@ -748,7 +756,8 @@ TEST(GcsFileSystemTest, NewReadOnlyMemoryRegionFromFile) { 0 /* stat cache max age */, 0 /* stat cache max entries */, 0 /* matching paths cache max age */, 0 /* matching paths cache max entries */, - 0 /* initial retry delay */, kTestTimeoutConfig); + 0 /* initial retry delay */, kTestTimeoutConfig, + nullptr /* gcs additional header */); std::unique_ptr region; TF_EXPECT_OK(fs.NewReadOnlyMemoryRegionFromFile( @@ -767,7 +776,8 @@ TEST(GcsFileSystemTest, NewReadOnlyMemoryRegionFromFile_NoObjectName) { 0 /* stat cache max age */, 0 /* stat cache max entries */, 0 /* matching paths cache max age */, 0 /* matching paths cache max entries */, - 0 /* initial retry delay */, kTestTimeoutConfig); + 0 /* initial retry delay */, kTestTimeoutConfig, + nullptr /* gcs additional header */); std::unique_ptr region; EXPECT_EQ(errors::Code::INVALID_ARGUMENT, @@ -789,7 +799,8 @@ TEST(GcsFileSystemTest, FileExists_YesAsObject) { 0 /* stat cache max age */, 0 /* stat cache max entries */, 0 /* matching paths cache max age */, 0 /* matching paths cache max entries */, - 0 /* initial retry delay */, kTestTimeoutConfig); + 0 /* initial retry delay */, kTestTimeoutConfig, + nullptr /* gcs additional header */); TF_EXPECT_OK(fs.FileExists("gs://bucket/path/file1.txt")); } @@ -817,7 +828,8 @@ TEST(GcsFileSystemTest, FileExists_YesAsFolder) { 0 /* stat cache max age */, 0 /* stat cache max entries */, 0 /* matching paths cache max age */, 0 /* matching paths cache max entries */, - 0 /* initial retry delay */, kTestTimeoutConfig); + 0 /* initial retry delay */, kTestTimeoutConfig, + nullptr /* gcs additional header */); TF_EXPECT_OK(fs.FileExists("gs://bucket/path/subfolder")); } @@ -841,7 +853,8 @@ TEST(GcsFileSystemTest, FileExists_YesAsBucket) { 0 /* stat cache max age */, 0 /* stat cache max entries */, 0 /* matching paths cache max age */, 0 /* matching paths cache max entries */, - 0 /* initial retry delay */, kTestTimeoutConfig); + 0 /* initial retry delay */, kTestTimeoutConfig, + nullptr /* gcs additional header */); TF_EXPECT_OK(fs.FileExists("gs://bucket1")); TF_EXPECT_OK(fs.FileExists("gs://bucket1/")); @@ -869,7 +882,8 @@ TEST(GcsFileSystemTest, FileExists_NotAsObjectOrFolder) { 0 /* stat cache max age */, 0 /* stat cache max entries */, 0 /* matching paths cache max age */, 0 /* matching paths cache max entries */, - 0 /* initial retry delay */, kTestTimeoutConfig); + 0 /* initial retry delay */, kTestTimeoutConfig, + nullptr /* gcs additional header */); EXPECT_EQ(errors::Code::NOT_FOUND, fs.FileExists("gs://bucket/path/file1.txt").code()); @@ -894,7 +908,8 @@ TEST(GcsFileSystemTest, FileExists_NotAsBucket) { 0 /* stat cache max age */, 0 /* stat cache max entries */, 0 /* matching paths cache max age */, 0 /* matching paths cache max entries */, - 0 /* initial retry delay */, kTestTimeoutConfig); + 0 /* initial retry delay */, kTestTimeoutConfig, + nullptr /* gcs additional header */); EXPECT_EQ(errors::Code::INVALID_ARGUMENT, fs.FileExists("gs://bucket2/").code()); EXPECT_EQ(errors::Code::INVALID_ARGUMENT, @@ -924,15 +939,15 @@ TEST(GcsFileSystemTest, FileExists_StatCache) { "Timeouts: 5 1 10\n", "{\"items\": [ " " { \"name\": \"path/subfolder/\" }]}")}); - GcsFileSystem fs(std::unique_ptr(new FakeAuthProvider), - std::unique_ptr( - new FakeHttpRequestFactory(&requests)), - 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */, - 3600 /* stat cache max age */, - 0 /* stat cache max entries */, - 0 /* matching paths cache max age */, - 0 /* matching paths cache max entries */, - 0 /* initial retry delay */, kTestTimeoutConfig); + GcsFileSystem fs( + std::unique_ptr(new FakeAuthProvider), + std::unique_ptr( + new FakeHttpRequestFactory(&requests)), + 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */, + 3600 /* stat cache max age */, 0 /* stat cache max entries */, + 0 /* matching paths cache max age */, + 0 /* matching paths cache max entries */, 0 /* initial retry delay */, + kTestTimeoutConfig, nullptr /* gcs additional header */); // The stat cache will ensure that repeated lookups don't trigger additional // HTTP requests. @@ -957,7 +972,8 @@ TEST(GcsFileSystemTest, GetChildren_NoItems) { 0 /* stat cache max age */, 0 /* stat cache max entries */, 0 /* matching paths cache max age */, 0 /* matching paths cache max entries */, - 0 /* initial retry delay */, kTestTimeoutConfig); + 0 /* initial retry delay */, kTestTimeoutConfig, + nullptr /* gcs additional header */); std::vector children; TF_EXPECT_OK(fs.GetChildren("gs://bucket/path/", &children)); @@ -983,7 +999,8 @@ TEST(GcsFileSystemTest, GetChildren_ThreeFiles) { 0 /* stat cache max age */, 0 /* stat cache max entries */, 0 /* matching paths cache max age */, 0 /* matching paths cache max entries */, - 0 /* initial retry delay */, kTestTimeoutConfig); + 0 /* initial retry delay */, kTestTimeoutConfig, + nullptr /* gcs additional header */); std::vector children; TF_EXPECT_OK(fs.GetChildren("gs://bucket/path/", &children)); @@ -1010,7 +1027,8 @@ TEST(GcsFileSystemTest, GetChildren_SelfDirectoryMarker) { 0 /* stat cache max age */, 0 /* stat cache max entries */, 0 /* matching paths cache max age */, 0 /* matching paths cache max entries */, - 0 /* initial retry delay */, kTestTimeoutConfig); + 0 /* initial retry delay */, kTestTimeoutConfig, + nullptr /* gcs additional header */); std::vector children; TF_EXPECT_OK(fs.GetChildren("gs://bucket/path/", &children)); @@ -1036,7 +1054,8 @@ TEST(GcsFileSystemTest, GetChildren_ThreeFiles_NoSlash) { 0 /* stat cache max age */, 0 /* stat cache max entries */, 0 /* matching paths cache max age */, 0 /* matching paths cache max entries */, - 0 /* initial retry delay*/, kTestTimeoutConfig); + 0 /* initial retry delay*/, kTestTimeoutConfig, + nullptr /* gcs additional header */); std::vector children; TF_EXPECT_OK(fs.GetChildren("gs://bucket/path", &children)); @@ -1059,7 +1078,8 @@ TEST(GcsFileSystemTest, GetChildren_Root) { 0 /* stat cache max age */, 0 /* stat cache max entries */, 0 /* matching paths cache max age */, 0 /* matching paths cache max entries */, - 0 /* initial retry delay*/, kTestTimeoutConfig); + 0 /* initial retry delay*/, kTestTimeoutConfig, + nullptr /* gcs additional header */); std::vector children; TF_EXPECT_OK(fs.GetChildren("gs://bucket-a-b-c", &children)); @@ -1082,7 +1102,8 @@ TEST(GcsFileSystemTest, GetChildren_Empty) { 0 /* stat cache max age */, 0 /* stat cache max entries */, 0 /* matching paths cache max age */, 0 /* matching paths cache max entries */, - 0 /* initial retry delay*/, kTestTimeoutConfig); + 0 /* initial retry delay*/, kTestTimeoutConfig, + nullptr /* gcs additional header */); std::vector children; TF_EXPECT_OK(fs.GetChildren("gs://bucket/path/", &children)); @@ -1121,7 +1142,8 @@ TEST(GcsFileSystemTest, GetChildren_Pagination) { 0 /* stat cache max age */, 0 /* stat cache max entries */, 0 /* matching paths cache max age */, 0 /* matching paths cache max entries */, - 0 /* initial retry delay*/, kTestTimeoutConfig); + 0 /* initial retry delay*/, kTestTimeoutConfig, + nullptr /* gcs additional header */); std::vector children; TF_EXPECT_OK(fs.GetChildren("gs://bucket/path", &children)); @@ -1146,7 +1168,8 @@ TEST(GcsFileSystemTest, GetMatchingPaths_NoWildcard) { 0 /* stat cache max age */, 0 /* stat cache max entries */, 0 /* matching paths cache max age */, 0 /* matching paths cache max entries */, - 0 /* initial retry delay*/, kTestTimeoutConfig); + 0 /* initial retry delay*/, kTestTimeoutConfig, + nullptr /* gcs additional header */); std::vector result; TF_EXPECT_OK( @@ -1172,7 +1195,8 @@ TEST(GcsFileSystemTest, GetMatchingPaths_BucketAndWildcard) { 0 /* stat cache max age */, 0 /* stat cache max entries */, 0 /* matching paths cache max age */, 0 /* matching paths cache max entries */, - 0 /* initial retry delay*/, kTestTimeoutConfig); + 0 /* initial retry delay*/, kTestTimeoutConfig, + nullptr /* gcs additional header */); std::vector result; TF_EXPECT_OK(fs.GetMatchingPaths("gs://bucket/*/*", &result)); @@ -1199,7 +1223,8 @@ TEST(GcsFileSystemTest, GetMatchingPaths_FolderAndWildcard_Matches) { 0 /* stat cache max age */, 0 /* stat cache max entries */, 0 /* matching paths cache max age */, 0 /* matching paths cache max entries */, - 0 /* initial retry delay*/, kTestTimeoutConfig); + 0 /* initial retry delay*/, kTestTimeoutConfig, + nullptr /* gcs additional header */); std::vector result; TF_EXPECT_OK(fs.GetMatchingPaths("gs://bucket/path/*/file2.txt", &result)); @@ -1223,7 +1248,8 @@ TEST(GcsFileSystemTest, GetMatchingPaths_SelfDirectoryMarker) { 0 /* stat cache max age */, 0 /* stat cache max entries */, 0 /* matching paths cache max age */, 0 /* matching paths cache max entries */, - 0 /* initial retry delay*/, kTestTimeoutConfig); + 0 /* initial retry delay*/, kTestTimeoutConfig, + nullptr /* gcs additional header */); std::vector result; TF_EXPECT_OK(fs.GetMatchingPaths("gs://bucket/path/*", &result)); @@ -1247,7 +1273,8 @@ TEST(GcsFileSystemTest, GetMatchingPaths_FolderAndWildcard_NoMatches) { 0 /* stat cache max age */, 0 /* stat cache max entries */, 0 /* matching paths cache max age */, 0 /* matching paths cache max entries */, - 0 /* initial retry delay*/, kTestTimeoutConfig); + 0 /* initial retry delay*/, kTestTimeoutConfig, + nullptr /* gcs additional header */); std::vector result; TF_EXPECT_OK(fs.GetMatchingPaths("gs://bucket/path/*/file3.txt", &result)); @@ -1263,7 +1290,8 @@ TEST(GcsFileSystemTest, GetMatchingPaths_OnlyWildcard) { 0 /* stat cache max age */, 0 /* stat cache max entries */, 0 /* matching paths cache max age */, 0 /* matching paths cache max entries */, - 0 /* initial retry delay*/, kTestTimeoutConfig); + 0 /* initial retry delay*/, kTestTimeoutConfig, + nullptr /* gcs additional header */); std::vector result; EXPECT_EQ(errors::Code::INVALID_ARGUMENT, @@ -1295,7 +1323,8 @@ TEST(GcsFileSystemTest, GetMatchingPaths_Cache) { 0 /* stat cache max age */, 0 /* stat cache max entries */, 3600 /* matching paths cache max age */, 0 /* matching paths cache max entries */, - 0 /* initial retry delay*/, kTestTimeoutConfig); + 0 /* initial retry delay*/, kTestTimeoutConfig, + nullptr /* gcs additional header */); // Repeated calls to fs.GetMatchingPaths on these patterns should not lead to // any additional HTTP requests to GCS. @@ -1336,7 +1365,8 @@ TEST(GcsFileSystemTest, GetMatchingPaths_Cache_Flush) { 0 /* stat cache max age */, 0 /* stat cache max entries */, 3600 /* matching paths cache max age */, 0 /* matching paths cache max entries */, - 0 /* initial retry delay*/, kTestTimeoutConfig); + 0 /* initial retry delay*/, kTestTimeoutConfig, + nullptr /* gcs additional header */); // This loop should trigger the first HTTP request to GCS. for (int i = 0; i < 10; i++) { @@ -1377,15 +1407,15 @@ TEST(GcsFileSystemTest, DeleteFile) { "Range: 0-15\n" "Timeouts: 5 1 20\n", "76543210")}); - GcsFileSystem fs(std::unique_ptr(new FakeAuthProvider), - std::unique_ptr( - new FakeHttpRequestFactory(&requests)), - 16 /* block size */, 16 /* max bytes */, - 0 /* max staleness */, 0 /* stat cache max age */, - 0 /* stat cache max entries */, - 0 /* matching paths cache max age */, - 0 /* matching paths cache max entries */, - 0 /* initial retry delay*/, kTestTimeoutConfig); + GcsFileSystem fs( + std::unique_ptr(new FakeAuthProvider), + std::unique_ptr( + new FakeHttpRequestFactory(&requests)), + 16 /* block size */, 16 /* max bytes */, 0 /* max staleness */, + 0 /* stat cache max age */, 0 /* stat cache max entries */, + 0 /* matching paths cache max age */, + 0 /* matching paths cache max entries */, 0 /* initial retry delay*/, + kTestTimeoutConfig, nullptr /* gcs additional header */); // Do an initial read of the file to load its contents into the block cache. char scratch[100]; @@ -1411,7 +1441,8 @@ TEST(GcsFileSystemTest, DeleteFile_NoObjectName) { 0 /* stat cache max age */, 0 /* stat cache max entries */, 0 /* matching paths cache max age */, 0 /* matching paths cache max entries */, - 0 /* initial retry delay*/, kTestTimeoutConfig); + 0 /* initial retry delay*/, kTestTimeoutConfig, + nullptr /* gcs additional header */); EXPECT_EQ(errors::Code::INVALID_ARGUMENT, fs.DeleteFile("gs://bucket/").code()); @@ -1431,7 +1462,8 @@ TEST(GcsFileSystemTest, DeleteDir_Empty) { 0 /* stat cache max age */, 0 /* stat cache max entries */, 0 /* matching paths cache max age */, 0 /* matching paths cache max entries */, - 0 /* initial retry delay*/, kTestTimeoutConfig); + 0 /* initial retry delay*/, kTestTimeoutConfig, + nullptr /* gcs additional header */); TF_EXPECT_OK(fs.DeleteDir("gs://bucket/path/")); } @@ -1458,7 +1490,8 @@ TEST(GcsFileSystemTest, DeleteDir_OnlyDirMarkerLeft) { 0 /* stat cache max age */, 0 /* stat cache max entries */, 0 /* matching paths cache max age */, 0 /* matching paths cache max entries */, - 0 /* initial retry delay*/, kTestTimeoutConfig); + 0 /* initial retry delay*/, kTestTimeoutConfig, + nullptr /* gcs additional header */); TF_EXPECT_OK(fs.DeleteDir("gs://bucket/path/")); } @@ -1476,7 +1509,8 @@ TEST(GcsFileSystemTest, DeleteDir_BucketOnly) { 0 /* stat cache max age */, 0 /* stat cache max entries */, 0 /* matching paths cache max age */, 0 /* matching paths cache max entries */, - 0 /* initial retry delay*/, kTestTimeoutConfig); + 0 /* initial retry delay*/, kTestTimeoutConfig, + nullptr /* gcs additional header */); TF_EXPECT_OK(fs.DeleteDir("gs://bucket")); } @@ -1496,7 +1530,8 @@ TEST(GcsFileSystemTest, DeleteDir_NonEmpty) { 0 /* stat cache max age */, 0 /* stat cache max entries */, 0 /* matching paths cache max age */, 0 /* matching paths cache max entries */, - 0 /* initial retry delay*/, kTestTimeoutConfig); + 0 /* initial retry delay*/, kTestTimeoutConfig, + nullptr /* gcs additional header */); EXPECT_EQ(error::Code::FAILED_PRECONDITION, fs.DeleteDir("gs://bucket/path/").code()); @@ -1517,7 +1552,8 @@ TEST(GcsFileSystemTest, GetFileSize) { 0 /* stat cache max age */, 0 /* stat cache max entries */, 0 /* matching paths cache max age */, 0 /* matching paths cache max entries */, - 0 /* initial retry delay*/, kTestTimeoutConfig); + 0 /* initial retry delay*/, kTestTimeoutConfig, + nullptr /* gcs additional header */); uint64 size; TF_EXPECT_OK(fs.GetFileSize("gs://bucket/file.txt", &size)); @@ -1533,7 +1569,8 @@ TEST(GcsFileSystemTest, GetFileSize_NoObjectName) { 0 /* stat cache max age */, 0 /* stat cache max entries */, 0 /* matching paths cache max age */, 0 /* matching paths cache max entries */, - 0 /* initial retry delay*/, kTestTimeoutConfig); + 0 /* initial retry delay*/, kTestTimeoutConfig, + nullptr /* gcs additional header */); uint64 size; EXPECT_EQ(errors::Code::INVALID_ARGUMENT, @@ -1617,7 +1654,8 @@ TEST(GcsFileSystemTest, RenameFile_Folder) { 0 /* stat cache max age */, 0 /* stat cache max entries */, 0 /* matching paths cache max age */, 0 /* matching paths cache max entries */, - 0 /* initial retry delay*/, kTestTimeoutConfig); + 0 /* initial retry delay*/, kTestTimeoutConfig, + nullptr /* gcs additional header */); TF_EXPECT_OK(fs.RenameFile("gs://bucket/path1", "gs://bucket/path2/")); } @@ -1680,15 +1718,15 @@ TEST(GcsFileSystemTest, RenameFile_Object) { "Range: 0-15\n" "Timeouts: 5 1 20\n", "fedcba98")}); - GcsFileSystem fs(std::unique_ptr(new FakeAuthProvider), - std::unique_ptr( - new FakeHttpRequestFactory(&requests)), - 16 /* block size */, 64 /* max bytes */, - 0 /* max staleness */, 0 /* stat cache max age */, - 0 /* stat cache max entries */, - 0 /* matching paths cache max age */, - 0 /* matching paths cache max entries */, - 0 /* initial retry delay*/, kTestTimeoutConfig); + GcsFileSystem fs( + std::unique_ptr(new FakeAuthProvider), + std::unique_ptr( + new FakeHttpRequestFactory(&requests)), + 16 /* block size */, 64 /* max bytes */, 0 /* max staleness */, + 0 /* stat cache max age */, 0 /* stat cache max entries */, + 0 /* matching paths cache max age */, + 0 /* matching paths cache max entries */, 0 /* initial retry delay*/, + kTestTimeoutConfig, nullptr /* gcs additional header */); // Do an initial read of the source and destination files to load their // contents into the block cache. char scratch[100]; @@ -1761,7 +1799,8 @@ TEST(GcsFileSystemTest, RenameFile_Object_DeletionRetried) { 0 /* stat cache max age */, 0 /* stat cache max entries */, 0 /* matching paths cache max age */, 0 /* matching paths cache max entries */, - 0 /* initial retry delay*/, kTestTimeoutConfig); + 0 /* initial retry delay*/, kTestTimeoutConfig, + nullptr /* gcs additional header */); TF_EXPECT_OK( fs.RenameFile("gs://bucket/path/src.txt", "gs://bucket/path/dst.txt")); @@ -1801,7 +1840,8 @@ TEST(GcsFileSystemTest, RenameFile_Object_Incomplete) { 0 /* stat cache max age */, 0 /* stat cache max entries */, 0 /* matching paths cache max age */, 0 /* matching paths cache max entries */, - 0 /* initial retry delay*/, kTestTimeoutConfig); + 0 /* initial retry delay*/, kTestTimeoutConfig, + nullptr /* gcs additional header */); EXPECT_EQ( errors::Code::UNIMPLEMENTED, @@ -1824,7 +1864,8 @@ TEST(GcsFileSystemTest, Stat_Object) { 0 /* stat cache max age */, 0 /* stat cache max entries */, 0 /* matching paths cache max age */, 0 /* matching paths cache max entries */, - 0 /* initial retry delay*/, kTestTimeoutConfig); + 0 /* initial retry delay*/, kTestTimeoutConfig, + nullptr /* gcs additional header */); FileStatistics stat; TF_EXPECT_OK(fs.Stat("gs://bucket/file.txt", &stat)); @@ -1856,7 +1897,8 @@ TEST(GcsFileSystemTest, Stat_Folder) { 0 /* stat cache max age */, 0 /* stat cache max entries */, 0 /* matching paths cache max age */, 0 /* matching paths cache max entries */, - 0 /* initial retry delay*/, kTestTimeoutConfig); + 0 /* initial retry delay*/, kTestTimeoutConfig, + nullptr /* gcs additional header */); FileStatistics stat; TF_EXPECT_OK(fs.Stat("gs://bucket/subfolder", &stat)); @@ -1887,7 +1929,8 @@ TEST(GcsFileSystemTest, Stat_ObjectOrFolderNotFound) { 0 /* stat cache max age */, 0 /* stat cache max entries */, 0 /* matching paths cache max age */, 0 /* matching paths cache max entries */, - 0 /* initial retry delay*/, kTestTimeoutConfig); + 0 /* initial retry delay*/, kTestTimeoutConfig, + nullptr /* gcs additional header */); FileStatistics stat; EXPECT_EQ(error::Code::NOT_FOUND, fs.Stat("gs://bucket/path", &stat).code()); @@ -1906,7 +1949,8 @@ TEST(GcsFileSystemTest, Stat_Bucket) { 0 /* stat cache max age */, 0 /* stat cache max entries */, 0 /* matching paths cache max age */, 0 /* matching paths cache max entries */, - 0 /* initial retry delay*/, kTestTimeoutConfig); + 0 /* initial retry delay*/, kTestTimeoutConfig, + nullptr /* gcs additional header */); FileStatistics stat; TF_EXPECT_OK(fs.Stat("gs://bucket/", &stat)); @@ -1928,7 +1972,8 @@ TEST(GcsFileSystemTest, Stat_BucketNotFound) { 0 /* stat cache max age */, 0 /* stat cache max entries */, 0 /* matching paths cache max age */, 0 /* matching paths cache max entries */, - 0 /* initial retry delay*/, kTestTimeoutConfig); + 0 /* initial retry delay*/, kTestTimeoutConfig, + nullptr /* gcs additional header */); FileStatistics stat; EXPECT_EQ(error::Code::NOT_FOUND, fs.Stat("gs://bucket/", &stat).code()); @@ -1957,15 +2002,15 @@ TEST(GcsFileSystemTest, Stat_Cache) { "Timeouts: 5 1 10\n", "{\"items\": [ " " { \"name\": \"subfolder/\" }]}")}); - GcsFileSystem fs(std::unique_ptr(new FakeAuthProvider), - std::unique_ptr( - new FakeHttpRequestFactory(&requests)), - 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */, - 3600 /* stat cache max age */, - 0 /* stat cache max entries */, - 0 /* matching paths cache max age */, - 0 /* matching paths cache max entries */, - 0 /* initial retry delay*/, kTestTimeoutConfig); + GcsFileSystem fs( + std::unique_ptr(new FakeAuthProvider), + std::unique_ptr( + new FakeHttpRequestFactory(&requests)), + 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */, + 3600 /* stat cache max age */, 0 /* stat cache max entries */, + 0 /* matching paths cache max age */, + 0 /* matching paths cache max entries */, 0 /* initial retry delay*/, + kTestTimeoutConfig, nullptr /* gcs additional header */); // Repeated calls to fs.Stat on these paths should not lead to any additional // HTTP requests to GCS. @@ -1998,15 +2043,15 @@ TEST(GcsFileSystemTest, Stat_Cache_Flush) { "Timeouts: 5 1 10\n", strings::StrCat("{\"size\": \"1010\"," "\"updated\": \"2016-04-29T23:15:24.896Z\"}"))}); - GcsFileSystem fs(std::unique_ptr(new FakeAuthProvider), - std::unique_ptr( - new FakeHttpRequestFactory(&requests)), - 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */, - 3600 /* stat cache max age */, - 0 /* stat cache max entries */, - 0 /* matching paths cache max age */, - 0 /* matching paths cache max entries */, - 0 /* initial retry delay*/, kTestTimeoutConfig); + GcsFileSystem fs( + std::unique_ptr(new FakeAuthProvider), + std::unique_ptr( + new FakeHttpRequestFactory(&requests)), + 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */, + 3600 /* stat cache max age */, 0 /* stat cache max entries */, + 0 /* matching paths cache max age */, + 0 /* matching paths cache max entries */, 0 /* initial retry delay*/, + kTestTimeoutConfig, nullptr /* gcs additional header */); // There should be a single HTTP request to GCS for fs.Stat in this loop. for (int i = 0; i < 10; i++) { FileStatistics stat; @@ -2048,7 +2093,8 @@ TEST(GcsFileSystemTest, IsDirectory_NotFound) { 0 /* stat cache max age */, 0 /* stat cache max entries */, 0 /* matching paths cache max age */, 0 /* matching paths cache max entries */, - 0 /* initial retry delay*/, kTestTimeoutConfig); + 0 /* initial retry delay*/, kTestTimeoutConfig, + nullptr /* gcs additional header */); EXPECT_EQ(error::Code::NOT_FOUND, fs.IsDirectory("gs://bucket/file.txt").code()); @@ -2077,7 +2123,8 @@ TEST(GcsFileSystemTest, IsDirectory_NotDirectoryButObject) { 0 /* stat cache max age */, 0 /* stat cache max entries */, 0 /* matching paths cache max age */, 0 /* matching paths cache max entries */, - 0 /* initial retry delay*/, kTestTimeoutConfig); + 0 /* initial retry delay*/, kTestTimeoutConfig, + nullptr /* gcs additional header */); EXPECT_EQ(error::Code::FAILED_PRECONDITION, fs.IsDirectory("gs://bucket/file.txt").code()); @@ -2106,7 +2153,8 @@ TEST(GcsFileSystemTest, IsDirectory_Yes) { 0 /* stat cache max age */, 0 /* stat cache max entries */, 0 /* matching paths cache max age */, 0 /* matching paths cache max entries */, - 0 /* initial retry delay*/, kTestTimeoutConfig); + 0 /* initial retry delay*/, kTestTimeoutConfig, + nullptr /* gcs additional header */); TF_EXPECT_OK(fs.IsDirectory("gs://bucket/subfolder")); TF_EXPECT_OK(fs.IsDirectory("gs://bucket/subfolder/")); @@ -2131,7 +2179,8 @@ TEST(GcsFileSystemTest, IsDirectory_Bucket) { 0 /* stat cache max age */, 0 /* stat cache max entries */, 0 /* matching paths cache max age */, 0 /* matching paths cache max entries */, - 0 /* initial retry delay*/, kTestTimeoutConfig); + 0 /* initial retry delay*/, kTestTimeoutConfig, + nullptr /* gcs additional header */); TF_EXPECT_OK(fs.IsDirectory("gs://bucket")); TF_EXPECT_OK(fs.IsDirectory("gs://bucket/")); @@ -2150,7 +2199,8 @@ TEST(GcsFileSystemTest, IsDirectory_BucketNotFound) { 0 /* stat cache max age */, 0 /* stat cache max entries */, 0 /* matching paths cache max age */, 0 /* matching paths cache max entries */, - 0 /* initial retry delay*/, kTestTimeoutConfig); + 0 /* initial retry delay*/, kTestTimeoutConfig, + nullptr /* gcs additional header */); EXPECT_EQ(error::Code::NOT_FOUND, fs.IsDirectory("gs://bucket/").code()); } @@ -2190,7 +2240,8 @@ TEST(GcsFileSystemTest, CreateDir_Folder) { 0 /* stat cache max age */, 0 /* stat cache max entries */, 0 /* matching paths cache max age */, 0 /* matching paths cache max entries */, - 0 /* initial retry delay*/, kTestTimeoutConfig); + 0 /* initial retry delay*/, kTestTimeoutConfig, + nullptr /* gcs additional header */); TF_EXPECT_OK(fs.CreateDir("gs://bucket/subpath")); TF_EXPECT_OK(fs.CreateDir("gs://bucket/subpath/")); @@ -2215,7 +2266,8 @@ TEST(GcsFileSystemTest, CreateDir_Bucket) { 0 /* stat cache max age */, 0 /* stat cache max entries */, 0 /* matching paths cache max age */, 0 /* matching paths cache max entries */, - 0 /* initial retry delay*/, kTestTimeoutConfig); + 0 /* initial retry delay*/, kTestTimeoutConfig, + nullptr /* gcs additional header */); TF_EXPECT_OK(fs.CreateDir("gs://bucket/")); TF_EXPECT_OK(fs.CreateDir("gs://bucket")); @@ -2285,7 +2337,8 @@ TEST(GcsFileSystemTest, DeleteRecursively_Ok) { 0 /* stat cache max age */, 0 /* stat cache max entries */, 0 /* matching paths cache max age */, 0 /* matching paths cache max entries */, - 0 /* initial retry delay*/, kTestTimeoutConfig); + 0 /* initial retry delay*/, kTestTimeoutConfig, + nullptr /* gcs additional header */); int64 undeleted_files, undeleted_dirs; TF_EXPECT_OK(fs.DeleteRecursively("gs://bucket/path", &undeleted_files, @@ -2376,7 +2429,8 @@ TEST(GcsFileSystemTest, DeleteRecursively_DeletionErrors) { 0 /* stat cache max age */, 0 /* stat cache max entries */, 0 /* matching paths cache max age */, 0 /* matching paths cache max entries */, - 0 /* initial retry delay*/, kTestTimeoutConfig); + 0 /* initial retry delay*/, kTestTimeoutConfig, + nullptr /* gcs additional header */); int64 undeleted_files, undeleted_dirs; TF_EXPECT_OK(fs.DeleteRecursively("gs://bucket/path", &undeleted_files, @@ -2409,7 +2463,8 @@ TEST(GcsFileSystemTest, DeleteRecursively_NotAFolder) { 0 /* stat cache max age */, 0 /* stat cache max entries */, 0 /* matching paths cache max age */, 0 /* matching paths cache max entries */, - 0 /* initial retry delay*/, kTestTimeoutConfig); + 0 /* initial retry delay*/, kTestTimeoutConfig, + nullptr /* gcs additional header */); int64 undeleted_files, undeleted_dirs; EXPECT_EQ(error::Code::NOT_FOUND, @@ -2420,6 +2475,64 @@ TEST(GcsFileSystemTest, DeleteRecursively_NotAFolder) { EXPECT_EQ(1, undeleted_dirs); } +TEST(GcsFileSystemTest, AdditionalRequestHeaderTest) { + GcsFileSystem fs1; + EXPECT_EQ("", fs1.additional_header_name()); + EXPECT_EQ("", fs1.additional_header_value()); + + setenv("GCS_ADDITIONAL_REQUEST_HEADER", + "X-Add-Header:My Additional Header Value", 1); + GcsFileSystem fs2; + EXPECT_EQ("X-Add-Header", fs2.additional_header_name()); + EXPECT_EQ("My Additional Header Value", fs2.additional_header_value()); + + setenv("GCS_ADDITIONAL_REQUEST_HEADER", "Someinvalidheadervalue", 1); + GcsFileSystem fs3; + EXPECT_EQ("", fs3.additional_header_name()); + EXPECT_EQ("", fs3.additional_header_value()); + + setenv("GCS_ADDITIONAL_REQUEST_HEADER", ":thisisinvalid", 1); + GcsFileSystem fs4; + EXPECT_EQ("", fs4.additional_header_name()); + EXPECT_EQ("", fs4.additional_header_value()); + + setenv("GCS_ADDITIONAL_REQUEST_HEADER", "soisthis:", 1); + GcsFileSystem fs5; + EXPECT_EQ("", fs5.additional_header_name()); + EXPECT_EQ("", fs5.additional_header_value()); + + setenv("GCS_ADDITIONAL_REQUEST_HEADER", "a:b", 1); + GcsFileSystem fs6; + EXPECT_EQ("a", fs6.additional_header_name()); + EXPECT_EQ("b", fs6.additional_header_value()); + + auto* add_header = new std::pair( + "mynewheader", "newheadercontents"); + + std::vector requests( + {// IsDirectory is checking whether there are children objects. + new FakeHttpRequest("Uri: https://www.googleapis.com/fake\n" + "Auth Token: fake_token\n" + "Header mynewheader: newheadercontents\n" + "Header Hello: world\n", + "{}")}); + GcsFileSystem fs7( + std::unique_ptr(new FakeAuthProvider), + std::unique_ptr( + new FakeHttpRequestFactory(&requests)), + 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */, + 0 /* stat cache max age */, 0 /* stat cache max entries */, + 0 /* matching paths cache max age */, + 0 /* matching paths cache max entries */, 0 /* initial retry delay */, + kTestTimeoutConfig, add_header /* gcs additional header */); + + std::unique_ptr request; + TF_EXPECT_OK(fs7.CreateHttpRequest(&request)); + request->SetUri("https://www.googleapis.com/fake"); + request->AddHeader("Hello", "world"); + TF_EXPECT_OK(request->Send()); +} + TEST(GcsFileSystemTest, OverrideCacheParameters) { // Verify defaults are propagated correctly. GcsFileSystem fs1; @@ -2485,7 +2598,8 @@ TEST(GcsFileSystemTest, CreateHttpRequest) { 0 /* stat cache max age */, 0 /* stat cache max entries */, 0 /* matching paths cache max age */, 0 /* matching paths cache max entries */, - 0 /* initial retry delay */, kTestTimeoutConfig); + 0 /* initial retry delay */, kTestTimeoutConfig, + nullptr /* gcs additional header */); std::unique_ptr request; TF_EXPECT_OK(fs.CreateHttpRequest(&request)); diff --git a/tensorflow/core/platform/cloud/oauth_client.h b/tensorflow/core/platform/cloud/oauth_client.h index 1614c7b315f67f5976a2d18a6d281afe7459f4f1..519d69acf982c7d004d2c15cd47cf8743669f8fe 100644 --- a/tensorflow/core/platform/cloud/oauth_client.h +++ b/tensorflow/core/platform/cloud/oauth_client.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_PLATFORM_CLOUD_OAUTH_CLIENT_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_PLATFORM_CLOUD_OAUTH_CLIENT_H_ +#ifndef TENSORFLOW_CORE_PLATFORM_CLOUD_OAUTH_CLIENT_H_ +#define TENSORFLOW_CORE_PLATFORM_CLOUD_OAUTH_CLIENT_H_ #include #include "include/json/json.h" @@ -59,4 +59,4 @@ class OAuthClient { } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_PLATFORM_CLOUD_OAUTH_CLIENT_H_ +#endif // TENSORFLOW_CORE_PLATFORM_CLOUD_OAUTH_CLIENT_H_ diff --git a/tensorflow/core/platform/cloud/retrying_utils.h b/tensorflow/core/platform/cloud/retrying_utils.h index 99ab216e97fc9fcdf02e0776dd252b808a43df7a..546b8d1c4a4842f44f6c490eb05cc3cac29aa023 100644 --- a/tensorflow/core/platform/cloud/retrying_utils.h +++ b/tensorflow/core/platform/cloud/retrying_utils.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_PLATFORM_CLOUD_RETRYING_UTILS_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_PLATFORM_CLOUD_RETRYING_UTILS_H_ +#ifndef TENSORFLOW_CORE_PLATFORM_CLOUD_RETRYING_UTILS_H_ +#define TENSORFLOW_CORE_PLATFORM_CLOUD_RETRYING_UTILS_H_ #include #include "tensorflow/core/lib/core/status.h" @@ -47,4 +47,4 @@ class RetryingUtils { } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_PLATFORM_CLOUD_RETRYING_UTILS_H_ +#endif // TENSORFLOW_CORE_PLATFORM_CLOUD_RETRYING_UTILS_H_ diff --git a/tensorflow/core/platform/cloud/time_util.h b/tensorflow/core/platform/cloud/time_util.h index b1bb7f111970b51dcd2dcba47a3c20f8388bca42..d6d4bc499fe2430e8f5c97ca23c9db7345de11b4 100644 --- a/tensorflow/core/platform/cloud/time_util.h +++ b/tensorflow/core/platform/cloud/time_util.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_PLATFORM_CLOUD_TIME_UTIL_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_PLATFORM_CLOUD_TIME_UTIL_H_ +#ifndef TENSORFLOW_CORE_PLATFORM_CLOUD_TIME_UTIL_H_ +#define TENSORFLOW_CORE_PLATFORM_CLOUD_TIME_UTIL_H_ #include "tensorflow/core/lib/core/status.h" @@ -26,4 +26,4 @@ Status ParseRfc3339Time(const string& time, int64* mtime_nsec); } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_PLATFORM_CLOUD_TIME_UTIL_H_ +#endif // TENSORFLOW_CORE_PLATFORM_CLOUD_TIME_UTIL_H_ diff --git a/tensorflow/core/platform/cuda_libdevice_path.h b/tensorflow/core/platform/cuda_libdevice_path.h index 601d0db6d47c7f09300970a454b618653f8f9596..6ef565ecd3c6460791b49a25fd4277e9393cfdd0 100644 --- a/tensorflow/core/platform/cuda_libdevice_path.h +++ b/tensorflow/core/platform/cuda_libdevice_path.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_PLATFORM_CUDA_LIBDEVICE_PATH_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_PLATFORM_CUDA_LIBDEVICE_PATH_H_ +#ifndef TENSORFLOW_CORE_PLATFORM_CUDA_LIBDEVICE_PATH_H_ +#define TENSORFLOW_CORE_PLATFORM_CUDA_LIBDEVICE_PATH_H_ #include "tensorflow/core/platform/types.h" @@ -29,4 +29,4 @@ string LibdeviceRoot(); } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_PLATFORM_CUDA_LIBDEVICE_PATH_H_ +#endif // TENSORFLOW_CORE_PLATFORM_CUDA_LIBDEVICE_PATH_H_ diff --git a/tensorflow/core/platform/cupti_wrapper.h b/tensorflow/core/platform/cupti_wrapper.h index c909dcd35bae4eef8cf165aa00349b079245db85..9a17ab60c0d2ebcd4401707a23a76f381aeb5994 100644 --- a/tensorflow/core/platform/cupti_wrapper.h +++ b/tensorflow/core/platform/cupti_wrapper.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_PLATFORM_CUPTI_WRAPPER_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_PLATFORM_CUPTI_WRAPPER_H_ +#ifndef TENSORFLOW_CORE_PLATFORM_CUPTI_WRAPPER_H_ +#define TENSORFLOW_CORE_PLATFORM_CUPTI_WRAPPER_H_ #include "tensorflow/core/platform/platform.h" @@ -24,4 +24,4 @@ limitations under the License. #include "tensorflow/core/platform/default/gpu/cupti_wrapper.h" #endif -#endif // THIRD_PARTY_TENSORFLOW_CORE_PLATFORM_CUPTI_WRAPPER_H_ +#endif // TENSORFLOW_CORE_PLATFORM_CUPTI_WRAPPER_H_ diff --git a/tensorflow/core/platform/default/build_config.bzl b/tensorflow/core/platform/default/build_config.bzl index e9c510c93c67a338df67c0882aef0fcf6ef5e393..2102c5cca383b553c56fb3704596e3d1335c55c2 100644 --- a/tensorflow/core/platform/default/build_config.bzl +++ b/tensorflow/core/platform/default/build_config.bzl @@ -378,6 +378,14 @@ def tf_protos_all(): extra_deps=tf_protos_all_impl(), otherwise=["//tensorflow/core:protos_all_cc"]) +def tf_protos_grappler_impl(): + return ["//tensorflow/core/grappler/costs:op_performance_data_cc_impl"] + +def tf_protos_grappler(): + return if_static( + extra_deps=tf_protos_grappler_impl(), + otherwise=["//tensorflow/core/grappler/costs:op_performance_data_cc"]) + def tf_env_time_hdrs(): return [ "platform/env_time.h", diff --git a/tensorflow/core/platform/default/build_config/BUILD b/tensorflow/core/platform/default/build_config/BUILD index f2fadb45589a8b44d29db045ca4585b578c5301d..2cd607edbe554cd18d21626e258176e8570282ed 100644 --- a/tensorflow/core/platform/default/build_config/BUILD +++ b/tensorflow/core/platform/default/build_config/BUILD @@ -122,7 +122,7 @@ cc_library( "//tensorflow/core:protos_cc", "@com_googlesource_code_re2//:re2", "@farmhash_archive//:farmhash", - "@fft2d//:fft2d", + "@fft2d", "@highwayhash//:sip_hash", "@png_archive//:png", ], @@ -140,7 +140,7 @@ cc_library( name = "jpeg", copts = tf_copts(), deps = [ - "@jpeg//:jpeg", + "@jpeg", ], ) diff --git a/tensorflow/core/platform/default/gpu/cupti_wrapper.h b/tensorflow/core/platform/default/gpu/cupti_wrapper.h index 38e01cefad8aac372a1d5e65f984ac62623336de..acd889e47496f8bc1cc9f89c3848848d47c4e91f 100644 --- a/tensorflow/core/platform/default/gpu/cupti_wrapper.h +++ b/tensorflow/core/platform/default/gpu/cupti_wrapper.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_PLATFORM_DEFAULT_CUPTI_WRAPPER_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_PLATFORM_DEFAULT_CUPTI_WRAPPER_H_ +#ifndef TENSORFLOW_CORE_PLATFORM_DEFAULT_CUPTI_WRAPPER_H_ +#define TENSORFLOW_CORE_PLATFORM_DEFAULT_CUPTI_WRAPPER_H_ #if GOOGLE_CUDA @@ -76,4 +76,4 @@ class CuptiWrapper { #endif // GOOGLE_CUDA -#endif // THIRD_PARTY_TENSORFLOW_CORE_PLATFORM_DEFAULT_CUPTI_WRAPPER_H_ +#endif // TENSORFLOW_CORE_PLATFORM_DEFAULT_CUPTI_WRAPPER_H_ diff --git a/tensorflow/core/platform/demangle.h b/tensorflow/core/platform/demangle.h index c2def217a12dd201245bc8e3e6629f2456198f2e..ce33be2e6899e9770e8cdd7831f16cdb4856d6af 100644 --- a/tensorflow/core/platform/demangle.h +++ b/tensorflow/core/platform/demangle.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_PLATFORM_DEMANGLE_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_PLATFORM_DEMANGLE_H_ +#ifndef TENSORFLOW_CORE_PLATFORM_DEMANGLE_H_ +#define TENSORFLOW_CORE_PLATFORM_DEMANGLE_H_ #include "tensorflow/core/platform/types.h" @@ -28,4 +28,4 @@ string Demangle(const char* mangled); } // namespace port } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_PLATFORM_DEMANGLE_H_ +#endif // TENSORFLOW_CORE_PLATFORM_DEMANGLE_H_ diff --git a/tensorflow/core/platform/file_statistics.h b/tensorflow/core/platform/file_statistics.h index 7629db6ef9e216d652b819a5bc19af1ab6a38058..9e3489b1adb8c7af1651c1b30539c5083a201979 100644 --- a/tensorflow/core/platform/file_statistics.h +++ b/tensorflow/core/platform/file_statistics.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_PLATFORM_FILE_STATISTICS_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_PLATFORM_FILE_STATISTICS_H_ +#ifndef TENSORFLOW_CORE_PLATFORM_FILE_STATISTICS_H_ +#define TENSORFLOW_CORE_PLATFORM_FILE_STATISTICS_H_ #include "tensorflow/core/platform/types.h" @@ -36,4 +36,4 @@ struct FileStatistics { } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_PLATFORM_FILE_STATISTICS_H_ +#endif // TENSORFLOW_CORE_PLATFORM_FILE_STATISTICS_H_ diff --git a/tensorflow/core/platform/hadoop/hadoop_file_system.h b/tensorflow/core/platform/hadoop/hadoop_file_system.h index 447e83158ab6c54a505190bd451fdbdcb678a7f1..5f2b222622cf01033af117f92d49458eeae00e6f 100644 --- a/tensorflow/core/platform/hadoop/hadoop_file_system.h +++ b/tensorflow/core/platform/hadoop/hadoop_file_system.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_PLATFORM_HADOOP_HADOOP_FILE_SYSTEM_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_PLATFORM_HADOOP_HADOOP_FILE_SYSTEM_H_ +#ifndef TENSORFLOW_CORE_PLATFORM_HADOOP_HADOOP_FILE_SYSTEM_H_ +#define TENSORFLOW_CORE_PLATFORM_HADOOP_HADOOP_FILE_SYSTEM_H_ #include "tensorflow/core/platform/env.h" @@ -70,4 +70,4 @@ class HadoopFileSystem : public FileSystem { } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_PLATFORM_HADOOP_HADOOP_FILE_SYSTEM_H_ +#endif // TENSORFLOW_CORE_PLATFORM_HADOOP_HADOOP_FILE_SYSTEM_H_ diff --git a/tensorflow/core/platform/s3/BUILD b/tensorflow/core/platform/s3/BUILD index 2cd5f877c9fcc998b6a727e3ae0a92f17a233c9f..3a0ad2e9bd09211aa452f8b39b621343a113785d 100644 --- a/tensorflow/core/platform/s3/BUILD +++ b/tensorflow/core/platform/s3/BUILD @@ -45,8 +45,8 @@ tf_cc_binary( linkshared = 1, deps = [ "//tensorflow/core:framework_headers_lib", - "@aws//:aws", - "@curl//:curl", + "@aws", + "@curl", "@protobuf_archive//:protobuf_headers", ], ) @@ -62,7 +62,7 @@ cc_library( deps = [ "//tensorflow/core:lib", "//tensorflow/core:lib_internal", - "@aws//:aws", + "@aws", "@boringssl//:crypto", ], alwayslink = 1, @@ -79,7 +79,7 @@ cc_library( deps = [ "//tensorflow/core:lib", "//tensorflow/core:lib_internal", - "@aws//:aws", + "@aws", ], alwayslink = 1, ) @@ -97,7 +97,7 @@ cc_library( ":s3_crypto", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", - "@aws//:aws", + "@aws", ], alwayslink = 1, ) @@ -117,6 +117,6 @@ tf_cc_test( "//tensorflow/core:lib_internal", "//tensorflow/core:test", "//tensorflow/core:test_main", - "@aws//:aws", + "@aws", ], ) diff --git a/tensorflow/core/platform/s3/s3_file_system.cc b/tensorflow/core/platform/s3/s3_file_system.cc index 58ea3156701268b758f52383d123fd3a16d3fd86..2c0babe098f2e7a066338e5cb2a25aedf16db8d9 100644 --- a/tensorflow/core/platform/s3/s3_file_system.cc +++ b/tensorflow/core/platform/s3/s3_file_system.cc @@ -14,11 +14,13 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/platform/s3/s3_file_system.h" #include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/s3/aws_logging.h" #include "tensorflow/core/platform/s3/s3_crypto.h" #include +#include #include #include #include @@ -54,13 +56,37 @@ Aws::Client::ClientConfiguration& GetDefaultClientConfig() { cfg.endpointOverride = Aws::String(endpoint); } const char* region = getenv("AWS_REGION"); + if (!region) { + // TODO (yongtang): `S3_REGION` should be deprecated after 2.0. + region = getenv("S3_REGION"); + } if (region) { cfg.region = Aws::String(region); } else { - // TODO (yongtang): `S3_REGION` should be deprecated after 2.0. - const char* region = getenv("S3_REGION"); - if (region) { - cfg.region = Aws::String(region); + // Load config file (e.g., ~/.aws/config) only if AWS_SDK_LOAD_CONFIG + // is set with a truthy value. + const char* load_config_env = getenv("AWS_SDK_LOAD_CONFIG"); + string load_config = + load_config_env ? str_util::Lowercase(load_config_env) : ""; + if (load_config == "true" || load_config == "1") { + Aws::String config_file; + // If AWS_CONFIG_FILE is set then use it, otherwise use ~/.aws/config. + const char* config_file_env = getenv("AWS_CONFIG_FILE"); + if (config_file_env) { + config_file = config_file_env; + } else { + const char* home_env = getenv("HOME"); + if (home_env) { + config_file = home_env; + config_file += "/.aws/config"; + } + } + Aws::Config::AWSConfigFileProfileConfigLoader loader(config_file); + loader.Load(); + auto profiles = loader.GetProfiles(); + if (!profiles["default"].GetRegion().empty()) { + cfg.region = profiles["default"].GetRegion(); + } } } const char* use_https = getenv("S3_USE_HTTPS"); @@ -102,6 +128,16 @@ Aws::Client::ClientConfiguration& GetDefaultClientConfig() { return cfg; }; + +void ShutdownClient(Aws::S3::S3Client *s3_client) { + if (s3_client != nullptr) { + delete s3_client; + Aws::SDKOptions options; + Aws::ShutdownAPI(options); + AWSLogSystem::ShutdownAWSLogging(); + } +} + Status ParseS3Path(const string& fname, bool empty_object_ok, string* bucket, string* object) { if (!bucket || !object) { @@ -129,12 +165,12 @@ Status ParseS3Path(const string& fname, bool empty_object_ok, string* bucket, class S3RandomAccessFile : public RandomAccessFile { public: - S3RandomAccessFile(const string& bucket, const string& object) - : bucket_(bucket), object_(object) {} + S3RandomAccessFile(const string& bucket, const string& object, + std::shared_ptr s3_client) + : bucket_(bucket), object_(object), s3_client_(s3_client) {} Status Read(uint64 offset, size_t n, StringPiece* result, char* scratch) const override { - Aws::S3::S3Client s3Client(GetDefaultClientConfig()); Aws::S3::Model::GetObjectRequest getObjectRequest; getObjectRequest.WithBucket(bucket_.c_str()).WithKey(object_.c_str()); string bytes = strings::StrCat("bytes=", offset, "-", offset + n - 1); @@ -142,7 +178,7 @@ class S3RandomAccessFile : public RandomAccessFile { getObjectRequest.SetResponseStreamFactory([]() { return Aws::New(kS3FileSystemAllocationTag); }); - auto getObjectOutcome = s3Client.GetObject(getObjectRequest); + auto getObjectOutcome = this->s3_client_->GetObject(getObjectRequest); if (!getObjectOutcome.IsSuccess()) { n = 0; *result = StringPiece(scratch, n); @@ -160,13 +196,16 @@ class S3RandomAccessFile : public RandomAccessFile { private: string bucket_; string object_; + std::shared_ptr s3_client_; }; class S3WritableFile : public WritableFile { public: - S3WritableFile(const string& bucket, const string& object) + S3WritableFile(const string& bucket, const string& object, + std::shared_ptr s3_client) : bucket_(bucket), object_(object), + s3_client_(s3_client), sync_needed_(true), outfile_(Aws::MakeShared( kS3FileSystemAllocationTag, "/tmp/s3_filesystem_XXXXXX", @@ -205,17 +244,13 @@ class S3WritableFile : public WritableFile { if (!sync_needed_) { return Status::OK(); } - Aws::Client::ClientConfiguration clientConfig = GetDefaultClientConfig(); - clientConfig.connectTimeoutMs = 300000; - clientConfig.requestTimeoutMs = 600000; - Aws::S3::S3Client s3Client(clientConfig); Aws::S3::Model::PutObjectRequest putObjectRequest; putObjectRequest.WithBucket(bucket_.c_str()).WithKey(object_.c_str()); long offset = outfile_->tellp(); outfile_->seekg(0); putObjectRequest.SetBody(outfile_); putObjectRequest.SetContentLength(offset); - auto putObjectOutcome = s3Client.PutObject(putObjectRequest); + auto putObjectOutcome = this->s3_client_->PutObject(putObjectRequest); outfile_->clear(); outfile_->seekp(offset); if (!putObjectOutcome.IsSuccess()) { @@ -230,6 +265,7 @@ class S3WritableFile : public WritableFile { private: string bucket_; string object_; + std::shared_ptr s3_client_; bool sync_needed_; std::shared_ptr outfile_; }; @@ -248,31 +284,39 @@ class S3ReadOnlyMemoryRegion : public ReadOnlyMemoryRegion { } // namespace -S3FileSystem::S3FileSystem() { - AWSLogSystem::InitializeAWSLogging(); - - Aws::SDKOptions options; - options.cryptoOptions.sha256Factory_create_fn = []() { - return Aws::MakeShared(S3CryptoAllocationTag); - }; - options.cryptoOptions.sha256HMACFactory_create_fn = []() { - return Aws::MakeShared(S3CryptoAllocationTag); - }; - Aws::InitAPI(options); -} +S3FileSystem::S3FileSystem() : + s3_client_(nullptr, ShutdownClient), client_lock_() {} + +S3FileSystem::~S3FileSystem() {} + +// Initializes s3_client_, if needed, and returns it. +std::shared_ptr S3FileSystem::GetS3Client() { + std::lock_guard lock(this->client_lock_); + + if (this->s3_client_.get() == nullptr) { + AWSLogSystem::InitializeAWSLogging(); -S3FileSystem::~S3FileSystem() { - Aws::SDKOptions options; - Aws::ShutdownAPI(options); + Aws::SDKOptions options; + options.cryptoOptions.sha256Factory_create_fn = []() { + return Aws::MakeShared(S3CryptoAllocationTag); + }; + options.cryptoOptions.sha256HMACFactory_create_fn = []() { + return Aws::MakeShared(S3CryptoAllocationTag); + }; + Aws::InitAPI(options); - AWSLogSystem::ShutdownAWSLogging(); + this->s3_client_ = std::shared_ptr( + new Aws::S3::S3Client(GetDefaultClientConfig())); + } + + return this->s3_client_; } Status S3FileSystem::NewRandomAccessFile( const string& fname, std::unique_ptr* result) { string bucket, object; TF_RETURN_IF_ERROR(ParseS3Path(fname, false, &bucket, &object)); - result->reset(new S3RandomAccessFile(bucket, object)); + result->reset(new S3RandomAccessFile(bucket, object, this->GetS3Client())); return Status::OK(); } @@ -280,7 +324,7 @@ Status S3FileSystem::NewWritableFile(const string& fname, std::unique_ptr* result) { string bucket, object; TF_RETURN_IF_ERROR(ParseS3Path(fname, false, &bucket, &object)); - result->reset(new S3WritableFile(bucket, object)); + result->reset(new S3WritableFile(bucket, object, this->GetS3Client())); return Status::OK(); } @@ -295,7 +339,7 @@ Status S3FileSystem::NewAppendableFile(const string& fname, string bucket, object; TF_RETURN_IF_ERROR(ParseS3Path(fname, false, &bucket, &object)); - result->reset(new S3WritableFile(bucket, object)); + result->reset(new S3WritableFile(bucket, object, this->GetS3Client())); while (true) { status = reader->Read(offset, kS3ReadAppendableFileBufferSize, &read_chunk, @@ -346,7 +390,6 @@ Status S3FileSystem::GetChildren(const string& dir, prefix.push_back('/'); } - Aws::S3::S3Client s3Client(GetDefaultClientConfig()); Aws::S3::Model::ListObjectsRequest listObjectsRequest; listObjectsRequest.WithBucket(bucket.c_str()) .WithPrefix(prefix.c_str()) @@ -357,7 +400,7 @@ Status S3FileSystem::GetChildren(const string& dir, Aws::S3::Model::ListObjectsResult listObjectsResult; do { - auto listObjectsOutcome = s3Client.ListObjects(listObjectsRequest); + auto listObjectsOutcome = this->GetS3Client()->ListObjects(listObjectsRequest); if (!listObjectsOutcome.IsSuccess()) { string error = strings::StrCat( listObjectsOutcome.GetError().GetExceptionName().c_str(), ": ", @@ -391,11 +434,10 @@ Status S3FileSystem::Stat(const string& fname, FileStatistics* stats) { string bucket, object; TF_RETURN_IF_ERROR(ParseS3Path(fname, true, &bucket, &object)); - Aws::S3::S3Client s3Client(GetDefaultClientConfig()); if (object.empty()) { Aws::S3::Model::HeadBucketRequest headBucketRequest; headBucketRequest.WithBucket(bucket.c_str()); - auto headBucketOutcome = s3Client.HeadBucket(headBucketRequest); + auto headBucketOutcome = this->GetS3Client()->HeadBucket(headBucketRequest); if (!headBucketOutcome.IsSuccess()) { string error = strings::StrCat( headBucketOutcome.GetError().GetExceptionName().c_str(), ": ", @@ -413,7 +455,7 @@ Status S3FileSystem::Stat(const string& fname, FileStatistics* stats) { headObjectRequest.WithBucket(bucket.c_str()).WithKey(object.c_str()); headObjectRequest.SetResponseStreamFactory( []() { return Aws::New(kS3FileSystemAllocationTag); }); - auto headObjectOutcome = s3Client.HeadObject(headObjectRequest); + auto headObjectOutcome = this->GetS3Client()->HeadObject(headObjectRequest); if (headObjectOutcome.IsSuccess()) { stats->length = headObjectOutcome.GetResult().GetContentLength(); stats->is_directory = 0; @@ -431,7 +473,7 @@ Status S3FileSystem::Stat(const string& fname, FileStatistics* stats) { .WithMaxKeys(1); listObjectsRequest.SetResponseStreamFactory( []() { return Aws::New(kS3FileSystemAllocationTag); }); - auto listObjectsOutcome = s3Client.ListObjects(listObjectsRequest); + auto listObjectsOutcome = this->GetS3Client()->ListObjects(listObjectsRequest); if (listObjectsOutcome.IsSuccess()) { if (listObjectsOutcome.GetResult().GetContents().size() > 0) { stats->length = 0; @@ -449,11 +491,11 @@ Status S3FileSystem::DeleteFile(const string& fname) { string bucket, object; TF_RETURN_IF_ERROR(ParseS3Path(fname, false, &bucket, &object)); - Aws::S3::S3Client s3Client(GetDefaultClientConfig()); Aws::S3::Model::DeleteObjectRequest deleteObjectRequest; deleteObjectRequest.WithBucket(bucket.c_str()).WithKey(object.c_str()); - auto deleteObjectOutcome = s3Client.DeleteObject(deleteObjectRequest); + auto deleteObjectOutcome = + this->GetS3Client()->DeleteObject(deleteObjectRequest); if (!deleteObjectOutcome.IsSuccess()) { string error = strings::StrCat( deleteObjectOutcome.GetError().GetExceptionName().c_str(), ": ", @@ -468,10 +510,9 @@ Status S3FileSystem::CreateDir(const string& dirname) { TF_RETURN_IF_ERROR(ParseS3Path(dirname, true, &bucket, &object)); if (object.empty()) { - Aws::S3::S3Client s3Client(GetDefaultClientConfig()); Aws::S3::Model::HeadBucketRequest headBucketRequest; headBucketRequest.WithBucket(bucket.c_str()); - auto headBucketOutcome = s3Client.HeadBucket(headBucketRequest); + auto headBucketOutcome = this->GetS3Client()->HeadBucket(headBucketRequest); if (!headBucketOutcome.IsSuccess()) { return errors::NotFound("The bucket ", bucket, " was not found."); } @@ -491,7 +532,6 @@ Status S3FileSystem::DeleteDir(const string& dirname) { string bucket, object; TF_RETURN_IF_ERROR(ParseS3Path(dirname, false, &bucket, &object)); - Aws::S3::S3Client s3Client(GetDefaultClientConfig()); string prefix = object; if (prefix.back() != '/') { prefix.push_back('/'); @@ -502,7 +542,7 @@ Status S3FileSystem::DeleteDir(const string& dirname) { .WithMaxKeys(2); listObjectsRequest.SetResponseStreamFactory( []() { return Aws::New(kS3FileSystemAllocationTag); }); - auto listObjectsOutcome = s3Client.ListObjects(listObjectsRequest); + auto listObjectsOutcome = this->GetS3Client()->ListObjects(listObjectsRequest); if (listObjectsOutcome.IsSuccess()) { auto contents = listObjectsOutcome.GetResult().GetContents(); if (contents.size() > 1 || @@ -542,8 +582,6 @@ Status S3FileSystem::RenameFile(const string& src, const string& target) { } } - Aws::S3::S3Client s3Client(GetDefaultClientConfig()); - Aws::S3::Model::CopyObjectRequest copyObjectRequest; Aws::S3::Model::DeleteObjectRequest deleteObjectRequest; @@ -556,7 +594,7 @@ Status S3FileSystem::RenameFile(const string& src, const string& target) { Aws::S3::Model::ListObjectsResult listObjectsResult; do { - auto listObjectsOutcome = s3Client.ListObjects(listObjectsRequest); + auto listObjectsOutcome = this->GetS3Client()->ListObjects(listObjectsRequest); if (!listObjectsOutcome.IsSuccess()) { string error = strings::StrCat( listObjectsOutcome.GetError().GetExceptionName().c_str(), ": ", @@ -575,7 +613,7 @@ Status S3FileSystem::RenameFile(const string& src, const string& target) { copyObjectRequest.SetKey(target_key); copyObjectRequest.SetCopySource(source); - auto copyObjectOutcome = s3Client.CopyObject(copyObjectRequest); + auto copyObjectOutcome = this->GetS3Client()->CopyObject(copyObjectRequest); if (!copyObjectOutcome.IsSuccess()) { string error = strings::StrCat( copyObjectOutcome.GetError().GetExceptionName().c_str(), ": ", @@ -586,7 +624,8 @@ Status S3FileSystem::RenameFile(const string& src, const string& target) { deleteObjectRequest.SetBucket(src_bucket.c_str()); deleteObjectRequest.SetKey(src_key.c_str()); - auto deleteObjectOutcome = s3Client.DeleteObject(deleteObjectRequest); + auto deleteObjectOutcome = + this->GetS3Client()->DeleteObject(deleteObjectRequest); if (!deleteObjectOutcome.IsSuccess()) { string error = strings::StrCat( deleteObjectOutcome.GetError().GetExceptionName().c_str(), ": ", diff --git a/tensorflow/core/platform/s3/s3_file_system.h b/tensorflow/core/platform/s3/s3_file_system.h index 31ba3cecc5d0283f83091bfb687445a6ce87a344..168b8007f3b60c60724682dd7fc4e95f8d15a413 100644 --- a/tensorflow/core/platform/s3/s3_file_system.h +++ b/tensorflow/core/platform/s3/s3_file_system.h @@ -16,7 +16,9 @@ limitations under the License. #ifndef TENSORFLOW_CONTRIB_S3_S3_FILE_SYSTEM_H_ #define TENSORFLOW_CONTRIB_S3_S3_FILE_SYSTEM_H_ +#include #include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/mutex.h" namespace tensorflow { @@ -53,6 +55,13 @@ class S3FileSystem : public FileSystem { Status GetFileSize(const string& fname, uint64* size) override; Status RenameFile(const string& src, const string& target) override; + private: + // Returns the member S3 client, initializing as-needed. + std::shared_ptr GetS3Client(); + + std::shared_ptr s3_client_; + // Lock held when checking for s3_client_ initialization. + mutex client_lock_; }; } // namespace tensorflow diff --git a/tensorflow/core/platform/s3/s3_file_system_test.cc b/tensorflow/core/platform/s3/s3_file_system_test.cc index 0b42f5fcec0041a01a571b1e38dedaa7ef191c22..d4411d98657811c0bf6858c5ac48c7991e8bed5a 100644 --- a/tensorflow/core/platform/s3/s3_file_system_test.cc +++ b/tensorflow/core/platform/s3/s3_file_system_test.cc @@ -130,6 +130,8 @@ TEST_F(S3FileSystemTest, NewReadOnlyMemoryRegionFromFile) { TEST_F(S3FileSystemTest, FileExists) { const string fname = TmpDir("FileExists"); + // Ensure the file doesn't yet exist. + TF_ASSERT_OK(s3fs.DeleteFile(fname)); EXPECT_EQ(error::Code::NOT_FOUND, s3fs.FileExists(fname).code()); TF_ASSERT_OK(WriteString(fname, "test")); TF_EXPECT_OK(s3fs.FileExists(fname)); diff --git a/tensorflow/core/platform/stacktrace_handler.h b/tensorflow/core/platform/stacktrace_handler.h index d36c82c9ba893b4438c21662156291aa71df77ee..a52970fdaaa6693d537ac42b3d237ce3eb6a7755 100644 --- a/tensorflow/core/platform/stacktrace_handler.h +++ b/tensorflow/core/platform/stacktrace_handler.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_PLATFORM_BACKTRACE_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_PLATFORM_BACKTRACE_H_ +#ifndef TENSORFLOW_CORE_PLATFORM_BACKTRACE_H_ +#define TENSORFLOW_CORE_PLATFORM_BACKTRACE_H_ namespace tensorflow { namespace testing { @@ -25,4 +25,4 @@ void InstallStacktraceHandler(); } // namespace testing } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_PLATFORM_BACKTRACE_H_ +#endif // TENSORFLOW_CORE_PLATFORM_BACKTRACE_H_ diff --git a/tensorflow/core/profiler/BUILD b/tensorflow/core/profiler/BUILD index 5fbfc62e74c4ba5e8821eb12eb2cabd1b1c99068..35d99930186381edbb80aa6485856e288f1dd568 100644 --- a/tensorflow/core/profiler/BUILD +++ b/tensorflow/core/profiler/BUILD @@ -38,7 +38,7 @@ tf_cc_binary( "//tensorflow/core/profiler/internal:tfprof_stats", "//tensorflow/core/profiler/internal:tfprof_utils", "//tensorflow/core/profiler/internal/advisor:tfprof_advisor", - "@linenoise//:linenoise", + "@linenoise", ], ) diff --git a/tensorflow/core/profiler/internal/advisor/accelerator_utilization_checker.h b/tensorflow/core/profiler/internal/advisor/accelerator_utilization_checker.h index c6544fe0b02df1b317db2ce4ab73130f9f155e56..25766668d88925b0d494e5e80284188cc42fb5cd 100644 --- a/tensorflow/core/profiler/internal/advisor/accelerator_utilization_checker.h +++ b/tensorflow/core/profiler/internal/advisor/accelerator_utilization_checker.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ // This checker checks the accelerator's utilization. -#ifndef THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_ADVISOR_ACCELERATOR_UTILIZATION_CHECKER_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_ADVISOR_ACCELERATOR_UTILIZATION_CHECKER_H_ +#ifndef TENSORFLOW_CORE_PROFILER_INTERNAL_ADVISOR_ACCELERATOR_UTILIZATION_CHECKER_H_ +#define TENSORFLOW_CORE_PROFILER_INTERNAL_ADVISOR_ACCELERATOR_UTILIZATION_CHECKER_H_ #include "tensorflow/core/profiler/internal/advisor/checker.h" @@ -106,4 +106,4 @@ class AcceleratorUtilizationChecker : public Checker { } // namespace tfprof } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_ADVISOR_ACCELERATOR_UTILIZATION_CHECKER_H_ +#endif // TENSORFLOW_CORE_PROFILER_INTERNAL_ADVISOR_ACCELERATOR_UTILIZATION_CHECKER_H_ diff --git a/tensorflow/core/profiler/internal/advisor/checker.h b/tensorflow/core/profiler/internal/advisor/checker.h index 4b5ebcf9e83742c8aa3cff072f490c6ca0243061..5d7da39e6b27b01a3438c25c26b70e5e3b65c7ff 100644 --- a/tensorflow/core/profiler/internal/advisor/checker.h +++ b/tensorflow/core/profiler/internal/advisor/checker.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_ADVISOR_CHECKER_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_ADVISOR_CHECKER_H_ +#ifndef TENSORFLOW_CORE_PROFILER_INTERNAL_ADVISOR_CHECKER_H_ +#define TENSORFLOW_CORE_PROFILER_INTERNAL_ADVISOR_CHECKER_H_ #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/profiler/internal/tfprof_stats.h" @@ -49,4 +49,4 @@ class Checker { } // namespace tfprof } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_ADVISOR_CHECKER_H_ +#endif // TENSORFLOW_CORE_PROFILER_INTERNAL_ADVISOR_CHECKER_H_ diff --git a/tensorflow/core/profiler/internal/advisor/expensive_operation_checker.h b/tensorflow/core/profiler/internal/advisor/expensive_operation_checker.h index 145782c7bddc3c98f9bdcab179cc303f25755bd5..f5ac5c9c5a428354f57767e812e8292da21f014d 100644 --- a/tensorflow/core/profiler/internal/advisor/expensive_operation_checker.h +++ b/tensorflow/core/profiler/internal/advisor/expensive_operation_checker.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ // This checker checks the most expensive operations. -#ifndef THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_ADVISOR_EXPENSIVE_OPERATION_CHECKER_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_ADVISOR_EXPENSIVE_OPERATION_CHECKER_H_ +#ifndef TENSORFLOW_CORE_PROFILER_INTERNAL_ADVISOR_EXPENSIVE_OPERATION_CHECKER_H_ +#define TENSORFLOW_CORE_PROFILER_INTERNAL_ADVISOR_EXPENSIVE_OPERATION_CHECKER_H_ #include "tensorflow/core/profiler/internal/advisor/checker.h" @@ -137,4 +137,4 @@ class ExpensiveOperationChecker : public Checker { } // namespace tfprof } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_ADVISOR_EXPENSIVE_OP_CHECKER_H_ +#endif // TENSORFLOW_CORE_PROFILER_INTERNAL_ADVISOR_EXPENSIVE_OP_CHECKER_H_ diff --git a/tensorflow/core/profiler/internal/advisor/internal_checker_runner.h b/tensorflow/core/profiler/internal/advisor/internal_checker_runner.h index ec52741b19e6769ec9d571666c063524857dd199..6fc16cf903704ec6ce6fd18ebc0ba67962483795 100644 --- a/tensorflow/core/profiler/internal/advisor/internal_checker_runner.h +++ b/tensorflow/core/profiler/internal/advisor/internal_checker_runner.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_ADVISOR_INTERNAL_CHECKER_RUNNER_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_ADVISOR_INTERNAL_CHECKER_RUNNER_H_ +#ifndef TENSORFLOW_CORE_PROFILER_INTERNAL_ADVISOR_INTERNAL_CHECKER_RUNNER_H_ +#define TENSORFLOW_CORE_PROFILER_INTERNAL_ADVISOR_INTERNAL_CHECKER_RUNNER_H_ #include "tensorflow/core/profiler/internal/tfprof_utils.h" #include "tensorflow/core/profiler/tfprof_options.pb.h" @@ -31,4 +31,4 @@ AdviceProto RunInternalCheckers(const AdvisorOptionsProto& options, } // namespace tfprof } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_ADVISOR_INTERNAL_CHECKER_RUNNER_H_ +#endif // TENSORFLOW_CORE_PROFILER_INTERNAL_ADVISOR_INTERNAL_CHECKER_RUNNER_H_ diff --git a/tensorflow/core/profiler/internal/advisor/operation_checker.h b/tensorflow/core/profiler/internal/advisor/operation_checker.h index f0bd72fa409a87aa512c8a7f50f33d57ec21e3a7..6c1d5cd6704f2aeaa0eeed25a7cf1ecdbb73919c 100644 --- a/tensorflow/core/profiler/internal/advisor/operation_checker.h +++ b/tensorflow/core/profiler/internal/advisor/operation_checker.h @@ -14,8 +14,8 @@ limitations under the License. ==============================================================================*/ // This checker checks common wrong configurations of operations. // -#ifndef THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_ADVISOR_OPERATION_CHECKER_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_ADVISOR_OPERATION_CHECKER_H_ +#ifndef TENSORFLOW_CORE_PROFILER_INTERNAL_ADVISOR_OPERATION_CHECKER_H_ +#define TENSORFLOW_CORE_PROFILER_INTERNAL_ADVISOR_OPERATION_CHECKER_H_ #include "tensorflow/core/profiler/internal/advisor/checker.h" @@ -74,4 +74,4 @@ class OperationChecker : public Checker { } // namespace tfprof } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_ADVISOR_OPERATION_CHECKER_H_ +#endif // TENSORFLOW_CORE_PROFILER_INTERNAL_ADVISOR_OPERATION_CHECKER_H_ diff --git a/tensorflow/core/profiler/internal/advisor/tfprof_advisor.h b/tensorflow/core/profiler/internal/advisor/tfprof_advisor.h index 42bd6d54381d50a0670ac23a6ae686bcf0b13c81..270662bd4aca9bb0d17957ef43abd4eda2fa8e4d 100644 --- a/tensorflow/core/profiler/internal/advisor/tfprof_advisor.h +++ b/tensorflow/core/profiler/internal/advisor/tfprof_advisor.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_ADVISOR_TFPROF_ADVICE_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_ADVISOR_TFPROF_ADVICE_H_ +#ifndef TENSORFLOW_CORE_PROFILER_INTERNAL_ADVISOR_TFPROF_ADVICE_H_ +#define TENSORFLOW_CORE_PROFILER_INTERNAL_ADVISOR_TFPROF_ADVICE_H_ #include "tensorflow/core/profiler/internal/advisor/accelerator_utilization_checker.h" #include "tensorflow/core/profiler/internal/advisor/checker.h" @@ -78,4 +78,4 @@ class Advisor { } // namespace tfprof } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_ADVISOR_TFPROF_ADVICE_H_ +#endif // TENSORFLOW_CORE_PROFILER_INTERNAL_ADVISOR_TFPROF_ADVICE_H_ diff --git a/tensorflow/core/profiler/internal/print_model_analysis.h b/tensorflow/core/profiler/internal/print_model_analysis.h index 90166aa7d5fc16efa1b7d405af4b15491872ad54..29666ab9364253ea5131cf1739a960182e91cee5 100644 --- a/tensorflow/core/profiler/internal/print_model_analysis.h +++ b/tensorflow/core/profiler/internal/print_model_analysis.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_PRINT_MODEL_ANALYSIS_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_PRINT_MODEL_ANALYSIS_H_ +#ifndef TENSORFLOW_CORE_PROFILER_INTERNAL_PRINT_MODEL_ANALYSIS_H_ +#define TENSORFLOW_CORE_PROFILER_INTERNAL_PRINT_MODEL_ANALYSIS_H_ #include @@ -63,4 +63,4 @@ string PrintModelAnalysis(const string* graph, const string* run_meta, } // namespace tfprof } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_PRINT_MODEL_ANALYSIS_H_ +#endif // TENSORFLOW_CORE_PROFILER_INTERNAL_PRINT_MODEL_ANALYSIS_H_ diff --git a/tensorflow/core/profiler/internal/tfprof_code.h b/tensorflow/core/profiler/internal/tfprof_code.h index bcbdc1b48c490b40d7fbf460c7f57a3eefef2a0a..38395f967c102f44fb49c49ced676dd5b6c609de 100644 --- a/tensorflow/core/profiler/internal/tfprof_code.h +++ b/tensorflow/core/profiler/internal/tfprof_code.h @@ -16,8 +16,8 @@ limitations under the License. // Build a tree structure based on the TensorFlow model's python code stacks. // Stats are aggregated from descendants to ancestors. -#ifndef THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_CODE_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_CODE_H_ +#ifndef TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_CODE_H_ +#define TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_CODE_H_ #include #include @@ -94,4 +94,4 @@ class TFCode : public TFMultiShow { } // namespace tfprof } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_CODE_H_ +#endif // TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_CODE_H_ diff --git a/tensorflow/core/profiler/internal/tfprof_constants.h b/tensorflow/core/profiler/internal/tfprof_constants.h index 6a4eaaa890c51a1c2a730cfbb96d6d45316789c6..d4a47931a2700794b2d2e3cb932bc5d19dc2d90c 100644 --- a/tensorflow/core/profiler/internal/tfprof_constants.h +++ b/tensorflow/core/profiler/internal/tfprof_constants.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_CONSTANTS_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_CONSTANTS_H_ +#ifndef TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_CONSTANTS_H_ +#define TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_CONSTANTS_H_ namespace tensorflow { namespace tfprof { @@ -34,4 +34,4 @@ static const char* const kCkptVarType = "_checkpoint_variables"; } // namespace tfprof } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_CONSTANTS_H_ +#endif // TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_CONSTANTS_H_ diff --git a/tensorflow/core/profiler/internal/tfprof_graph.h b/tensorflow/core/profiler/internal/tfprof_graph.h index f7eef9c835b1985ccb8436691a35cfd779d94a8d..356a459a65ece4e4395db1da82c99739a6982318 100644 --- a/tensorflow/core/profiler/internal/tfprof_graph.h +++ b/tensorflow/core/profiler/internal/tfprof_graph.h @@ -16,8 +16,8 @@ limitations under the License. // Build a graph structure based on op inputs/outputs. The graph is a directed // acyclic graph pointing *from outputs to inputs*. -#ifndef THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_GRAPH_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_GRAPH_H_ +#ifndef TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_GRAPH_H_ +#define TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_GRAPH_H_ #include #include @@ -86,4 +86,4 @@ class TFGraph : public TFShow { } // namespace tfprof } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_GRAPH_H_ +#endif // TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_GRAPH_H_ diff --git a/tensorflow/core/profiler/internal/tfprof_node.h b/tensorflow/core/profiler/internal/tfprof_node.h index 255a0987e68400badeb24457e834646c3306f11a..0a97b1cb0f2568656fbc45883a688d0ecc5c95d8 100644 --- a/tensorflow/core/profiler/internal/tfprof_node.h +++ b/tensorflow/core/profiler/internal/tfprof_node.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_NODE_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_NODE_H_ +#ifndef TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_NODE_H_ +#define TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_NODE_H_ #include #include @@ -915,4 +915,4 @@ bool IsCanonicalDevice(const string& device); } // namespace tfprof } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_NODE_H_ +#endif // TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_NODE_H_ diff --git a/tensorflow/core/profiler/internal/tfprof_node_show.h b/tensorflow/core/profiler/internal/tfprof_node_show.h index ca6f9bca5e8fcf1a1e8d39d66b28de7e2fcc3f79..517da67d74c5663ecea4fb914ef0940590400489 100644 --- a/tensorflow/core/profiler/internal/tfprof_node_show.h +++ b/tensorflow/core/profiler/internal/tfprof_node_show.h @@ -21,8 +21,8 @@ limitations under the License. // ScopeNode and GraphNode each maps to one TFGraphNode. // CodeNode and OpNode each maps to one TFMultiGraphNode. -#ifndef THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_NODE_SHOW_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_NODE_SHOW_H_ +#ifndef TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_NODE_SHOW_H_ +#define TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_NODE_SHOW_H_ #include #include @@ -156,4 +156,4 @@ class OpNode : public ShowMultiNode { } // namespace tfprof } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_NODE_SHOW_H_ +#endif // TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_NODE_SHOW_H_ diff --git a/tensorflow/core/profiler/internal/tfprof_op.h b/tensorflow/core/profiler/internal/tfprof_op.h index fcc5e68f474e643a6e23dc9fc17dce7eca6f04b1..fe1c3b2ae826783c1405b6151b82f153c05d2901 100644 --- a/tensorflow/core/profiler/internal/tfprof_op.h +++ b/tensorflow/core/profiler/internal/tfprof_op.h @@ -15,8 +15,8 @@ limitations under the License. // Build a flat structure of ops. -#ifndef THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_OP_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_OP_H_ +#ifndef TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_OP_H_ +#define TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_OP_H_ #include #include @@ -76,4 +76,4 @@ class TFOp : public TFMultiShow { } // namespace tfprof } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_OP_H_ +#endif // TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_OP_H_ diff --git a/tensorflow/core/profiler/internal/tfprof_scope.h b/tensorflow/core/profiler/internal/tfprof_scope.h index bb847c08666df232a472aca8c882decb630c736d..235dfde803fa45146484870dfc46ebda367dc29c 100644 --- a/tensorflow/core/profiler/internal/tfprof_scope.h +++ b/tensorflow/core/profiler/internal/tfprof_scope.h @@ -17,8 +17,8 @@ limitations under the License. // For example, 'name1/name2' is a child of 'name1'. // Stats are aggregated from descendants to ancestors. -#ifndef THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_SCOPE_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_SCOPE_H_ +#ifndef TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_SCOPE_H_ +#define TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_SCOPE_H_ #include #include @@ -74,4 +74,4 @@ class TFScope : public TFShow { } // namespace tfprof } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_SCOPE_H_ +#endif // TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_SCOPE_H_ diff --git a/tensorflow/core/profiler/internal/tfprof_show.h b/tensorflow/core/profiler/internal/tfprof_show.h index 2067ea3b735a07922168f8b557e6cd8be534b408..4d6de060705435c5346f6f49810b7dfc05d4530e 100644 --- a/tensorflow/core/profiler/internal/tfprof_show.h +++ b/tensorflow/core/profiler/internal/tfprof_show.h @@ -15,8 +15,8 @@ limitations under the License. // Parent class and utilities for tfprof_graph and tfprof_scope. -#ifndef THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_SHOW_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_SHOW_H_ +#ifndef TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_SHOW_H_ +#define TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_SHOW_H_ #include #include @@ -151,4 +151,4 @@ string FormatAcceleratorExecTime(const T* node, const Options& opts) { } // namespace tfprof } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_SHOW_H_ +#endif // TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_SHOW_H_ diff --git a/tensorflow/core/profiler/internal/tfprof_show_multi.h b/tensorflow/core/profiler/internal/tfprof_show_multi.h index ac0ada04490a10330f0596e60f103fbfbe75fe4c..2a2208d8e78efd5bc20d0db23e5fdaabbb3e8d5a 100644 --- a/tensorflow/core/profiler/internal/tfprof_show_multi.h +++ b/tensorflow/core/profiler/internal/tfprof_show_multi.h @@ -15,8 +15,8 @@ limitations under the License. // Parent class and utilities for tfprof_code. -#ifndef THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_SHOW_MULTI_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_SHOW_MULTI_H_ +#ifndef TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_SHOW_MULTI_H_ +#define TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_SHOW_MULTI_H_ #include #include @@ -127,4 +127,4 @@ class TFMultiShow { } // namespace tfprof } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_SHOW_MULTI_H_ +#endif // TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_SHOW_MULTI_H_ diff --git a/tensorflow/core/profiler/internal/tfprof_stats.h b/tensorflow/core/profiler/internal/tfprof_stats.h index d78abda588b7df5239547a6a3519ce7304c32be1..db148c936c9746c773213a9a49803103814906d3 100644 --- a/tensorflow/core/profiler/internal/tfprof_stats.h +++ b/tensorflow/core/profiler/internal/tfprof_stats.h @@ -20,8 +20,8 @@ limitations under the License. // 3. Accept command and options to selectively aggregate stats for analysis // and print out the results. -#ifndef THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_STATS_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_STATS_H_ +#ifndef TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_STATS_H_ +#define TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_STATS_H_ #include #include @@ -83,7 +83,7 @@ class TFStats { const MultiGraphNodeProto& ShowMultiGraphNode(const string& cmd, const Options& opts) const; - // A a (partial) graph to existing graph. + // Add a (partial) graph to existing graph. void AddGraph(std::unique_ptr graph); // Add a step of run time meta data. @@ -118,11 +118,11 @@ class TFStats { MultiGraphNodeProto empty_multi_graph_node_; std::map id_to_string_; - // Graph nodes covered by RunMetdata, that is traced with run time stats. + // Graph nodes covered by RunMetadata, that is traced with run time stats. std::set covered_nodes_; }; } // namespace tfprof } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_STATS_H_ +#endif // TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_STATS_H_ diff --git a/tensorflow/core/profiler/internal/tfprof_tensor.h b/tensorflow/core/profiler/internal/tfprof_tensor.h index 9f72e081c91957f6534334e56bf85a6b1d36a1ba..7a0885772001a1c4b587cff54739264bb5542925 100644 --- a/tensorflow/core/profiler/internal/tfprof_tensor.h +++ b/tensorflow/core/profiler/internal/tfprof_tensor.h @@ -19,8 +19,8 @@ limitations under the License. // is not supported by TensorFlow CheckPointReader library, though it is // supported in current code. -#ifndef THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_TENSOR_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_TENSOR_H_ +#ifndef TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_TENSOR_H_ +#define TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_TENSOR_H_ #include @@ -173,4 +173,4 @@ class TFProfTensor { } // namespace tfprof } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_TENSOR_H_ +#endif // TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_TENSOR_H_ diff --git a/tensorflow/core/profiler/internal/tfprof_timeline.h b/tensorflow/core/profiler/internal/tfprof_timeline.h index b8174cdecbd764ff784049e75d0a62c038c05978..4428ab571f84ff75499f24d78af2547d512a8c1c 100644 --- a/tensorflow/core/profiler/internal/tfprof_timeline.h +++ b/tensorflow/core/profiler/internal/tfprof_timeline.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_TIMELINE_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_TIMELINE_H_ +#ifndef TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_TIMELINE_H_ +#define TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_TIMELINE_H_ #include "include/json/json.h" #include "tensorflow/core/framework/graph.pb.h" @@ -191,4 +191,4 @@ class Timeline { } // namespace tfprof } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_TIMELINE_H_ +#endif // TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_TIMELINE_H_ diff --git a/tensorflow/core/profiler/internal/tfprof_utils.h b/tensorflow/core/profiler/internal/tfprof_utils.h index afca3df7f8cb4d15a4abcecdbf2163fbf4ee8945..d4f80afce0c3145bed18ab677f7537a41dea778c 100644 --- a/tensorflow/core/profiler/internal/tfprof_utils.h +++ b/tensorflow/core/profiler/internal/tfprof_utils.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_UTILS_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_UTILS_H_ +#ifndef TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_UTILS_H_ +#define TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_UTILS_H_ #include #include @@ -72,4 +72,4 @@ string QueryDoc(const string& cmd, const Options& opts); } // namespace tfprof } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_UTILS_H_ +#endif // TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_UTILS_H_ diff --git a/tensorflow/core/profiler/tfprof_options.h b/tensorflow/core/profiler/tfprof_options.h index 463f5b3c3a69b3105141faea0c669a83c181bd93..d61deb72ac45517587739722457299acffa18a4c 100644 --- a/tensorflow/core/profiler/tfprof_options.h +++ b/tensorflow/core/profiler/tfprof_options.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_OPTIONS_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_OPTIONS_H_ +#ifndef TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_OPTIONS_H_ +#define TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_OPTIONS_H_ #include #include @@ -183,4 +183,4 @@ tensorflow::Status ParseOutput(const string& output_opt, string* output_type, } // namespace tfprof } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_OPTIONS_H_ +#endif // TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_OPTIONS_H_ diff --git a/tensorflow/core/protobuf/worker.proto b/tensorflow/core/protobuf/worker.proto index 9b51db1362124bb1db2645711684bd1cbf3e61b5..3e7289bd919015dcb6712ee89ccee3605dc6d907 100644 --- a/tensorflow/core/protobuf/worker.proto +++ b/tensorflow/core/protobuf/worker.proto @@ -292,7 +292,10 @@ message RecvTensorRequest { // into a RunGraph call on the same WorkerService. int64 step_id = 1; - // A key that identifies the tensor to be received. + // A key identifying the channel to receive tensors from. A RecvTensor request + // retrieves one tensor from the channel, but multiple tensors can be sent and + // received over the same channel with multiple RecvTensor requests. See + // rendezvous.h for details. string rendezvous_key = 2; // If true, use an out-of-band DMA mechanism to transfer the @@ -307,6 +310,16 @@ message RecvTensorRequest { // Optional information needed by the RPC subsystem. google.protobuf.Any transport_options = 6; + + // Unique identifier for this request. Every RecvTensorRequest must have a + // unique request_id, and retried RecvTensorRequests must have the same + // request_id. If request_id is zero, retry detection is disabled. + // + // Retried RecvTensorRequests are problematic because a RecvTensor with no + // corresponding sender will wait forever, and the tensor may have been + // delivered to a previous retry. Workers use request_ids to reject retried + // RecvTensor requests instead of waiting forever. + int64 request_id = 7; } message RecvTensorResponse { diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index adeb080ddef004c55c97a489a21c207362cf2e27..67da7bf4526235ae51eb172f8da9fc267cc12b98 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -94,10 +94,12 @@ limitations under the License. // 26. Add a bool 'stripped_default_attrs' to MetaInfoDef indicating // whether default-valued attrs have been stripped from the nodes in the // GraphDef. (7dec2017) +// 27. Deprecate TensorArray ops v2 in favor of v3 and deprecated io_ops +// deprecated in favor of V2 ops. (2018/01/23) #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0 -#define TF_GRAPH_DEF_VERSION 25 +#define TF_GRAPH_DEF_VERSION 26 // Checkpoint compatibility versions (the versions field in SavedSliceMeta). // diff --git a/tensorflow/core/util/command_line_flags.h b/tensorflow/core/util/command_line_flags.h index 121c7063c9ebf6d447d0077f612386e316e05624..928ae8a4e9405f30ec994110e9032c6c19dd1b7f 100644 --- a/tensorflow/core/util/command_line_flags.h +++ b/tensorflow/core/util/command_line_flags.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_UTIL_COMMAND_LINE_FLAGS_H -#define THIRD_PARTY_TENSORFLOW_CORE_UTIL_COMMAND_LINE_FLAGS_H +#ifndef TENSORFLOW_CORE_UTIL_COMMAND_LINE_FLAGS_H +#define TENSORFLOW_CORE_UTIL_COMMAND_LINE_FLAGS_H #include #include @@ -134,4 +134,4 @@ class Flags { } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_UTIL_COMMAND_LINE_FLAGS_H +#endif // TENSORFLOW_CORE_UTIL_COMMAND_LINE_FLAGS_H diff --git a/tensorflow/core/util/ctc/ctc_beam_entry.h b/tensorflow/core/util/ctc/ctc_beam_entry.h index d30ab3f4dadb28a0c63632357bc7d631e2bdc81f..53087821d7b4bc0f98e77be9274cbdb4c675c10f 100644 --- a/tensorflow/core/util/ctc/ctc_beam_entry.h +++ b/tensorflow/core/util/ctc/ctc_beam_entry.h @@ -52,26 +52,25 @@ struct BeamProbability { float label; }; +template +class BeamRoot; + template struct BeamEntry { - // Default constructor does not create a vector of children. - BeamEntry() : parent(nullptr), label(-1) {} - // Constructor giving parent, label, and number of children does - // create a vector of children. The object pointed to by p - // cannot be copied and should not be moved, otherwise parent will - // become invalid. - BeamEntry(BeamEntry* p, int l) : parent(p), label(l) {} + // BeamRoot::AddEntry() serves as the factory method. + friend BeamEntry* BeamRoot::AddEntry( + BeamEntry* p, int l); inline bool Active() const { return newp.total != kLogZero; } // Return the child at the given index, or construct a new one in-place if // none was found. BeamEntry& GetChild(int ind) { auto entry = children.emplace(ind, nullptr); auto& child_entry = entry.first->second; - // If this is a new child, populate the uniqe_ptr. + // If this is a new child, populate the BeamEntry*. if (entry.second) { - child_entry.reset(new BeamEntry(this, ind)); + child_entry = beam_root->AddEntry(this, ind); } - return *(child_entry.get()); + return *child_entry; } std::vector LabelSeq(bool merge_repeated) const { std::vector labels; @@ -90,15 +89,45 @@ struct BeamEntry { BeamEntry* parent; int label; - gtl::FlatMap>> children; + // All instances of child BeamEntry are owned by *beam_root. + gtl::FlatMap*> children; BeamProbability oldp; BeamProbability newp; CTCBeamState state; private: + // Constructor giving parent, label, and the beam_root. + // The object pointed to by p cannot be copied and should not be moved, + // otherwise parent will become invalid. + // This private constructor is only called through the factory method + // BeamRoot::AddEntry(). + BeamEntry(BeamEntry* p, int l, BeamRoot* beam_root) + : parent(p), label(l), beam_root(beam_root) {} + BeamRoot* beam_root; TF_DISALLOW_COPY_AND_ASSIGN(BeamEntry); }; +// This class owns all instances of BeamEntry. This is used to avoid recursive +// destructor call during destruction. +template +class BeamRoot { + public: + BeamRoot(BeamEntry* p, int l) { root_entry_ = AddEntry(p, l); } + BeamRoot(const BeamRoot&) = delete; + BeamRoot& operator=(const BeamRoot&) = delete; + + BeamEntry* AddEntry(BeamEntry* p, int l) { + auto* new_entry = new BeamEntry(p, l, this); + beam_entries_.emplace_back(new_entry); + return new_entry; + } + BeamEntry* RootEntry() const { return root_entry_; } + + private: + BeamEntry* root_entry_ = nullptr; + std::vector>> beam_entries_; +}; + // BeamComparer is the default beam comparer provided in CTCBeamSearch. template class BeamComparer { diff --git a/tensorflow/core/util/ctc/ctc_beam_search.h b/tensorflow/core/util/ctc/ctc_beam_search.h index 372f25a1434036ef6022841665f6f942af046dc1..709c65fc9659e5b76ffa42f6e3a2030e8cdc9676 100644 --- a/tensorflow/core/util/ctc/ctc_beam_search.h +++ b/tensorflow/core/util/ctc/ctc_beam_search.h @@ -16,11 +16,15 @@ limitations under the License. #ifndef TENSORFLOW_CORE_UTIL_CTC_CTC_BEAM_SEARCH_H_ #define TENSORFLOW_CORE_UTIL_CTC_CTC_BEAM_SEARCH_H_ +#include #include +#include #include +#include #include "third_party/eigen3/Eigen/Core" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/top_n.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" @@ -69,6 +73,7 @@ class CTCBeamSearchDecoder : public CTCDecoder { // P(l=abc? @ t=3) = P(a @ 0)*P(b @ 1)*P(c @ 2)*P(? @ 3) // but we calculate it recursively for speed purposes. typedef ctc_beam_search::BeamEntry BeamEntry; + typedef ctc_beam_search::BeamRoot BeamRoot; typedef ctc_beam_search::BeamProbability BeamProbability; public: @@ -142,7 +147,7 @@ class CTCBeamSearchDecoder : public CTCDecoder { float label_selection_margin_ = -1; // -1 means unlimited. gtl::TopN leaves_; - std::unique_ptr beam_root_; + std::unique_ptr beam_root_; BaseBeamScorer* beam_scorer_; TF_DISALLOW_COPY_AND_ASSIGN(CTCBeamSearchDecoder); @@ -367,15 +372,15 @@ void CTCBeamSearchDecoder::Reset() { // This beam root, and all of its children, will be in memory until // the next reset. - beam_root_.reset(new BeamEntry(nullptr, -1)); - beam_root_->newp.total = 0.0; // ln(1) - beam_root_->newp.blank = 0.0; // ln(1) + beam_root_.reset(new BeamRoot(nullptr, -1)); + beam_root_->RootEntry()->newp.total = 0.0; // ln(1) + beam_root_->RootEntry()->newp.blank = 0.0; // ln(1) // Add the root as the initial leaf. - leaves_.push(beam_root_.get()); + leaves_.push(beam_root_->RootEntry()); // Call initialize state on the root object. - beam_scorer_->InitializeState(&beam_root_->state); + beam_scorer_->InitializeState(&beam_root_->RootEntry()->state); } template diff --git a/tensorflow/core/util/ctc/ctc_decoder.h b/tensorflow/core/util/ctc/ctc_decoder.h index 5b28aeb70ad4bd91800dda824f0bdffd5fcbea7c..b8bab69053fa65d4a29eb08ba10154c1b68a184d 100644 --- a/tensorflow/core/util/ctc/ctc_decoder.h +++ b/tensorflow/core/util/ctc/ctc_decoder.h @@ -16,6 +16,9 @@ limitations under the License. #ifndef TENSORFLOW_CORE_UTIL_CTC_CTC_DECODER_H_ #define TENSORFLOW_CORE_UTIL_CTC_CTC_DECODER_H_ +#include +#include + #include "third_party/eigen3/Eigen/Core" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" diff --git a/tensorflow/core/util/example_proto_fast_parsing.h b/tensorflow/core/util/example_proto_fast_parsing.h index fe59ec77ca9872ada865a27075c733e30a003c21..1b08f0226735d0efe6ab9e8a17453311aa032ab0 100644 --- a/tensorflow/core/util/example_proto_fast_parsing.h +++ b/tensorflow/core/util/example_proto_fast_parsing.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_UTIL_EXAMPLE_PROTO_FAST_PARSING_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_UTIL_EXAMPLE_PROTO_FAST_PARSING_H_ +#ifndef TENSORFLOW_CORE_UTIL_EXAMPLE_PROTO_FAST_PARSING_H_ +#define TENSORFLOW_CORE_UTIL_EXAMPLE_PROTO_FAST_PARSING_H_ #include #include @@ -94,4 +94,4 @@ bool TestFastParse(const string& serialized, Example* example); } // namespace example } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_UTIL_EXAMPLE_PROTO_FAST_PARSING_H_ +#endif // TENSORFLOW_CORE_UTIL_EXAMPLE_PROTO_FAST_PARSING_H_ diff --git a/tensorflow/core/util/example_proto_helper.h b/tensorflow/core/util/example_proto_helper.h index 8b3c6c5a3fa20967377fcf5d9f14a5f1562e73dd..e51170496217d01084ebbc671524ca7829847a41 100644 --- a/tensorflow/core/util/example_proto_helper.h +++ b/tensorflow/core/util/example_proto_helper.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_UTIL_EXAMPLE_PROTO_HELPER_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_UTIL_EXAMPLE_PROTO_HELPER_H_ +#ifndef TENSORFLOW_CORE_UTIL_EXAMPLE_PROTO_HELPER_H_ +#define TENSORFLOW_CORE_UTIL_EXAMPLE_PROTO_HELPER_H_ #include #include @@ -314,4 +314,4 @@ class ParseSingleSequenceExampleAttrs { } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_UTIL_EXAMPLE_PROTO_HELPER_H_ +#endif // TENSORFLOW_CORE_UTIL_EXAMPLE_PROTO_HELPER_H_ diff --git a/tensorflow/core/util/matmul_autotune.h b/tensorflow/core/util/matmul_autotune.h index 53666238836b89db3198adce9620fcbd7c59a12c..5846cae2fc73f822633dd0fa1667ee2f55d487bc 100644 --- a/tensorflow/core/util/matmul_autotune.h +++ b/tensorflow/core/util/matmul_autotune.h @@ -15,8 +15,8 @@ limitations under the License. // The utility to check matmul autotune related flags. -#ifndef THIRD_PARTY_TENSORFLOW_CORE_UTIL_MATMUL_AUTOTUNE_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_UTIL_MATMUL_AUTOTUNE_H_ +#ifndef TENSORFLOW_CORE_UTIL_MATMUL_AUTOTUNE_H_ +#define TENSORFLOW_CORE_UTIL_MATMUL_AUTOTUNE_H_ namespace tensorflow { @@ -25,4 +25,4 @@ bool MatmulDoFP32ComputationFP16Input(); } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_UTIL_MATMUL_AUTOTUNE_H_ +#endif // TENSORFLOW_CORE_UTIL_MATMUL_AUTOTUNE_H_ diff --git a/tensorflow/core/util/strided_slice_op.h b/tensorflow/core/util/strided_slice_op.h index abca98f27b534ea3c4fc2bb7832a38ea6f47df0c..25ecccd28550e943e4a7ab9bc1529426ea8454d2 100644 --- a/tensorflow/core/util/strided_slice_op.h +++ b/tensorflow/core/util/strided_slice_op.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_UTIL_STRIDED_SLICE_OP_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_UTIL_STRIDED_SLICE_OP_H_ +#ifndef TENSORFLOW_CORE_UTIL_STRIDED_SLICE_OP_H_ +#define TENSORFLOW_CORE_UTIL_STRIDED_SLICE_OP_H_ #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" @@ -62,4 +62,4 @@ Status ValidateStridedSliceOp( } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_UTIL_STRIDED_SLICE_OP_H_ +#endif // TENSORFLOW_CORE_UTIL_STRIDED_SLICE_OP_H_ diff --git a/tensorflow/docs_src/api_guides/python/contrib.signal.md b/tensorflow/docs_src/api_guides/python/contrib.signal.md index 85ef3ad1341380607f457e9112e39930c569357d..0f7690f80a5bcb4a776df21cf0768f1540f01baf 100644 --- a/tensorflow/docs_src/api_guides/python/contrib.signal.md +++ b/tensorflow/docs_src/api_guides/python/contrib.signal.md @@ -28,14 +28,14 @@ The `axis` parameter to @{tf.contrib.signal.frame} allows you to frame tensors with inner structure (e.g. a spectrogram): ```python -# `magnitude_spectrograms` is a [batch_size, ?, 127] tensor of spectrograms. We +# `magnitude_spectrograms` is a [batch_size, ?, 129] tensor of spectrograms. We # would like to produce overlapping fixed-size spectrogram patches; for example, # for use in a situation where a fixed size input is needed. magnitude_spectrograms = tf.abs(tf.contrib.signal.stft( signals, frame_length=256, frame_step=64, fft_length=256)) -# `spectrogram_patches` is a [batch_size, ?, 64, 127] tensor containing a -# variable number of [64, 127] spectrogram patches per batch item. +# `spectrogram_patches` is a [batch_size, ?, 64, 129] tensor containing a +# variable number of [64, 129] spectrogram patches per batch item. spectrogram_patches = tf.contrib.signal.frame( magnitude_spectrograms, frame_length=64, frame_step=16, axis=1) ``` diff --git a/tensorflow/docs_src/api_guides/python/python_io.md b/tensorflow/docs_src/api_guides/python/python_io.md index a5444408fe8f276028b6cedd5044947051043d31..06282e49d5247ee1ad22eb5bce872ae2c08514e2 100644 --- a/tensorflow/docs_src/api_guides/python/python_io.md +++ b/tensorflow/docs_src/api_guides/python/python_io.md @@ -14,16 +14,16 @@ suitable if fast sharding or other non-sequential access is desired. ## TFRecords Format Details -A TFRecords file contains a sequence of strings with CRC hashes. Each record -has the format +A TFRecords file contains a sequence of strings with CRC32C (32-bit CRC using +the Castagnoli polynomial) hashes. Each record has the format uint64 length uint32 masked_crc32_of_length byte data[length] uint32 masked_crc32_of_data -and the records are concatenated together to produce the file. The CRC32s -are [described here](https://en.wikipedia.org/wiki/Cyclic_redundancy_check), -and the mask of a CRC is +and the records are concatenated together to produce the file. CRCs are +[described here](https://en.wikipedia.org/wiki/Cyclic_redundancy_check), and +the mask of a CRC is masked_crc = ((crc >> 15) | (crc << 17)) + 0xa282ead8ul diff --git a/tensorflow/docs_src/programmers_guide/saved_model.md b/tensorflow/docs_src/programmers_guide/saved_model.md index fa7a94cc0686bb86c8b7033589a4b2da0e02c87c..9f50be5b31cd8b61b81426f50aa9ef9beb3138f2 100644 --- a/tensorflow/docs_src/programmers_guide/saved_model.md +++ b/tensorflow/docs_src/programmers_guide/saved_model.md @@ -736,6 +736,7 @@ The `run` command provides the following two ways to pass inputs to the model: * `--inputs` option enables you to pass numpy ndarray in files. * `--input_exprs` option enables you to pass Python expressions. +* `--input_examples` option enables you to pass `tf.train.Example`. #### `--inputs` @@ -789,19 +790,31 @@ inputs that match the dtype and shape of the model's `SignatureDef`s. For example: ```bsh -`input_key=[[1], [2], [3]]` +`=[[1],[2],[3]]` ``` In addition to Python expressions, you may also pass numpy functions. For example: ```bsh -input_key=np.ones((32, 32, 3)) +`=np.ones((32,32,3))` ``` (Note that the `numpy` module is already available to you as `np`.) +#### `--inputs_examples` + +To pass `tf.train.Example` as inputs, specify the `--input_examples` option. +For each input key, it takes a list of dictionary, where each dictionary is an +instance of `tf.train.Example`. The dictionary keys are the features and the +values are the value lists for each feature. +For example: + +```bsh +`=[{"age":[22,24],"education":["BS","MS"]}]` +``` + #### Save Output By default, the SavedModel CLI writes output to stdout. If a directory is diff --git a/tensorflow/examples/android/BUILD b/tensorflow/examples/android/BUILD index 46df5973e89a1a87c79bde95262d04b5be88f54e..12146477972a116903f731a03b9755aafd92acc1 100644 --- a/tensorflow/examples/android/BUILD +++ b/tensorflow/examples/android/BUILD @@ -92,7 +92,7 @@ android_binary( filegroup( name = "external_assets", srcs = [ - "@inception5h//:model_files", + "@inception_v1//:model_files", "@mobile_ssd//:model_files", "@speech_commands//:model_files", "@stylize//:model_files", diff --git a/tensorflow/examples/android/download-models.gradle b/tensorflow/examples/android/download-models.gradle index 0e2cf65f538f49779b851c3f84259bf839ea90ef..d3b67eab52bfbcf006755bb36396a0d71fb66f77 100644 --- a/tensorflow/examples/android/download-models.gradle +++ b/tensorflow/examples/android/download-models.gradle @@ -9,7 +9,7 @@ */ // hard coded model files // LINT.IfChange -def models = ['inception5h.zip', +def models = ['inception_v1.zip', 'object_detection/ssd_mobilenet_v1_android_export.zip', 'stylize_v1.zip', 'speech_commands_conv_actions.zip'] diff --git a/tensorflow/examples/android/jni/object_tracking/config.h b/tensorflow/examples/android/jni/object_tracking/config.h index 86e9fc71b690f6dfda9658d9f081e990dbb9a612..47de2d2c15b3f7141182efb261a79a40e0da2e93 100644 --- a/tensorflow/examples/android/jni/object_tracking/config.h +++ b/tensorflow/examples/android/jni/object_tracking/config.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_CONFIG_H_ -#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_CONFIG_H_ +#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_CONFIG_H_ +#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_CONFIG_H_ #include @@ -297,4 +297,4 @@ struct TrackerConfig { } // namespace tf_tracking -#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_CONFIG_H_ +#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_CONFIG_H_ diff --git a/tensorflow/examples/android/jni/object_tracking/flow_cache.h b/tensorflow/examples/android/jni/object_tracking/flow_cache.h index 8813ab6d71846f5ce2e13a2853594de43d95b0b7..b62e334ecd7de55a31e4904c655c0659b0507639 100644 --- a/tensorflow/examples/android/jni/object_tracking/flow_cache.h +++ b/tensorflow/examples/android/jni/object_tracking/flow_cache.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_FLOW_CACHE_H_ -#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_FLOW_CACHE_H_ +#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_FLOW_CACHE_H_ +#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_FLOW_CACHE_H_ #include "tensorflow/examples/android/jni/object_tracking/geom.h" #include "tensorflow/examples/android/jni/object_tracking/utils.h" @@ -303,4 +303,4 @@ class FlowCache { } // namespace tf_tracking -#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_FLOW_CACHE_H_ +#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_FLOW_CACHE_H_ diff --git a/tensorflow/examples/android/jni/object_tracking/frame_pair.h b/tensorflow/examples/android/jni/object_tracking/frame_pair.h index 8f409fe80612e0115ca03b01ccfd5f7dd8a5f110..6c8ac9be9810327505f0a4f8c80f7099f060a5da 100644 --- a/tensorflow/examples/android/jni/object_tracking/frame_pair.h +++ b/tensorflow/examples/android/jni/object_tracking/frame_pair.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_FRAME_PAIR_H_ -#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_FRAME_PAIR_H_ +#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_FRAME_PAIR_H_ +#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_FRAME_PAIR_H_ #include "tensorflow/examples/android/jni/object_tracking/keypoint.h" @@ -100,4 +100,4 @@ class FramePair { } // namespace tf_tracking -#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_FRAME_PAIR_H_ +#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_FRAME_PAIR_H_ diff --git a/tensorflow/examples/android/jni/object_tracking/geom.h b/tensorflow/examples/android/jni/object_tracking/geom.h index 2819063616566a8f83b0cdb5beee48ebbb55e2f6..c975e40144b47337482dcbd4120d645f44fcaf7d 100644 --- a/tensorflow/examples/android/jni/object_tracking/geom.h +++ b/tensorflow/examples/android/jni/object_tracking/geom.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_GEOM_H_ -#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_GEOM_H_ +#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_GEOM_H_ +#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_GEOM_H_ #include "tensorflow/examples/android/jni/object_tracking/logging.h" #include "tensorflow/examples/android/jni/object_tracking/utils.h" @@ -316,4 +316,4 @@ inline BoundingSquare GetCenteredSquare(const BoundingBox& original_box) { } // namespace tf_tracking -#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_GEOM_H_ +#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_GEOM_H_ diff --git a/tensorflow/examples/android/jni/object_tracking/gl_utils.h b/tensorflow/examples/android/jni/object_tracking/gl_utils.h index bd5c233f4f31ad3a7d99b762911a9fb0acbcd36a..a29e677d3c534cacf41434e53f6ca286d4c1b17c 100755 --- a/tensorflow/examples/android/jni/object_tracking/gl_utils.h +++ b/tensorflow/examples/android/jni/object_tracking/gl_utils.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_GL_UTILS_H_ -#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_GL_UTILS_H_ +#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_GL_UTILS_H_ +#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_GL_UTILS_H_ #include #include @@ -52,4 +52,4 @@ inline static void MapWorldSquareToUnitSquare(const BoundingSquare& square) { } // namespace tf_tracking -#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_GL_UTILS_H_ +#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_GL_UTILS_H_ diff --git a/tensorflow/examples/android/jni/object_tracking/image-inl.h b/tensorflow/examples/android/jni/object_tracking/image-inl.h index 9c4c389aa716e640f9dc7a9953266f65c3b997bd..61d69908b5508de3f2d2f670ba5f926e9901f751 100644 --- a/tensorflow/examples/android/jni/object_tracking/image-inl.h +++ b/tensorflow/examples/android/jni/object_tracking/image-inl.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_INL_H_ -#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_INL_H_ +#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_INL_H_ +#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_INL_H_ #include @@ -641,4 +641,4 @@ inline void Image::FromArray(const T* const pixels, const int stride, } // namespace tf_tracking -#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_INL_H_ +#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_INL_H_ diff --git a/tensorflow/examples/android/jni/object_tracking/image.h b/tensorflow/examples/android/jni/object_tracking/image.h index b7a2301f5e1fc0c29ea2b4dd7f539d3438a65871..a436f0e0a13a695e6713eeafaa565495f0353662 100644 --- a/tensorflow/examples/android/jni/object_tracking/image.h +++ b/tensorflow/examples/android/jni/object_tracking/image.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_H_ -#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_H_ +#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_H_ +#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_H_ #include @@ -338,4 +338,4 @@ inline std::ostream& operator<<(std::ostream& stream, const Image& image) { } // namespace tf_tracking -#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_H_ +#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_H_ diff --git a/tensorflow/examples/android/jni/object_tracking/image_data.h b/tensorflow/examples/android/jni/object_tracking/image_data.h index 445cdb57a310cddd6f3b7e4e01ee105080f3fdd9..c4f91d8cbd801db11ce740c23360a3c021e2b548 100644 --- a/tensorflow/examples/android/jni/object_tracking/image_data.h +++ b/tensorflow/examples/android/jni/object_tracking/image_data.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_DATA_H_ -#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_DATA_H_ +#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_DATA_H_ +#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_DATA_H_ #include #include @@ -261,4 +261,4 @@ class ImageData { } // namespace tf_tracking -#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_DATA_H_ +#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_DATA_H_ diff --git a/tensorflow/examples/android/jni/object_tracking/image_utils.h b/tensorflow/examples/android/jni/object_tracking/image_utils.h index ac9ffd90f8a167199bbcc777df74c11630a1ef41..b4ad7000b3321e5b921187e0aa3cba69a2bfb2a6 100644 --- a/tensorflow/examples/android/jni/object_tracking/image_utils.h +++ b/tensorflow/examples/android/jni/object_tracking/image_utils.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_UTILS_H_ -#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_UTILS_H_ +#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_UTILS_H_ +#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_UTILS_H_ #include @@ -295,4 +295,4 @@ inline void NormalizeImage(Image* const image) { } // namespace tf_tracking -#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_UTILS_H_ +#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_UTILS_H_ diff --git a/tensorflow/examples/android/jni/object_tracking/integral_image.h b/tensorflow/examples/android/jni/object_tracking/integral_image.h index 8e82334abf684dba6de8247d013893baa2cda953..caf9b7d2ab88f17ee7fc614175165133c5513356 100755 --- a/tensorflow/examples/android/jni/object_tracking/integral_image.h +++ b/tensorflow/examples/android/jni/object_tracking/integral_image.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_INTEGRAL_IMAGE_H_ -#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_INTEGRAL_IMAGE_H_ +#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_INTEGRAL_IMAGE_H_ +#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_INTEGRAL_IMAGE_H_ #include "tensorflow/examples/android/jni/object_tracking/geom.h" #include "tensorflow/examples/android/jni/object_tracking/image-inl.h" @@ -184,4 +184,4 @@ class IntegralImage : public Image { } // namespace tf_tracking -#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_INTEGRAL_IMAGE_H_ +#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_INTEGRAL_IMAGE_H_ diff --git a/tensorflow/examples/android/jni/object_tracking/jni_utils.h b/tensorflow/examples/android/jni/object_tracking/jni_utils.h index 21fbabb5211ad51ea4c77885c5e8e8135b8aa96e..b81d9e0c1262234cfc6f0c5ba6bdc9a16713283f 100644 --- a/tensorflow/examples/android/jni/object_tracking/jni_utils.h +++ b/tensorflow/examples/android/jni/object_tracking/jni_utils.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_JNI_UTILS_H_ -#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_JNI_UTILS_H_ +#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_JNI_UTILS_H_ +#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_JNI_UTILS_H_ #include diff --git a/tensorflow/examples/android/jni/object_tracking/keypoint.h b/tensorflow/examples/android/jni/object_tracking/keypoint.h index 719f9aff3f80a2328083aa8fe0bcfff587fb38c6..93405a5b2a83f4bb4ad7d97bef2ff361b3578b94 100644 --- a/tensorflow/examples/android/jni/object_tracking/keypoint.h +++ b/tensorflow/examples/android/jni/object_tracking/keypoint.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_KEYPOINT_H_ -#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_KEYPOINT_H_ +#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_KEYPOINT_H_ +#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_KEYPOINT_H_ #include "tensorflow/examples/android/jni/object_tracking/geom.h" #include "tensorflow/examples/android/jni/object_tracking/image-inl.h" @@ -45,4 +45,4 @@ inline std::ostream& operator<<(std::ostream& stream, const Keypoint keypoint) { } // namespace tf_tracking -#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_KEYPOINT_H_ +#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_KEYPOINT_H_ diff --git a/tensorflow/examples/android/jni/object_tracking/keypoint_detector.h b/tensorflow/examples/android/jni/object_tracking/keypoint_detector.h index 33d228128d64060123f7aab8b84b23eb87d6fc84..2e85b835a7067b0a1d37908d187680bbc0a91ca6 100644 --- a/tensorflow/examples/android/jni/object_tracking/keypoint_detector.h +++ b/tensorflow/examples/android/jni/object_tracking/keypoint_detector.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_KEYPOINT_DETECTOR_H_ -#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_KEYPOINT_DETECTOR_H_ +#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_KEYPOINT_DETECTOR_H_ +#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_KEYPOINT_DETECTOR_H_ #include #include @@ -125,4 +125,4 @@ class KeypointDetector { } // namespace tf_tracking -#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_KEYPOINT_DETECTOR_H_ +#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_KEYPOINT_DETECTOR_H_ diff --git a/tensorflow/examples/android/jni/object_tracking/logging.h b/tensorflow/examples/android/jni/object_tracking/logging.h index dbc89af2f7ecd52cd1fff449665630fc0107b1af..852a7493993c104e0d0d7837774073dd8355e960 100644 --- a/tensorflow/examples/android/jni/object_tracking/logging.h +++ b/tensorflow/examples/android/jni/object_tracking/logging.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_LOG_STREAMING_H_ -#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_LOG_STREAMING_H_ +#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_LOG_STREAMING_H_ +#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_LOG_STREAMING_H_ #include #include @@ -118,4 +118,4 @@ void LogPrintF(const int severity, const char* format, ...); #endif -#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_LOG_STREAMING_H_ +#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_LOG_STREAMING_H_ diff --git a/tensorflow/examples/android/jni/object_tracking/object_detector.h b/tensorflow/examples/android/jni/object_tracking/object_detector.h index 252556767807a78b0dcbc68c940f5509618cae86..a65c7b0db70bd0fe57826deaab231f545a4fe510 100644 --- a/tensorflow/examples/android/jni/object_tracking/object_detector.h +++ b/tensorflow/examples/android/jni/object_tracking/object_detector.h @@ -20,8 +20,8 @@ limitations under the License. // Defines the ObjectDetector class that is the main interface for detecting // ObjectModelBases in frames. -#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_DETECTOR_H_ -#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_DETECTOR_H_ +#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_DETECTOR_H_ +#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_DETECTOR_H_ #include #include @@ -227,4 +227,4 @@ class ObjectDetector : public ObjectDetectorBase { } // namespace tf_tracking -#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_DETECTOR_H_ +#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_DETECTOR_H_ diff --git a/tensorflow/examples/android/jni/object_tracking/object_model.h b/tensorflow/examples/android/jni/object_tracking/object_model.h index be33aea638bf82df60ba151b64bca26fe261402c..5e81c4908080668849a654450cc10e95ec694889 100644 --- a/tensorflow/examples/android/jni/object_tracking/object_model.h +++ b/tensorflow/examples/android/jni/object_tracking/object_model.h @@ -19,8 +19,8 @@ limitations under the License. // Contains ObjectModelBase declaration. -#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_DETECTION_OBJECT_MODEL_H_ -#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_DETECTION_OBJECT_MODEL_H_ +#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_DETECTION_OBJECT_MODEL_H_ +#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_DETECTION_OBJECT_MODEL_H_ #ifdef __RENDER_OPENGL__ #include @@ -99,4 +99,4 @@ class ObjectModel : public ObjectModelBase { } // namespace tf_tracking -#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_DETECTION_OBJECT_MODEL_H_ +#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_DETECTION_OBJECT_MODEL_H_ diff --git a/tensorflow/examples/android/jni/object_tracking/object_tracker.h b/tensorflow/examples/android/jni/object_tracking/object_tracker.h index eb281fad3726cf782c1b937c3a213ba7f926bf88..20c7627fc5f0c0718f67eb230d00a8582b637e2c 100644 --- a/tensorflow/examples/android/jni/object_tracking/object_tracker.h +++ b/tensorflow/examples/android/jni/object_tracking/object_tracker.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_TRACKER_H_ -#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_TRACKER_H_ +#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_TRACKER_H_ +#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_TRACKER_H_ #include #include @@ -267,4 +267,4 @@ inline std::ostream& operator<<(std::ostream& stream, } // namespace tf_tracking -#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_TRACKER_H_ +#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_TRACKER_H_ diff --git a/tensorflow/examples/android/jni/object_tracking/optical_flow.h b/tensorflow/examples/android/jni/object_tracking/optical_flow.h index 2206375bebd80e75a9fe2a52609c6d8b3875b65e..f98ae22bd646775871832a40e4c9c0e72916ca4a 100644 --- a/tensorflow/examples/android/jni/object_tracking/optical_flow.h +++ b/tensorflow/examples/android/jni/object_tracking/optical_flow.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OPTICAL_FLOW_H_ -#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OPTICAL_FLOW_H_ +#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OPTICAL_FLOW_H_ +#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OPTICAL_FLOW_H_ #include "tensorflow/examples/android/jni/object_tracking/geom.h" #include "tensorflow/examples/android/jni/object_tracking/image-inl.h" @@ -97,4 +97,4 @@ class OpticalFlow { } // namespace tf_tracking -#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OPTICAL_FLOW_H_ +#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OPTICAL_FLOW_H_ diff --git a/tensorflow/examples/android/jni/object_tracking/sprite.h b/tensorflow/examples/android/jni/object_tracking/sprite.h index 05a13fea111941b2f36daa3694fba1a11ecd411a..b54a68458f108bf736a4daf237d34fc10742e1a6 100755 --- a/tensorflow/examples/android/jni/object_tracking/sprite.h +++ b/tensorflow/examples/android/jni/object_tracking/sprite.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_SPRITE_H_ -#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_SPRITE_H_ +#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_SPRITE_H_ +#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_SPRITE_H_ #include #include @@ -199,4 +199,4 @@ class Sprite { } // namespace tf_tracking -#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_SPRITE_H_ +#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_SPRITE_H_ diff --git a/tensorflow/examples/android/jni/object_tracking/time_log.h b/tensorflow/examples/android/jni/object_tracking/time_log.h index 60911da396c2e7ce0315e1b53a32773bd7b233c3..0073e115963ffc28ed22d5e50809d1e9f70094f4 100644 --- a/tensorflow/examples/android/jni/object_tracking/time_log.h +++ b/tensorflow/examples/android/jni/object_tracking/time_log.h @@ -15,8 +15,8 @@ limitations under the License. // Utility functions for performance profiling. -#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_TIME_LOG_H_ -#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_TIME_LOG_H_ +#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_TIME_LOG_H_ +#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_TIME_LOG_H_ #include @@ -134,4 +134,4 @@ inline static void TimeLog(const char* const str) { inline static void PrintTimeLog() {} #endif -#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_TIME_LOG_H_ +#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_TIME_LOG_H_ diff --git a/tensorflow/examples/android/jni/object_tracking/tracked_object.h b/tensorflow/examples/android/jni/object_tracking/tracked_object.h index cda14e19d26260703cbc213592c1795865e021a5..d7f1a7019bb2cb93e86d3de9122d597e6d907a7a 100644 --- a/tensorflow/examples/android/jni/object_tracking/tracked_object.h +++ b/tensorflow/examples/android/jni/object_tracking/tracked_object.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_TRACKED_OBJECT_H_ -#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_TRACKED_OBJECT_H_ +#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_TRACKED_OBJECT_H_ +#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_TRACKED_OBJECT_H_ #ifdef __RENDER_OPENGL__ #include "tensorflow/examples/android/jni/object_tracking/gl_utils.h" @@ -183,4 +183,4 @@ inline std::ostream& operator<<(std::ostream& stream, } // namespace tf_tracking -#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_TRACKED_OBJECT_H_ +#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_TRACKED_OBJECT_H_ diff --git a/tensorflow/examples/android/jni/object_tracking/utils.h b/tensorflow/examples/android/jni/object_tracking/utils.h index 51cdfcdcfb123b8d604d3f33db85628d6c67fb18..2e98734ec4e7e44894cb78e753ac7084d62c87a8 100644 --- a/tensorflow/examples/android/jni/object_tracking/utils.h +++ b/tensorflow/examples/android/jni/object_tracking/utils.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_UTILS_H_ -#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_UTILS_H_ +#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_UTILS_H_ +#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_UTILS_H_ #include #include @@ -378,4 +378,4 @@ inline bool Invert2x2(const T* const a, float* const a_inv) { } // namespace tf_tracking -#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_UTILS_H_ +#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_UTILS_H_ diff --git a/tensorflow/examples/speech_commands/accuracy_utils.h b/tensorflow/examples/speech_commands/accuracy_utils.h index 8d918cb64b064e10bd6f3e42e3e56d86d74242c6..eea048365bc9ff53bdd767be436fb657b43793c7 100644 --- a/tensorflow/examples/speech_commands/accuracy_utils.h +++ b/tensorflow/examples/speech_commands/accuracy_utils.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_SPEECH_COMMANDS_ACCURACY_UTILS_H_ -#define THIRD_PARTY_TENSORFLOW_EXAMPLES_SPEECH_COMMANDS_ACCURACY_UTILS_H_ +#ifndef TENSORFLOW_EXAMPLES_SPEECH_COMMANDS_ACCURACY_UTILS_H_ +#define TENSORFLOW_EXAMPLES_SPEECH_COMMANDS_ACCURACY_UTILS_H_ #include @@ -57,4 +57,4 @@ void PrintAccuracyStats(const StreamingAccuracyStats& stats); } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_SPEECH_COMMANDS_ACCURACY_UTILS_H_ +#endif // TENSORFLOW_EXAMPLES_SPEECH_COMMANDS_ACCURACY_UTILS_H_ diff --git a/tensorflow/examples/speech_commands/recognize_commands.h b/tensorflow/examples/speech_commands/recognize_commands.h index 7f8041f9ed39c4847b05b2ac748f8f526adbab44..a7cd194bec5612122cdf167aafda9b0786d770d8 100644 --- a/tensorflow/examples/speech_commands/recognize_commands.h +++ b/tensorflow/examples/speech_commands/recognize_commands.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_SPEECH_COMMANDS_RECOGNIZE_COMMANDS_H_ -#define THIRD_PARTY_TENSORFLOW_EXAMPLES_SPEECH_COMMANDS_RECOGNIZE_COMMANDS_H_ +#ifndef TENSORFLOW_EXAMPLES_SPEECH_COMMANDS_RECOGNIZE_COMMANDS_H_ +#define TENSORFLOW_EXAMPLES_SPEECH_COMMANDS_RECOGNIZE_COMMANDS_H_ #include #include @@ -76,4 +76,4 @@ class RecognizeCommands { } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_SPEECH_COMMANDS_RECOGNIZE_COMMANDS_H_ +#endif // TENSORFLOW_EXAMPLES_SPEECH_COMMANDS_RECOGNIZE_COMMANDS_H_ diff --git a/tensorflow/examples/tutorials/word2vec/word2vec_basic.py b/tensorflow/examples/tutorials/word2vec/word2vec_basic.py index 87cd95165e99f3fa7d8911112865a33570186533..d055d157454d4cb351e8db59eec484f212893fe5 100644 --- a/tensorflow/examples/tutorials/word2vec/word2vec_basic.py +++ b/tensorflow/examples/tutorials/word2vec/word2vec_basic.py @@ -21,6 +21,8 @@ from __future__ import print_function import collections import math import os +import sys +import argparse import random from tempfile import gettempdir import zipfile @@ -30,6 +32,24 @@ from six.moves import urllib from six.moves import xrange # pylint: disable=redefined-builtin import tensorflow as tf +from tensorflow.contrib.tensorboard.plugins import projector + +# Give a folder path as an argument with '--log_dir' to save +# TensorBoard summaries. Default is a log folder in current directory. +current_path = os.path.dirname(os.path.realpath(sys.argv[0])) + +parser = argparse.ArgumentParser() +parser.add_argument( + '--log_dir', + type=str, + default=os.path.join(current_path, 'log'), + help='The log directory for TensorBoard summaries.') +FLAGS, unparsed = parser.parse_known_args() + +# Create the directory for TensorBoard variables if there is not. +if not os.path.exists(FLAGS.log_dir): + os.makedirs(FLAGS.log_dir) + # Step 1: Download the data. url = 'http://mattmahoney.net/dc/' @@ -61,6 +81,7 @@ def read_data(filename): data = tf.compat.as_str(f.read(f.namelist()[0])).split() return data + vocabulary = read_data(filename) print('Data size', len(vocabulary)) @@ -86,20 +107,22 @@ def build_dataset(words, n_words): reversed_dictionary = dict(zip(dictionary.values(), dictionary.keys())) return data, count, dictionary, reversed_dictionary + # Filling 4 global variables: # data - list of codes (integers from 0 to vocabulary_size-1). # This is the original text but words are replaced by their codes # count - map of words(strings) to count of occurrences # dictionary - map of words(strings) to their codes(integers) # reverse_dictionary - maps codes(integers) to words(strings) -data, count, dictionary, reverse_dictionary = build_dataset(vocabulary, - vocabulary_size) +data, count, dictionary, reverse_dictionary = build_dataset( + vocabulary, vocabulary_size) del vocabulary # Hint to reduce memory. print('Most common words (+UNK)', count[:5]) print('Sample data', data[:10], [reverse_dictionary[i] for i in data[:10]]) data_index = 0 + # Step 3: Function to generate a training batch for the skip-gram model. def generate_batch(batch_size, num_skips, skip_window): global data_index @@ -129,96 +152,136 @@ def generate_batch(batch_size, num_skips, skip_window): data_index = (data_index + len(data) - span) % len(data) return batch, labels + batch, labels = generate_batch(batch_size=8, num_skips=2, skip_window=1) for i in range(8): - print(batch[i], reverse_dictionary[batch[i]], - '->', labels[i, 0], reverse_dictionary[labels[i, 0]]) + print(batch[i], reverse_dictionary[batch[i]], '->', labels[i, 0], + reverse_dictionary[labels[i, 0]]) # Step 4: Build and train a skip-gram model. batch_size = 128 embedding_size = 128 # Dimension of the embedding vector. -skip_window = 1 # How many words to consider left and right. -num_skips = 2 # How many times to reuse an input to generate a label. -num_sampled = 64 # Number of negative examples to sample. +skip_window = 1 # How many words to consider left and right. +num_skips = 2 # How many times to reuse an input to generate a label. +num_sampled = 64 # Number of negative examples to sample. # We pick a random validation set to sample nearest neighbors. Here we limit the # validation samples to the words that have a low numeric ID, which by # construction are also the most frequent. These 3 variables are used only for # displaying model accuracy, they don't affect calculation. -valid_size = 16 # Random set of words to evaluate similarity on. +valid_size = 16 # Random set of words to evaluate similarity on. valid_window = 100 # Only pick dev samples in the head of the distribution. valid_examples = np.random.choice(valid_window, valid_size, replace=False) - graph = tf.Graph() with graph.as_default(): # Input data. - train_inputs = tf.placeholder(tf.int32, shape=[batch_size]) - train_labels = tf.placeholder(tf.int32, shape=[batch_size, 1]) - valid_dataset = tf.constant(valid_examples, dtype=tf.int32) + with tf.name_scope('inputs'): + train_inputs = tf.placeholder(tf.int32, shape=[batch_size]) + train_labels = tf.placeholder(tf.int32, shape=[batch_size, 1]) + valid_dataset = tf.constant(valid_examples, dtype=tf.int32) # Ops and variables pinned to the CPU because of missing GPU implementation with tf.device('/cpu:0'): # Look up embeddings for inputs. - embeddings = tf.Variable( - tf.random_uniform([vocabulary_size, embedding_size], -1.0, 1.0)) - embed = tf.nn.embedding_lookup(embeddings, train_inputs) + with tf.name_scope('embeddings'): + embeddings = tf.Variable( + tf.random_uniform([vocabulary_size, embedding_size], -1.0, 1.0)) + embed = tf.nn.embedding_lookup(embeddings, train_inputs) # Construct the variables for the NCE loss - nce_weights = tf.Variable( - tf.truncated_normal([vocabulary_size, embedding_size], - stddev=1.0 / math.sqrt(embedding_size))) - nce_biases = tf.Variable(tf.zeros([vocabulary_size])) + with tf.name_scope('weights'): + nce_weights = tf.Variable( + tf.truncated_normal( + [vocabulary_size, embedding_size], + stddev=1.0 / math.sqrt(embedding_size))) + with tf.name_scope('biases'): + nce_biases = tf.Variable(tf.zeros([vocabulary_size])) # Compute the average NCE loss for the batch. # tf.nce_loss automatically draws a new sample of the negative labels each # time we evaluate the loss. # Explanation of the meaning of NCE loss: # http://mccormickml.com/2016/04/19/word2vec-tutorial-the-skip-gram-model/ - loss = tf.reduce_mean( - tf.nn.nce_loss(weights=nce_weights, - biases=nce_biases, - labels=train_labels, - inputs=embed, - num_sampled=num_sampled, - num_classes=vocabulary_size)) + with tf.name_scope('loss'): + loss = tf.reduce_mean( + tf.nn.nce_loss( + weights=nce_weights, + biases=nce_biases, + labels=train_labels, + inputs=embed, + num_sampled=num_sampled, + num_classes=vocabulary_size)) + + # Add the loss value as a scalar to summary. + tf.summary.scalar('loss', loss) # Construct the SGD optimizer using a learning rate of 1.0. - optimizer = tf.train.GradientDescentOptimizer(1.0).minimize(loss) + with tf.name_scope('optimizer'): + optimizer = tf.train.GradientDescentOptimizer(1.0).minimize(loss) # Compute the cosine similarity between minibatch examples and all embeddings. norm = tf.sqrt(tf.reduce_sum(tf.square(embeddings), 1, keep_dims=True)) normalized_embeddings = embeddings / norm - valid_embeddings = tf.nn.embedding_lookup( - normalized_embeddings, valid_dataset) + valid_embeddings = tf.nn.embedding_lookup(normalized_embeddings, + valid_dataset) similarity = tf.matmul( valid_embeddings, normalized_embeddings, transpose_b=True) + # Merge all summaries. + merged = tf.summary.merge_all() + # Add variable initializer. init = tf.global_variables_initializer() + # Create a saver. + saver = tf.train.Saver() + # Step 5: Begin training. num_steps = 100001 with tf.Session(graph=graph) as session: + # Open a writer to write summaries. + writer = tf.summary.FileWriter(FLAGS.log_dir, session.graph) + # We must initialize all variables before we use them. init.run() print('Initialized') average_loss = 0 for step in xrange(num_steps): - batch_inputs, batch_labels = generate_batch( - batch_size, num_skips, skip_window) + batch_inputs, batch_labels = generate_batch(batch_size, num_skips, + skip_window) feed_dict = {train_inputs: batch_inputs, train_labels: batch_labels} + # Define metadata variable. + run_metadata = tf.RunMetadata() + # We perform one update step by evaluating the optimizer op (including it # in the list of returned values for session.run() - _, loss_val = session.run([optimizer, loss], feed_dict=feed_dict) + # Also, evaluate the merged op to get all summaries from the returned "summary" variable. + # Feed metadata variable to session for visualizing the graph in TensorBoard. + _, summary, loss_val = session.run( + [optimizer, merged, loss], + feed_dict=feed_dict, + run_metadata=run_metadata) average_loss += loss_val + # Add returned summaries to writer in each step. + writer.add_summary(summary, step) + # Add metadata to visualize the graph for the last run. + if step == (num_steps - 1): + writer.add_run_metadata(run_metadata, 'step%d' % step) + + # Add returned summaries to writer in each step. + writer.add_summary(summary, step) + # Add metadata to visualize the graph for the last run. + if step == (num_steps - 1): + writer.add_run_metadata(run_metadata, 'step%d' % step) + if step % 2000 == 0: if step > 0: average_loss /= 2000 @@ -240,6 +303,23 @@ with tf.Session(graph=graph) as session: print(log_str) final_embeddings = normalized_embeddings.eval() + # Write corresponding labels for the embeddings. + with open(FLAGS.log_dir + '/metadata.tsv', 'w') as f: + for i in xrange(vocabulary_size): + f.write(reverse_dictionary[i] + '\n') + + # Save the model for checkpoints. + saver.save(session, os.path.join(FLAGS.log_dir, 'model.ckpt')) + + # Create a configuration for visualizing embeddings with the labels in TensorBoard. + config = projector.ProjectorConfig() + embedding_conf = config.embeddings.add() + embedding_conf.tensor_name = embeddings.name + embedding_conf.metadata_path = os.path.join(FLAGS.log_dir, 'metadata.tsv') + projector.visualize_embeddings(writer, config) + +writer.close() + # Step 6: Visualize the embeddings. @@ -251,21 +331,24 @@ def plot_with_labels(low_dim_embs, labels, filename): for i, label in enumerate(labels): x, y = low_dim_embs[i, :] plt.scatter(x, y) - plt.annotate(label, - xy=(x, y), - xytext=(5, 2), - textcoords='offset points', - ha='right', - va='bottom') + plt.annotate( + label, + xy=(x, y), + xytext=(5, 2), + textcoords='offset points', + ha='right', + va='bottom') plt.savefig(filename) + try: # pylint: disable=g-import-not-at-top from sklearn.manifold import TSNE import matplotlib.pyplot as plt - tsne = TSNE(perplexity=30, n_components=2, init='pca', n_iter=5000, method='exact') + tsne = TSNE( + perplexity=30, n_components=2, init='pca', n_iter=5000, method='exact') plot_only = 500 low_dim_embs = tsne.fit_transform(final_embeddings[:plot_only, :]) labels = [reverse_dictionary[i] for i in xrange(plot_only)] diff --git a/tensorflow/examples/wav_to_spectrogram/wav_to_spectrogram.h b/tensorflow/examples/wav_to_spectrogram/wav_to_spectrogram.h index fa8cb0abe951957e621703b7e2b9a6774200ac33..eada07e06f95f5ad9b97c2e2a992435de3437da9 100644 --- a/tensorflow/examples/wav_to_spectrogram/wav_to_spectrogram.h +++ b/tensorflow/examples/wav_to_spectrogram/wav_to_spectrogram.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_WAV_TO_SPECTROGRAM_WAV_TO_SPECTROGRAM_H_ -#define THIRD_PARTY_TENSORFLOW_EXAMPLES_WAV_TO_SPECTROGRAM_WAV_TO_SPECTROGRAM_H_ +#ifndef TENSORFLOW_EXAMPLES_WAV_TO_SPECTROGRAM_WAV_TO_SPECTROGRAM_H_ +#define TENSORFLOW_EXAMPLES_WAV_TO_SPECTROGRAM_WAV_TO_SPECTROGRAM_H_ #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/types.h" @@ -28,4 +28,4 @@ tensorflow::Status WavToSpectrogram(const tensorflow::string& input_wav, tensorflow::int32 stride, float brightness, const tensorflow::string& output_image); -#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_WAV_TO_SPECTROGRAM_WAV_TO_SPECTROGRAM_H_ +#endif // TENSORFLOW_EXAMPLES_WAV_TO_SPECTROGRAM_WAV_TO_SPECTROGRAM_H_ diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go index 7bcc55959cd6822fdcd52ad00d12f29fb17c33ef..5b19c90238ef3bb1361a5e2476e94dd06e76d128 100644 --- a/tensorflow/go/op/wrappers.go +++ b/tensorflow/go/op/wrappers.go @@ -116,6 +116,110 @@ func WriteImageSummary(scope *Scope, writer tf.Output, step tf.Output, tag tf.Ou return scope.AddOperation(opspec) } +// Outputs a `tf.Event` protocol buffer. +// +// When CreateSummaryDbWriter is being used, this op can be useful for +// importing data from event logs. +// +// Arguments: +// writer: A handle to a summary writer. +// event: A string containing a binary-encoded tf.Event proto. +// +// Returns the created operation. +func ImportEvent(scope *Scope, writer tf.Output, event tf.Output) (o *tf.Operation) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "ImportEvent", + Input: []tf.Input{ + writer, event, + }, + } + return scope.AddOperation(opspec) +} + +// Outputs a `Summary` protocol buffer with a tensor. +// +// Arguments: +// writer: A handle to a summary writer. +// step: The step to write the summary for. +// tensor: A tensor to serialize. +// tag: The summary's tag. +// summary_metadata: Serialized SummaryMetadata protocol buffer containing +// plugin-related metadata for this summary. +// +// Returns the created operation. +func WriteSummary(scope *Scope, writer tf.Output, step tf.Output, tensor tf.Output, tag tf.Output, summary_metadata tf.Output) (o *tf.Operation) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "WriteSummary", + Input: []tf.Input{ + writer, step, tensor, tag, summary_metadata, + }, + } + return scope.AddOperation(opspec) +} + +// Creates summary database writer accessible by given resource handle. +// +// This can be used to write tensors from the execution graph directly +// to a database. Only SQLite is supported right now. This function +// will create the schema if it doesn't exist. Entries in the Users, +// Experiments, and Runs tables will be created automatically if they +// don't already exist. +// +// Arguments: +// writer: Handle to SummaryWriter resource to overwrite. +// db_uri: For example "file:/tmp/foo.sqlite". +// experiment_name: Can't contain ASCII control characters or <>. Case +// sensitive. If empty, then the Run will not be associated with any +// Experiment. +// run_name: Can't contain ASCII control characters or <>. Case sensitive. +// If empty, then each Tag will not be associated with any Run. +// user_name: Must be valid as both a DNS label and Linux username. If +// empty, then the Experiment will not be associated with any User. +// +// Returns the created operation. +func CreateSummaryDbWriter(scope *Scope, writer tf.Output, db_uri tf.Output, experiment_name tf.Output, run_name tf.Output, user_name tf.Output) (o *tf.Operation) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "CreateSummaryDbWriter", + Input: []tf.Input{ + writer, db_uri, experiment_name, run_name, user_name, + }, + } + return scope.AddOperation(opspec) +} + +// Creates a summary file writer accessible by the given resource handle. +// +// Arguments: +// writer: A handle to the summary writer resource +// logdir: Directory where the event file will be written. +// max_queue: Size of the queue of pending events and summaries. +// flush_millis: How often, in milliseconds, to flush the pending events and +// summaries to disk. +// filename_suffix: Every event file's name is suffixed with this suffix. +// +// Returns the created operation. +func CreateSummaryFileWriter(scope *Scope, writer tf.Output, logdir tf.Output, max_queue tf.Output, flush_millis tf.Output, filename_suffix tf.Output) (o *tf.Operation) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "CreateSummaryFileWriter", + Input: []tf.Input{ + writer, logdir, max_queue, flush_millis, filename_suffix, + }, + } + return scope.AddOperation(opspec) +} + // Partitions `data` into `num_partitions` tensors using indices from `partitions`. // // For each index tuple `js` of size `partitions.ndim`, the slice `data[js, ...]` @@ -2357,6 +2461,8 @@ func TensorArrayV2TensorArrayName(value string) TensorArrayV2Attr { } // Deprecated. Use TensorArrayV3 +// +// DEPRECATED at GraphDef version 26: Use TensorArrayV3 func TensorArrayV2(scope *Scope, size tf.Output, dtype tf.DataType, optional ...TensorArrayV2Attr) (handle tf.Output) { if scope.Err() != nil { return @@ -3117,39 +3223,6 @@ func HistogramFixedWidth(scope *Scope, values tf.Output, value_range tf.Output, return op.Output(0) } -// Creates summary database writer accessible by given resource handle. -// -// This can be used to write tensors from the execution graph directly -// to a database. Only SQLite is supported right now. This function -// will create the schema if it doesn't exist. Entries in the Users, -// Experiments, and Runs tables will be created automatically if they -// don't already exist. -// -// Arguments: -// writer: Handle to SummaryWriter resource to overwrite. -// db_uri: For example "file:/tmp/foo.sqlite". -// experiment_name: Can't contain ASCII control characters or <>. Case -// sensitive. If empty, then the Run will not be associated with any -// Experiment. -// run_name: Can't contain ASCII control characters or <>. Case sensitive. -// If empty, then each Tag will not be associated with any Run. -// user_name: Must be valid as both a DNS label and Linux username. If -// empty, then the Experiment will not be associated with any User. -// -// Returns the created operation. -func CreateSummaryDbWriter(scope *Scope, writer tf.Output, db_uri tf.Output, experiment_name tf.Output, run_name tf.Output, user_name tf.Output) (o *tf.Operation) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "CreateSummaryDbWriter", - Input: []tf.Input{ - writer, db_uri, experiment_name, run_name, user_name, - }, - } - return scope.AddOperation(opspec) -} - // Adds Tensor 'bias' to Tensor 'input' for Quantized types. // // Broadcasts the values of bias on dimensions 0..N-2 of 'input'. @@ -5413,6 +5486,72 @@ func QuantizedReluX(scope *Scope, features tf.Output, max_value tf.Output, min_f return op.Output(0), op.Output(1), op.Output(2) } +// SummaryWriterAttr is an optional argument to SummaryWriter. +type SummaryWriterAttr func(optionalAttr) + +// SummaryWriterSharedName sets the optional shared_name attribute to value. +// If not specified, defaults to "" +func SummaryWriterSharedName(value string) SummaryWriterAttr { + return func(m optionalAttr) { + m["shared_name"] = value + } +} + +// SummaryWriterContainer sets the optional container attribute to value. +// If not specified, defaults to "" +func SummaryWriterContainer(value string) SummaryWriterAttr { + return func(m optionalAttr) { + m["container"] = value + } +} + +// Returns a handle to be used to access a summary writer. +// +// The summary writer is an in-graph resource which can be used by ops to write +// summaries to event files. +// +// Returns the summary writer resource. Scalar handle. +func SummaryWriter(scope *Scope, optional ...SummaryWriterAttr) (writer tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "SummaryWriter", + + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Computes gradients for SparseSegmentMean. +// +// Returns tensor "output" with same shape as grad, except for dimension 0 whose +// value is output_dim0. +// +// Arguments: +// grad: gradient propagated to the SparseSegmentMean op. +// indices: indices passed to the corresponding SparseSegmentMean op. +// segment_ids: segment_ids passed to the corresponding SparseSegmentMean op. +// output_dim0: dimension 0 of "data" passed to SparseSegmentMean op. +func SparseSegmentMeanGrad(scope *Scope, grad tf.Output, indices tf.Output, segment_ids tf.Output, output_dim0 tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "SparseSegmentMeanGrad", + Input: []tf.Input{ + grad, indices, segment_ids, output_dim0, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // Applies softmax to a batched N-D `SparseTensor`. // // The inputs represent an N-D SparseTensor with logical shape `[..., B, C]` @@ -7427,30 +7566,6 @@ func VarHandleOp(scope *Scope, dtype tf.DataType, shape tf.Shape, optional ...Va return op.Output(0) } -// Creates a summary file writer accessible by the given resource handle. -// -// Arguments: -// writer: A handle to the summary writer resource -// logdir: Directory where the event file will be written. -// max_queue: Size of the queue of pending events and summaries. -// flush_millis: How often, in milliseconds, to flush the pending events and -// summaries to disk. -// filename_suffix: Every event file's name is suffixed with this suffix. -// -// Returns the created operation. -func CreateSummaryFileWriter(scope *Scope, writer tf.Output, logdir tf.Output, max_queue tf.Output, flush_millis tf.Output, filename_suffix tf.Output) (o *tf.Operation) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "CreateSummaryFileWriter", - Input: []tf.Input{ - writer, logdir, max_queue, flush_millis, filename_suffix, - }, - } - return scope.AddOperation(opspec) -} - // Elementwise computes the bitwise XOR of `x` and `y`. // // The result will have those bits set, that are different in `x` and `y`. The @@ -10353,6 +10468,8 @@ func SparseReshape(scope *Scope, input_indices tf.Output, input_shape tf.Output, } // Deprecated. Use TensorArraySplitV3 +// +// DEPRECATED at GraphDef version 26: Use TensorArraySplitV3 func TensorArraySplitV2(scope *Scope, handle tf.Output, value tf.Output, lengths tf.Output, flow_in tf.Output) (flow_out tf.Output) { if scope.Err() != nil { return @@ -10908,37 +11025,196 @@ func DepthwiseConv2dNativeBackpropFilter(scope *Scope, input tf.Output, filter_s return op.Output(0) } -// Component-wise divides a SparseTensor by a dense Tensor. -// -// *Limitation*: this Op only broadcasts the dense side to the sparse side, but not -// the other direction. +// Flushes the writer's unwritten events. // // Arguments: -// sp_indices: 2-D. `N x R` matrix with the indices of non-empty values in a -// SparseTensor, possibly not in canonical ordering. -// sp_values: 1-D. `N` non-empty values corresponding to `sp_indices`. -// sp_shape: 1-D. Shape of the input SparseTensor. -// dense: `R`-D. The dense Tensor operand. +// writer: A handle to the summary writer resource. // -// Returns 1-D. The `N` values that are operated on. -func SparseDenseCwiseDiv(scope *Scope, sp_indices tf.Output, sp_values tf.Output, sp_shape tf.Output, dense tf.Output) (output tf.Output) { +// Returns the created operation. +func FlushSummaryWriter(scope *Scope, writer tf.Output) (o *tf.Operation) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "SparseDenseCwiseDiv", + Type: "FlushSummaryWriter", Input: []tf.Input{ - sp_indices, sp_values, sp_shape, dense, + writer, }, } - op := scope.AddOperation(opspec) - return op.Output(0) + return scope.AddOperation(opspec) } -// ResourceApplyMomentumAttr is an optional argument to ResourceApplyMomentum. -type ResourceApplyMomentumAttr func(optionalAttr) +// QuantizeV2Attr is an optional argument to QuantizeV2. +type QuantizeV2Attr func(optionalAttr) -// ResourceApplyMomentumUseLocking sets the optional use_locking attribute to value. +// QuantizeV2Mode sets the optional mode attribute to value. +// If not specified, defaults to "MIN_COMBINED" +func QuantizeV2Mode(value string) QuantizeV2Attr { + return func(m optionalAttr) { + m["mode"] = value + } +} + +// QuantizeV2RoundMode sets the optional round_mode attribute to value. +// If not specified, defaults to "HALF_AWAY_FROM_ZERO" +func QuantizeV2RoundMode(value string) QuantizeV2Attr { + return func(m optionalAttr) { + m["round_mode"] = value + } +} + +// Quantize the 'input' tensor of type float to 'output' tensor of type 'T'. +// +// [min_range, max_range] are scalar floats that specify the range for +// the 'input' data. The 'mode' attribute controls exactly which calculations are +// used to convert the float values to their quantized equivalents. The +// 'round_mode' attribute controls which rounding tie-breaking algorithm is used +// when rounding float values to their quantized equivalents. +// +// In 'MIN_COMBINED' mode, each value of the tensor will undergo the following: +// +// ``` +// out[i] = (in[i] - min_range) * range(T) / (max_range - min_range) +// if T == qint8, out[i] -= (range(T) + 1) / 2.0 +// ``` +// here `range(T) = numeric_limits::max() - numeric_limits::min()` +// +// *MIN_COMBINED Mode Example* +// +// Assume the input is type float and has a possible range of [0.0, 6.0] and the +// output type is quint8 ([0, 255]). The min_range and max_range values should be +// specified as 0.0 and 6.0. Quantizing from float to quint8 will multiply each +// value of the input by 255/6 and cast to quint8. +// +// If the output type was qint8 ([-128, 127]), the operation will additionally +// subtract each value by 128 prior to casting, so that the range of values aligns +// with the range of qint8. +// +// If the mode is 'MIN_FIRST', then this approach is used: +// +// ``` +// num_discrete_values = 1 << (# of bits in T) +// range_adjust = num_discrete_values / (num_discrete_values - 1) +// range = (range_max - range_min) * range_adjust +// range_scale = num_discrete_values / range +// quantized = round(input * range_scale) - round(range_min * range_scale) + +// numeric_limits::min() +// quantized = max(quantized, numeric_limits::min()) +// quantized = min(quantized, numeric_limits::max()) +// ``` +// +// The biggest difference between this and MIN_COMBINED is that the minimum range +// is rounded first, before it's subtracted from the rounded value. With +// MIN_COMBINED, a small bias is introduced where repeated iterations of quantizing +// and dequantizing will introduce a larger and larger error. +// +// *SCALED mode Example* +// +// `SCALED` mode matches the quantization approach used in +// `QuantizeAndDequantize{V2|V3}`. +// +// If the mode is `SCALED`, we do not use the full range of the output type, +// choosing to elide the lowest possible value for symmetry (e.g., output range is +// -127 to 127, not -128 to 127 for signed 8 bit quantization), so that 0.0 maps to +// 0. +// +// We first find the range of values in our tensor. The +// range we use is always centered on 0, so we find m such that +// ```c++ +// m = max(abs(input_min), abs(input_max)) +// ``` +// +// Our input tensor range is then `[-m, m]`. +// +// Next, we choose our fixed-point quantization buckets, `[min_fixed, max_fixed]`. +// If T is signed, this is +// ``` +// num_bits = sizeof(T) * 8 +// [min_fixed, max_fixed] = +// [-(1 << (num_bits - 1) - 1), (1 << (num_bits - 1)) - 1] +// ``` +// +// Otherwise, if T is unsigned, the fixed-point range is +// ``` +// [min_fixed, max_fixed] = [0, (1 << num_bits) - 1] +// ``` +// +// From this we compute our scaling factor, s: +// ```c++ +// s = (max_fixed - min_fixed) / (2 * m) +// ``` +// +// Now we can quantize the elements of our tensor: +// ```c++ +// result = round(input * s) +// ``` +// +// One thing to watch out for is that the operator may choose to adjust the +// requested minimum and maximum values slightly during the quantization process, +// so you should always use the output ports as the range for further calculations. +// For example, if the requested minimum and maximum values are close to equal, +// they will be separated by a small epsilon value to prevent ill-formed quantized +// buffers from being created. Otherwise, you can end up with buffers where all the +// quantized values map to the same float value, which causes problems for +// operations that have to perform further calculations on them. +// +// Arguments: +// +// min_range: The minimum scalar value possibly produced for the input. +// max_range: The maximum scalar value possibly produced for the input. +// +// +// Returns The quantized data produced from the float input.The actual minimum scalar value used for the output.The actual maximum scalar value used for the output. +func QuantizeV2(scope *Scope, input tf.Output, min_range tf.Output, max_range tf.Output, T tf.DataType, optional ...QuantizeV2Attr) (output tf.Output, output_min tf.Output, output_max tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"T": T} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "QuantizeV2", + Input: []tf.Input{ + input, min_range, max_range, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2) +} + +// Component-wise divides a SparseTensor by a dense Tensor. +// +// *Limitation*: this Op only broadcasts the dense side to the sparse side, but not +// the other direction. +// +// Arguments: +// sp_indices: 2-D. `N x R` matrix with the indices of non-empty values in a +// SparseTensor, possibly not in canonical ordering. +// sp_values: 1-D. `N` non-empty values corresponding to `sp_indices`. +// sp_shape: 1-D. Shape of the input SparseTensor. +// dense: `R`-D. The dense Tensor operand. +// +// Returns 1-D. The `N` values that are operated on. +func SparseDenseCwiseDiv(scope *Scope, sp_indices tf.Output, sp_values tf.Output, sp_shape tf.Output, dense tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "SparseDenseCwiseDiv", + Input: []tf.Input{ + sp_indices, sp_values, sp_shape, dense, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// ResourceApplyMomentumAttr is an optional argument to ResourceApplyMomentum. +type ResourceApplyMomentumAttr func(optionalAttr) + +// ResourceApplyMomentumUseLocking sets the optional use_locking attribute to value. // // value: If `True`, updating of the var and accum tensors will be protected // by a lock; otherwise the behavior is undefined, but may exhibit less @@ -11607,6 +11883,8 @@ func MaxPoolV2(scope *Scope, input tf.Output, ksize tf.Output, strides tf.Output } // Deprecated. Use TensorArrayReadV3 +// +// DEPRECATED at GraphDef version 26: Use TensorArrayReadV3 func TensorArrayReadV2(scope *Scope, handle tf.Output, index tf.Output, flow_in tf.Output, dtype tf.DataType) (value tf.Output) { if scope.Err() != nil { return @@ -14420,293 +14698,81 @@ func RandomGamma(scope *Scope, shape tf.Output, alpha tf.Output, optional ...Ran return op.Output(0) } -// AvgPool3DGradAttr is an optional argument to AvgPool3DGrad. -type AvgPool3DGradAttr func(optionalAttr) +// QuantizedConv2DAttr is an optional argument to QuantizedConv2D. +type QuantizedConv2DAttr func(optionalAttr) -// AvgPool3DGradDataFormat sets the optional data_format attribute to value. +// QuantizedConv2DOutType sets the optional out_type attribute to value. +// If not specified, defaults to DT_QINT32 +func QuantizedConv2DOutType(value tf.DataType) QuantizedConv2DAttr { + return func(m optionalAttr) { + m["out_type"] = value + } +} + +// QuantizedConv2DDilations sets the optional dilations attribute to value. // -// value: The data format of the input and output data. With the -// default format "NDHWC", the data is stored in the order of: -// [batch, in_depth, in_height, in_width, in_channels]. -// Alternatively, the format could be "NCDHW", the data storage order is: -// [batch, in_channels, in_depth, in_height, in_width]. -// If not specified, defaults to "NDHWC" -func AvgPool3DGradDataFormat(value string) AvgPool3DGradAttr { +// value: 1-D tensor of length 4. The dilation factor for each dimension of +// `input`. If set to k > 1, there will be k-1 skipped cells between each +// filter element on that dimension. The dimension order is determined by the +// value of `data_format`, see above for details. Dilations in the batch and +// depth dimensions must be 1. +// If not specified, defaults to +func QuantizedConv2DDilations(value []int64) QuantizedConv2DAttr { return func(m optionalAttr) { - m["data_format"] = value + m["dilations"] = value } } -// Computes gradients of average pooling function. +// Computes a 2D convolution given quantized 4D input and filter tensors. +// +// The inputs are quantized tensors where the lowest value represents the real +// number of the associated minimum, and the highest represents the maximum. +// This means that you can only interpret the quantized output in the same way, by +// taking the returned minimum and maximum values into account. // // Arguments: -// orig_input_shape: The original input dimensions. -// grad: Output backprop of shape `[batch, depth, rows, cols, channels]`. -// ksize: 1-D tensor of length 5. The size of the window for each dimension of -// the input tensor. Must have `ksize[0] = ksize[4] = 1`. -// strides: 1-D tensor of length 5. The stride of the sliding window for each -// dimension of `input`. Must have `strides[0] = strides[4] = 1`. +// +// filter: filter's input_depth dimension must match input's depth dimensions. +// min_input: The float value that the lowest quantized input value represents. +// max_input: The float value that the highest quantized input value represents. +// min_filter: The float value that the lowest quantized filter value represents. +// max_filter: The float value that the highest quantized filter value represents. +// strides: The stride of the sliding window for each dimension of the input +// tensor. // padding: The type of padding algorithm to use. // -// Returns The backprop for input. -func AvgPool3DGrad(scope *Scope, orig_input_shape tf.Output, grad tf.Output, ksize []int64, strides []int64, padding string, optional ...AvgPool3DGradAttr) (output tf.Output) { +// Returns The float value that the lowest quantized output value represents.The float value that the highest quantized output value represents. +func QuantizedConv2D(scope *Scope, input tf.Output, filter tf.Output, min_input tf.Output, max_input tf.Output, min_filter tf.Output, max_filter tf.Output, strides []int64, padding string, optional ...QuantizedConv2DAttr) (output tf.Output, min_output tf.Output, max_output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} + attrs := map[string]interface{}{"strides": strides, "padding": padding} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "AvgPool3DGrad", + Type: "QuantizedConv2D", Input: []tf.Input{ - orig_input_shape, grad, + input, filter, min_input, max_input, min_filter, max_filter, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1), op.Output(2) } -// ParseSingleSequenceExampleAttr is an optional argument to ParseSingleSequenceExample. -type ParseSingleSequenceExampleAttr func(optionalAttr) +// ResourceGatherAttr is an optional argument to ResourceGather. +type ResourceGatherAttr func(optionalAttr) -// ParseSingleSequenceExampleContextSparseTypes sets the optional context_sparse_types attribute to value. -// -// value: A list of Ncontext_sparse types; the data types of data in -// each context Feature given in context_sparse_keys. -// Currently the ParseSingleSequenceExample supports DT_FLOAT (FloatList), -// DT_INT64 (Int64List), and DT_STRING (BytesList). -// If not specified, defaults to <> -// -// REQUIRES: len(value) >= 0 -func ParseSingleSequenceExampleContextSparseTypes(value []tf.DataType) ParseSingleSequenceExampleAttr { +// ResourceGatherValidateIndices sets the optional validate_indices attribute to value. +// If not specified, defaults to true +func ResourceGatherValidateIndices(value bool) ResourceGatherAttr { return func(m optionalAttr) { - m["context_sparse_types"] = value + m["validate_indices"] = value } } -// ParseSingleSequenceExampleFeatureListDenseTypes sets the optional feature_list_dense_types attribute to value. -// If not specified, defaults to <> -// -// REQUIRES: len(value) >= 0 -func ParseSingleSequenceExampleFeatureListDenseTypes(value []tf.DataType) ParseSingleSequenceExampleAttr { - return func(m optionalAttr) { - m["feature_list_dense_types"] = value - } -} - -// ParseSingleSequenceExampleContextDenseShapes sets the optional context_dense_shapes attribute to value. -// -// value: A list of Ncontext_dense shapes; the shapes of data in -// each context Feature given in context_dense_keys. -// The number of elements in the Feature corresponding to context_dense_key[j] -// must always equal context_dense_shapes[j].NumEntries(). -// The shape of context_dense_values[j] will match context_dense_shapes[j]. -// If not specified, defaults to <> -// -// REQUIRES: len(value) >= 0 -func ParseSingleSequenceExampleContextDenseShapes(value []tf.Shape) ParseSingleSequenceExampleAttr { - return func(m optionalAttr) { - m["context_dense_shapes"] = value - } -} - -// ParseSingleSequenceExampleFeatureListSparseTypes sets the optional feature_list_sparse_types attribute to value. -// -// value: A list of Nfeature_list_sparse types; the data types -// of data in each FeatureList given in feature_list_sparse_keys. -// Currently the ParseSingleSequenceExample supports DT_FLOAT (FloatList), -// DT_INT64 (Int64List), and DT_STRING (BytesList). -// If not specified, defaults to <> -// -// REQUIRES: len(value) >= 0 -func ParseSingleSequenceExampleFeatureListSparseTypes(value []tf.DataType) ParseSingleSequenceExampleAttr { - return func(m optionalAttr) { - m["feature_list_sparse_types"] = value - } -} - -// ParseSingleSequenceExampleFeatureListDenseShapes sets the optional feature_list_dense_shapes attribute to value. -// -// value: A list of Nfeature_list_dense shapes; the shapes of -// data in each FeatureList given in feature_list_dense_keys. -// The shape of each Feature in the FeatureList corresponding to -// feature_list_dense_key[j] must always equal -// feature_list_dense_shapes[j].NumEntries(). -// If not specified, defaults to <> -// -// REQUIRES: len(value) >= 0 -func ParseSingleSequenceExampleFeatureListDenseShapes(value []tf.Shape) ParseSingleSequenceExampleAttr { - return func(m optionalAttr) { - m["feature_list_dense_shapes"] = value - } -} - -// Transforms a scalar brain.SequenceExample proto (as strings) into typed tensors. -// -// Arguments: -// serialized: A scalar containing a binary serialized SequenceExample proto. -// feature_list_dense_missing_assumed_empty: A vector listing the -// FeatureList keys which may be missing from the SequenceExample. If the -// associated FeatureList is missing, it is treated as empty. By default, -// any FeatureList not listed in this vector must exist in the SequenceExample. -// context_sparse_keys: A list of Ncontext_sparse string Tensors (scalars). -// The keys expected in the Examples' features associated with context_sparse -// values. -// context_dense_keys: A list of Ncontext_dense string Tensors (scalars). -// The keys expected in the SequenceExamples' context features associated with -// dense values. -// feature_list_sparse_keys: A list of Nfeature_list_sparse string Tensors -// (scalars). The keys expected in the FeatureLists associated with sparse -// values. -// feature_list_dense_keys: A list of Nfeature_list_dense string Tensors (scalars). -// The keys expected in the SequenceExamples' feature_lists associated -// with lists of dense values. -// context_dense_defaults: A list of Ncontext_dense Tensors (some may be empty). -// context_dense_defaults[j] provides default values -// when the SequenceExample's context map lacks context_dense_key[j]. -// If an empty Tensor is provided for context_dense_defaults[j], -// then the Feature context_dense_keys[j] is required. -// The input type is inferred from context_dense_defaults[j], even when it's -// empty. If context_dense_defaults[j] is not empty, its shape must match -// context_dense_shapes[j]. -// debug_name: A scalar containing the name of the serialized proto. -// May contain, for example, table key (descriptive) name for the -// corresponding serialized proto. This is purely useful for debugging -// purposes, and the presence of values here has no effect on the output. -// May also be an empty scalar if no name is available. -func ParseSingleSequenceExample(scope *Scope, serialized tf.Output, feature_list_dense_missing_assumed_empty tf.Output, context_sparse_keys []tf.Output, context_dense_keys []tf.Output, feature_list_sparse_keys []tf.Output, feature_list_dense_keys []tf.Output, context_dense_defaults []tf.Output, debug_name tf.Output, optional ...ParseSingleSequenceExampleAttr) (context_sparse_indices []tf.Output, context_sparse_values []tf.Output, context_sparse_shapes []tf.Output, context_dense_values []tf.Output, feature_list_sparse_indices []tf.Output, feature_list_sparse_values []tf.Output, feature_list_sparse_shapes []tf.Output, feature_list_dense_values []tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "ParseSingleSequenceExample", - Input: []tf.Input{ - serialized, feature_list_dense_missing_assumed_empty, tf.OutputList(context_sparse_keys), tf.OutputList(context_dense_keys), tf.OutputList(feature_list_sparse_keys), tf.OutputList(feature_list_dense_keys), tf.OutputList(context_dense_defaults), debug_name, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - if scope.Err() != nil { - return - } - var idx int - var err error - if context_sparse_indices, idx, err = makeOutputList(op, idx, "context_sparse_indices"); err != nil { - scope.UpdateErr("ParseSingleSequenceExample", err) - return - } - if context_sparse_values, idx, err = makeOutputList(op, idx, "context_sparse_values"); err != nil { - scope.UpdateErr("ParseSingleSequenceExample", err) - return - } - if context_sparse_shapes, idx, err = makeOutputList(op, idx, "context_sparse_shapes"); err != nil { - scope.UpdateErr("ParseSingleSequenceExample", err) - return - } - if context_dense_values, idx, err = makeOutputList(op, idx, "context_dense_values"); err != nil { - scope.UpdateErr("ParseSingleSequenceExample", err) - return - } - if feature_list_sparse_indices, idx, err = makeOutputList(op, idx, "feature_list_sparse_indices"); err != nil { - scope.UpdateErr("ParseSingleSequenceExample", err) - return - } - if feature_list_sparse_values, idx, err = makeOutputList(op, idx, "feature_list_sparse_values"); err != nil { - scope.UpdateErr("ParseSingleSequenceExample", err) - return - } - if feature_list_sparse_shapes, idx, err = makeOutputList(op, idx, "feature_list_sparse_shapes"); err != nil { - scope.UpdateErr("ParseSingleSequenceExample", err) - return - } - if feature_list_dense_values, idx, err = makeOutputList(op, idx, "feature_list_dense_values"); err != nil { - scope.UpdateErr("ParseSingleSequenceExample", err) - return - } - return context_sparse_indices, context_sparse_values, context_sparse_shapes, context_dense_values, feature_list_sparse_indices, feature_list_sparse_values, feature_list_sparse_shapes, feature_list_dense_values -} - -// QuantizedConv2DAttr is an optional argument to QuantizedConv2D. -type QuantizedConv2DAttr func(optionalAttr) - -// QuantizedConv2DOutType sets the optional out_type attribute to value. -// If not specified, defaults to DT_QINT32 -func QuantizedConv2DOutType(value tf.DataType) QuantizedConv2DAttr { - return func(m optionalAttr) { - m["out_type"] = value - } -} - -// QuantizedConv2DDilations sets the optional dilations attribute to value. -// -// value: 1-D tensor of length 4. The dilation factor for each dimension of -// `input`. If set to k > 1, there will be k-1 skipped cells between each -// filter element on that dimension. The dimension order is determined by the -// value of `data_format`, see above for details. Dilations in the batch and -// depth dimensions must be 1. -// If not specified, defaults to -func QuantizedConv2DDilations(value []int64) QuantizedConv2DAttr { - return func(m optionalAttr) { - m["dilations"] = value - } -} - -// Computes a 2D convolution given quantized 4D input and filter tensors. -// -// The inputs are quantized tensors where the lowest value represents the real -// number of the associated minimum, and the highest represents the maximum. -// This means that you can only interpret the quantized output in the same way, by -// taking the returned minimum and maximum values into account. -// -// Arguments: -// -// filter: filter's input_depth dimension must match input's depth dimensions. -// min_input: The float value that the lowest quantized input value represents. -// max_input: The float value that the highest quantized input value represents. -// min_filter: The float value that the lowest quantized filter value represents. -// max_filter: The float value that the highest quantized filter value represents. -// strides: The stride of the sliding window for each dimension of the input -// tensor. -// padding: The type of padding algorithm to use. -// -// Returns The float value that the lowest quantized output value represents.The float value that the highest quantized output value represents. -func QuantizedConv2D(scope *Scope, input tf.Output, filter tf.Output, min_input tf.Output, max_input tf.Output, min_filter tf.Output, max_filter tf.Output, strides []int64, padding string, optional ...QuantizedConv2DAttr) (output tf.Output, min_output tf.Output, max_output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"strides": strides, "padding": padding} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "QuantizedConv2D", - Input: []tf.Input{ - input, filter, min_input, max_input, min_filter, max_filter, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) -} - -// ResourceGatherAttr is an optional argument to ResourceGather. -type ResourceGatherAttr func(optionalAttr) - -// ResourceGatherValidateIndices sets the optional validate_indices attribute to value. -// If not specified, defaults to true -func ResourceGatherValidateIndices(value bool) ResourceGatherAttr { - return func(m optionalAttr) { - m["validate_indices"] = value - } -} - -// Gather slices from the variable pointed to by `resource` according to `indices`. +// Gather slices from the variable pointed to by `resource` according to `indices`. // // `indices` must be an integer tensor of any dimension (usually 0-D or 1-D). // Produces an output tensor with shape `indices.shape + params.shape[1:]` where: @@ -18490,34 +18556,6 @@ func ReaderResetV2(scope *Scope, reader_handle tf.Output) (o *tf.Operation) { return scope.AddOperation(opspec) } -// Adjust the hue of one or more images. -// -// `images` is a tensor of at least 3 dimensions. The last dimension is -// interpretted as channels, and must be three. -// -// The input image is considered in the RGB colorspace. Conceptually, the RGB -// colors are first mapped into HSV. A delta is then applied all the hue values, -// and then remapped back to RGB colorspace. -// -// Arguments: -// images: Images to adjust. At least 3-D. -// delta: A float delta to add to the hue. -// -// Returns The hue-adjusted image or images. -func AdjustHue(scope *Scope, images tf.Output, delta tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "AdjustHue", - Input: []tf.Input{ - images, delta, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // ResourceApplyAdamAttr is an optional argument to ResourceApplyAdam. type ResourceApplyAdamAttr func(optionalAttr) @@ -18625,76 +18663,10 @@ func MatchingFiles(scope *Scope, pattern tf.Output) (filenames tf.Output) { return op.Output(0) } -// Computes gradients for SparseSegmentMean. -// -// Returns tensor "output" with same shape as grad, except for dimension 0 whose -// value is output_dim0. -// -// Arguments: -// grad: gradient propagated to the SparseSegmentMean op. -// indices: indices passed to the corresponding SparseSegmentMean op. -// segment_ids: segment_ids passed to the corresponding SparseSegmentMean op. -// output_dim0: dimension 0 of "data" passed to SparseSegmentMean op. -func SparseSegmentMeanGrad(scope *Scope, grad tf.Output, indices tf.Output, segment_ids tf.Output, output_dim0 tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "SparseSegmentMeanGrad", - Input: []tf.Input{ - grad, indices, segment_ids, output_dim0, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// SummaryWriterAttr is an optional argument to SummaryWriter. -type SummaryWriterAttr func(optionalAttr) - -// SummaryWriterSharedName sets the optional shared_name attribute to value. -// If not specified, defaults to "" -func SummaryWriterSharedName(value string) SummaryWriterAttr { - return func(m optionalAttr) { - m["shared_name"] = value - } -} - -// SummaryWriterContainer sets the optional container attribute to value. -// If not specified, defaults to "" -func SummaryWriterContainer(value string) SummaryWriterAttr { - return func(m optionalAttr) { - m["container"] = value - } -} - -// Returns a handle to be used to access a summary writer. -// -// The summary writer is an in-graph resource which can be used by ops to write -// summaries to event files. -// -// Returns the summary writer resource. Scalar handle. -func SummaryWriter(scope *Scope, optional ...SummaryWriterAttr) (writer tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "SummaryWriter", - - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// ResizeBicubicGradAttr is an optional argument to ResizeBicubicGrad. -type ResizeBicubicGradAttr func(optionalAttr) - -// ResizeBicubicGradAlignCorners sets the optional align_corners attribute to value. +// ResizeBicubicGradAttr is an optional argument to ResizeBicubicGrad. +type ResizeBicubicGradAttr func(optionalAttr) + +// ResizeBicubicGradAlignCorners sets the optional align_corners attribute to value. // // value: If true, rescale grads by (orig_height - 1) / (height - 1), which // exactly aligns the 4 corners of grads and original_image. If false, rescale by @@ -20245,6 +20217,8 @@ func DenseToSparseBatchDataset(scope *Scope, input_dataset tf.Output, batch_size } // Deprecated. Use TensorArrayGradV3 +// +// DEPRECATED at GraphDef version 26: Use TensorArrayGradV3 func TensorArrayGradV2(scope *Scope, handle tf.Output, flow_in tf.Output, source string) (grad_handle tf.Output) { if scope.Err() != nil { return @@ -21538,6 +21512,8 @@ func TensorArrayGatherV2ElementShape(value tf.Shape) TensorArrayGatherV2Attr { } // Deprecated. Use TensorArrayGatherV3 +// +// DEPRECATED at GraphDef version 26: Use TensorArrayGatherV3 func TensorArrayGatherV2(scope *Scope, handle tf.Output, indices tf.Output, flow_in tf.Output, dtype tf.DataType, optional ...TensorArrayGatherV2Attr) (value tf.Output) { if scope.Err() != nil { return @@ -22262,6 +22238,8 @@ func EncodeBase64(scope *Scope, input tf.Output, optional ...EncodeBase64Attr) ( // Deprecated. Use TensorArrayCloseV3 // +// DEPRECATED at GraphDef version 26: Use TensorArrayCloseV3 +// // Returns the created operation. func TensorArrayCloseV2(scope *Scope, handle tf.Output) (o *tf.Operation) { if scope.Err() != nil { @@ -22381,6 +22359,69 @@ func Abs(scope *Scope, x tf.Output) (y tf.Output) { return op.Output(0) } +// Flushes and closes the summary writer. +// +// Also removes it from the resource manager. To reopen, use another +// CreateSummaryFileWriter op. +// +// Arguments: +// writer: A handle to the summary writer resource. +// +// Returns the created operation. +func CloseSummaryWriter(scope *Scope, writer tf.Output) (o *tf.Operation) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "CloseSummaryWriter", + Input: []tf.Input{ + writer, + }, + } + return scope.AddOperation(opspec) +} + +// StackV2Attr is an optional argument to StackV2. +type StackV2Attr func(optionalAttr) + +// StackV2StackName sets the optional stack_name attribute to value. +// +// value: Overrides the name used for the temporary stack resource. Default +// value is the name of the 'Stack' op (which is guaranteed unique). +// If not specified, defaults to "" +func StackV2StackName(value string) StackV2Attr { + return func(m optionalAttr) { + m["stack_name"] = value + } +} + +// A stack that produces elements in first-in last-out order. +// +// Arguments: +// max_size: The maximum size of the stack if non-negative. If negative, the stack +// size is unlimited. +// elem_type: The type of the elements on the stack. +// +// Returns The handle to the stack. +func StackV2(scope *Scope, max_size tf.Output, elem_type tf.DataType, optional ...StackV2Attr) (handle tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"elem_type": elem_type} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "StackV2", + Input: []tf.Input{ + max_size, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // OrderedMapStageAttr is an optional argument to OrderedMapStage. type OrderedMapStageAttr func(optionalAttr) @@ -23218,6 +23259,8 @@ func TensorArraySizeV3(scope *Scope, handle tf.Output, flow_in tf.Output) (size } // Deprecated. Use TensorArrayGradV3 +// +// DEPRECATED at GraphDef version 26: Use TensorArrayWriteV3 func TensorArrayWriteV2(scope *Scope, handle tf.Output, index tf.Output, value tf.Output, flow_in tf.Output) (flow_out tf.Output) { if scope.Err() != nil { return @@ -23368,6 +23411,8 @@ func AsString(scope *Scope, input tf.Output, optional ...AsStringAttr) (output t } // Deprecated. Use TensorArrayScatterV3 +// +// DEPRECATED at GraphDef version 26: Use TensorArrayScatterV3 func TensorArrayScatterV2(scope *Scope, handle tf.Output, indices tf.Output, value tf.Output, flow_in tf.Output) (flow_out tf.Output) { if scope.Err() != nil { return @@ -23572,6 +23617,8 @@ func FractionalMaxPool(scope *Scope, value tf.Output, pooling_ratio []float32, o } // Deprecated. Use TensorArraySizeV3 +// +// DEPRECATED at GraphDef version 26: Use TensorArraySizeV3 func TensorArraySizeV2(scope *Scope, handle tf.Output, flow_in tf.Output) (size tf.Output) { if scope.Err() != nil { return @@ -25440,245 +25487,204 @@ func RightShift(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { return op.Output(0) } -// DecodeWavAttr is an optional argument to DecodeWav. -type DecodeWavAttr func(optionalAttr) - -// DecodeWavDesiredChannels sets the optional desired_channels attribute to value. -// -// value: Number of sample channels wanted. -// If not specified, defaults to -1 -func DecodeWavDesiredChannels(value int64) DecodeWavAttr { - return func(m optionalAttr) { - m["desired_channels"] = value - } -} - -// DecodeWavDesiredSamples sets the optional desired_samples attribute to value. -// -// value: Length of audio requested. -// If not specified, defaults to -1 -func DecodeWavDesiredSamples(value int64) DecodeWavAttr { - return func(m optionalAttr) { - m["desired_samples"] = value - } -} - -// Decode a 16-bit PCM WAV file to a float tensor. -// -// The -32768 to 32767 signed 16-bit values will be scaled to -1.0 to 1.0 in float. -// -// When desired_channels is set, if the input contains fewer channels than this -// then the last channel will be duplicated to give the requested number, else if -// the input has more channels than requested then the additional channels will be -// ignored. +// Adjust the hue of one or more images. // -// If desired_samples is set, then the audio will be cropped or padded with zeroes -// to the requested length. +// `images` is a tensor of at least 3 dimensions. The last dimension is +// interpretted as channels, and must be three. // -// The first output contains a Tensor with the content of the audio samples. The -// lowest dimension will be the number of channels, and the second will be the -// number of samples. For example, a ten-sample-long stereo WAV file should give an -// output shape of [10, 2]. +// The input image is considered in the RGB colorspace. Conceptually, the RGB +// colors are first mapped into HSV. A delta is then applied all the hue values, +// and then remapped back to RGB colorspace. // // Arguments: -// contents: The WAV-encoded audio, usually from a file. +// images: Images to adjust. At least 3-D. +// delta: A float delta to add to the hue. // -// Returns 2-D with shape `[length, channels]`.Scalar holding the sample rate found in the WAV header. -func DecodeWav(scope *Scope, contents tf.Output, optional ...DecodeWavAttr) (audio tf.Output, sample_rate tf.Output) { +// Returns The hue-adjusted image or images. +func AdjustHue(scope *Scope, images tf.Output, delta tf.Output) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "DecodeWav", + Type: "AdjustHue", Input: []tf.Input{ - contents, + images, delta, }, - Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) + return op.Output(0) } -// UniqueAttr is an optional argument to Unique. -type UniqueAttr func(optionalAttr) +// AvgPool3DGradAttr is an optional argument to AvgPool3DGrad. +type AvgPool3DGradAttr func(optionalAttr) -// UniqueOutIdx sets the optional out_idx attribute to value. -// If not specified, defaults to DT_INT32 -func UniqueOutIdx(value tf.DataType) UniqueAttr { +// AvgPool3DGradDataFormat sets the optional data_format attribute to value. +// +// value: The data format of the input and output data. With the +// default format "NDHWC", the data is stored in the order of: +// [batch, in_depth, in_height, in_width, in_channels]. +// Alternatively, the format could be "NCDHW", the data storage order is: +// [batch, in_channels, in_depth, in_height, in_width]. +// If not specified, defaults to "NDHWC" +func AvgPool3DGradDataFormat(value string) AvgPool3DGradAttr { return func(m optionalAttr) { - m["out_idx"] = value + m["data_format"] = value } } -// Finds unique elements in a 1-D tensor. -// -// This operation returns a tensor `y` containing all of the unique elements of `x` -// sorted in the same order that they occur in `x`. This operation also returns a -// tensor `idx` the same size as `x` that contains the index of each value of `x` -// in the unique output `y`. In other words: -// -// `y[idx[i]] = x[i] for i in [0, 1,...,rank(x) - 1]` -// -// For example: -// -// ``` -// # tensor 'x' is [1, 1, 2, 4, 4, 4, 7, 8, 8] -// y, idx = unique(x) -// y ==> [1, 2, 4, 7, 8] -// idx ==> [0, 0, 1, 2, 2, 2, 3, 4, 4] -// ``` +// Computes gradients of average pooling function. // // Arguments: -// x: 1-D. +// orig_input_shape: The original input dimensions. +// grad: Output backprop of shape `[batch, depth, rows, cols, channels]`. +// ksize: 1-D tensor of length 5. The size of the window for each dimension of +// the input tensor. Must have `ksize[0] = ksize[4] = 1`. +// strides: 1-D tensor of length 5. The stride of the sliding window for each +// dimension of `input`. Must have `strides[0] = strides[4] = 1`. +// padding: The type of padding algorithm to use. // -// Returns 1-D.1-D. -func Unique(scope *Scope, x tf.Output, optional ...UniqueAttr) (y tf.Output, idx tf.Output) { +// Returns The backprop for input. +func AvgPool3DGrad(scope *Scope, orig_input_shape tf.Output, grad tf.Output, ksize []int64, strides []int64, padding string, optional ...AvgPool3DGradAttr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} + attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "Unique", + Type: "AvgPool3DGrad", Input: []tf.Input{ - x, + orig_input_shape, grad, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) + return op.Output(0) } -// Concatenates a list of `N` tensors along the first dimension. -// -// The input tensors are all required to have size 1 in the first dimension. +// ParseSingleSequenceExampleAttr is an optional argument to ParseSingleSequenceExample. +type ParseSingleSequenceExampleAttr func(optionalAttr) + +// ParseSingleSequenceExampleContextSparseTypes sets the optional context_sparse_types attribute to value. // -// For example: -// -// ``` -// # 'x' is [[1, 4]] -// # 'y' is [[2, 5]] -// # 'z' is [[3, 6]] -// parallel_concat([x, y, z]) => [[1, 4], [2, 5], [3, 6]] # Pack along first dim. -// ``` -// -// The difference between concat and parallel_concat is that concat requires all -// of the inputs be computed before the operation will begin but doesn't require -// that the input shapes be known during graph construction. Parallel concat -// will copy pieces of the input into the output as they become available, in -// some situations this can provide a performance benefit. -// -// Arguments: -// values: Tensors to be concatenated. All must have size 1 in the first dimension -// and same shape. -// shape: the final shape of the result; should be equal to the shapes of any input -// but with the number of input values in the first dimension. +// value: A list of Ncontext_sparse types; the data types of data in +// each context Feature given in context_sparse_keys. +// Currently the ParseSingleSequenceExample supports DT_FLOAT (FloatList), +// DT_INT64 (Int64List), and DT_STRING (BytesList). +// If not specified, defaults to <> // -// Returns The concatenated tensor. -func ParallelConcat(scope *Scope, values []tf.Output, shape tf.Shape) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"shape": shape} - opspec := tf.OpSpec{ - Type: "ParallelConcat", - Input: []tf.Input{ - tf.OutputList(values), - }, - Attrs: attrs, +// REQUIRES: len(value) >= 0 +func ParseSingleSequenceExampleContextSparseTypes(value []tf.DataType) ParseSingleSequenceExampleAttr { + return func(m optionalAttr) { + m["context_sparse_types"] = value } - op := scope.AddOperation(opspec) - return op.Output(0) } -// Concatenates tensors along one dimension. -// -// Arguments: -// concat_dim: 0-D. The dimension along which to concatenate. Must be in the -// range [0, rank(values)). -// values: The `N` Tensors to concatenate. Their ranks and types must match, -// and their sizes must match in all dimensions except `concat_dim`. +// ParseSingleSequenceExampleFeatureListDenseTypes sets the optional feature_list_dense_types attribute to value. +// If not specified, defaults to <> // -// Returns A `Tensor` with the concatenation of values stacked along the -// `concat_dim` dimension. This tensor's shape matches that of `values` except -// in `concat_dim` where it has the sum of the sizes. -func Concat(scope *Scope, concat_dim tf.Output, values []tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Concat", - Input: []tf.Input{ - concat_dim, tf.OutputList(values), - }, +// REQUIRES: len(value) >= 0 +func ParseSingleSequenceExampleFeatureListDenseTypes(value []tf.DataType) ParseSingleSequenceExampleAttr { + return func(m optionalAttr) { + m["feature_list_dense_types"] = value } - op := scope.AddOperation(opspec) - return op.Output(0) } -// Compute the lower regularized incomplete Gamma function `Q(a, x)`. -// -// The lower regularized incomplete Gamma function is defined as: -// -// -// \\(P(a, x) = gamma(a, x) / Gamma(a) = 1 - Q(a, x)\\) +// ParseSingleSequenceExampleContextDenseShapes sets the optional context_dense_shapes attribute to value. // -// where +// value: A list of Ncontext_dense shapes; the shapes of data in +// each context Feature given in context_dense_keys. +// The number of elements in the Feature corresponding to context_dense_key[j] +// must always equal context_dense_shapes[j].NumEntries(). +// The shape of context_dense_values[j] will match context_dense_shapes[j]. +// If not specified, defaults to <> // -// \\(gamma(a, x) = int_{0}^{x} t^{a-1} exp(-t) dt\\) +// REQUIRES: len(value) >= 0 +func ParseSingleSequenceExampleContextDenseShapes(value []tf.Shape) ParseSingleSequenceExampleAttr { + return func(m optionalAttr) { + m["context_dense_shapes"] = value + } +} + +// ParseSingleSequenceExampleFeatureListSparseTypes sets the optional feature_list_sparse_types attribute to value. // -// is the lower incomplete Gamma function. +// value: A list of Nfeature_list_sparse types; the data types +// of data in each FeatureList given in feature_list_sparse_keys. +// Currently the ParseSingleSequenceExample supports DT_FLOAT (FloatList), +// DT_INT64 (Int64List), and DT_STRING (BytesList). +// If not specified, defaults to <> // -// Note, above `Q(a, x)` (`Igammac`) is the upper regularized complete -// Gamma function. -func Igamma(scope *Scope, a tf.Output, x tf.Output) (z tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Igamma", - Input: []tf.Input{ - a, x, - }, +// REQUIRES: len(value) >= 0 +func ParseSingleSequenceExampleFeatureListSparseTypes(value []tf.DataType) ParseSingleSequenceExampleAttr { + return func(m optionalAttr) { + m["feature_list_sparse_types"] = value } - op := scope.AddOperation(opspec) - return op.Output(0) } -// Computes offsets of concat inputs within its output. -// -// For example: +// ParseSingleSequenceExampleFeatureListDenseShapes sets the optional feature_list_dense_shapes attribute to value. // -// ``` -// # 'x' is [2, 2, 7] -// # 'y' is [2, 3, 7] -// # 'z' is [2, 5, 7] -// concat_offset(2, [x, y, z]) => [0, 0, 0], [0, 2, 0], [0, 5, 0] -// ``` +// value: A list of Nfeature_list_dense shapes; the shapes of +// data in each FeatureList given in feature_list_dense_keys. +// The shape of each Feature in the FeatureList corresponding to +// feature_list_dense_key[j] must always equal +// feature_list_dense_shapes[j].NumEntries(). +// If not specified, defaults to <> // -// This is typically used by gradient computations for a concat operation. +// REQUIRES: len(value) >= 0 +func ParseSingleSequenceExampleFeatureListDenseShapes(value []tf.Shape) ParseSingleSequenceExampleAttr { + return func(m optionalAttr) { + m["feature_list_dense_shapes"] = value + } +} + +// Transforms a scalar brain.SequenceExample proto (as strings) into typed tensors. // // Arguments: -// concat_dim: The dimension along which to concatenate. -// shape: The `N` int32 vectors representing shape of tensors being concatenated. -// -// Returns The `N` int32 vectors representing the starting offset -// of input tensors within the concatenated output. -func ConcatOffset(scope *Scope, concat_dim tf.Output, shape []tf.Output) (offset []tf.Output) { +// serialized: A scalar containing a binary serialized SequenceExample proto. +// feature_list_dense_missing_assumed_empty: A vector listing the +// FeatureList keys which may be missing from the SequenceExample. If the +// associated FeatureList is missing, it is treated as empty. By default, +// any FeatureList not listed in this vector must exist in the SequenceExample. +// context_sparse_keys: A list of Ncontext_sparse string Tensors (scalars). +// The keys expected in the Examples' features associated with context_sparse +// values. +// context_dense_keys: A list of Ncontext_dense string Tensors (scalars). +// The keys expected in the SequenceExamples' context features associated with +// dense values. +// feature_list_sparse_keys: A list of Nfeature_list_sparse string Tensors +// (scalars). The keys expected in the FeatureLists associated with sparse +// values. +// feature_list_dense_keys: A list of Nfeature_list_dense string Tensors (scalars). +// The keys expected in the SequenceExamples' feature_lists associated +// with lists of dense values. +// context_dense_defaults: A list of Ncontext_dense Tensors (some may be empty). +// context_dense_defaults[j] provides default values +// when the SequenceExample's context map lacks context_dense_key[j]. +// If an empty Tensor is provided for context_dense_defaults[j], +// then the Feature context_dense_keys[j] is required. +// The input type is inferred from context_dense_defaults[j], even when it's +// empty. If context_dense_defaults[j] is not empty, its shape must match +// context_dense_shapes[j]. +// debug_name: A scalar containing the name of the serialized proto. +// May contain, for example, table key (descriptive) name for the +// corresponding serialized proto. This is purely useful for debugging +// purposes, and the presence of values here has no effect on the output. +// May also be an empty scalar if no name is available. +func ParseSingleSequenceExample(scope *Scope, serialized tf.Output, feature_list_dense_missing_assumed_empty tf.Output, context_sparse_keys []tf.Output, context_dense_keys []tf.Output, feature_list_sparse_keys []tf.Output, feature_list_dense_keys []tf.Output, context_dense_defaults []tf.Output, debug_name tf.Output, optional ...ParseSingleSequenceExampleAttr) (context_sparse_indices []tf.Output, context_sparse_values []tf.Output, context_sparse_shapes []tf.Output, context_dense_values []tf.Output, feature_list_sparse_indices []tf.Output, feature_list_sparse_values []tf.Output, feature_list_sparse_shapes []tf.Output, feature_list_dense_values []tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "ConcatOffset", + Type: "ParseSingleSequenceExample", Input: []tf.Input{ - concat_dim, tf.OutputList(shape), + serialized, feature_list_dense_missing_assumed_empty, tf.OutputList(context_sparse_keys), tf.OutputList(context_dense_keys), tf.OutputList(feature_list_sparse_keys), tf.OutputList(feature_list_dense_keys), tf.OutputList(context_dense_defaults), debug_name, }, + Attrs: attrs, } op := scope.AddOperation(opspec) if scope.Err() != nil { @@ -25686,228 +25692,509 @@ func ConcatOffset(scope *Scope, concat_dim tf.Output, shape []tf.Output) (offset } var idx int var err error - if offset, idx, err = makeOutputList(op, idx, "offset"); err != nil { - scope.UpdateErr("ConcatOffset", err) + if context_sparse_indices, idx, err = makeOutputList(op, idx, "context_sparse_indices"); err != nil { + scope.UpdateErr("ParseSingleSequenceExample", err) return } - return offset -} - -// Splits a tensor into `num_split` tensors along one dimension. -// -// Arguments: -// axis: 0-D. The dimension along which to split. Must be in the range -// `[-rank(value), rank(value))`. -// value: The tensor to split. -// num_split: The number of ways to split. Must evenly divide -// `value.shape[split_dim]`. -// -// Returns They are identically shaped tensors, whose shape matches that of `value` -// except along `axis`, where their sizes are -// `values.shape[split_dim] / num_split`. -func Split(scope *Scope, axis tf.Output, value tf.Output, num_split int64) (output []tf.Output) { - if scope.Err() != nil { + if context_sparse_values, idx, err = makeOutputList(op, idx, "context_sparse_values"); err != nil { + scope.UpdateErr("ParseSingleSequenceExample", err) return } - attrs := map[string]interface{}{"num_split": num_split} - opspec := tf.OpSpec{ - Type: "Split", - Input: []tf.Input{ - axis, value, - }, - Attrs: attrs, + if context_sparse_shapes, idx, err = makeOutputList(op, idx, "context_sparse_shapes"); err != nil { + scope.UpdateErr("ParseSingleSequenceExample", err) + return } - op := scope.AddOperation(opspec) - if scope.Err() != nil { + if context_dense_values, idx, err = makeOutputList(op, idx, "context_dense_values"); err != nil { + scope.UpdateErr("ParseSingleSequenceExample", err) return } - var idx int - var err error - if output, idx, err = makeOutputList(op, idx, "output"); err != nil { - scope.UpdateErr("Split", err) + if feature_list_sparse_indices, idx, err = makeOutputList(op, idx, "feature_list_sparse_indices"); err != nil { + scope.UpdateErr("ParseSingleSequenceExample", err) return } - return output + if feature_list_sparse_values, idx, err = makeOutputList(op, idx, "feature_list_sparse_values"); err != nil { + scope.UpdateErr("ParseSingleSequenceExample", err) + return + } + if feature_list_sparse_shapes, idx, err = makeOutputList(op, idx, "feature_list_sparse_shapes"); err != nil { + scope.UpdateErr("ParseSingleSequenceExample", err) + return + } + if feature_list_dense_values, idx, err = makeOutputList(op, idx, "feature_list_dense_values"); err != nil { + scope.UpdateErr("ParseSingleSequenceExample", err) + return + } + return context_sparse_indices, context_sparse_values, context_sparse_shapes, context_dense_values, feature_list_sparse_indices, feature_list_sparse_values, feature_list_sparse_shapes, feature_list_dense_values } -// Splits a tensor into `num_split` tensors along one dimension. +// DecodeWavAttr is an optional argument to DecodeWav. +type DecodeWavAttr func(optionalAttr) + +// DecodeWavDesiredChannels sets the optional desired_channels attribute to value. // -// Arguments: -// value: The tensor to split. -// size_splits: list containing the sizes of each output tensor along the split -// dimension. Must sum to the dimension of value along split_dim. -// Can contain one -1 indicating that dimension is to be inferred. -// axis: 0-D. The dimension along which to split. Must be in the range -// `[-rank(value), rank(value))`. +// value: Number of sample channels wanted. +// If not specified, defaults to -1 +func DecodeWavDesiredChannels(value int64) DecodeWavAttr { + return func(m optionalAttr) { + m["desired_channels"] = value + } +} + +// DecodeWavDesiredSamples sets the optional desired_samples attribute to value. +// +// value: Length of audio requested. +// If not specified, defaults to -1 +func DecodeWavDesiredSamples(value int64) DecodeWavAttr { + return func(m optionalAttr) { + m["desired_samples"] = value + } +} + +// Decode a 16-bit PCM WAV file to a float tensor. // +// The -32768 to 32767 signed 16-bit values will be scaled to -1.0 to 1.0 in float. // -// Returns Tensors whose shape matches that of `value` -// except along `axis`, where their sizes are -// `size_splits[i]`. -func SplitV(scope *Scope, value tf.Output, size_splits tf.Output, axis tf.Output, num_split int64) (output []tf.Output) { +// When desired_channels is set, if the input contains fewer channels than this +// then the last channel will be duplicated to give the requested number, else if +// the input has more channels than requested then the additional channels will be +// ignored. +// +// If desired_samples is set, then the audio will be cropped or padded with zeroes +// to the requested length. +// +// The first output contains a Tensor with the content of the audio samples. The +// lowest dimension will be the number of channels, and the second will be the +// number of samples. For example, a ten-sample-long stereo WAV file should give an +// output shape of [10, 2]. +// +// Arguments: +// contents: The WAV-encoded audio, usually from a file. +// +// Returns 2-D with shape `[length, channels]`.Scalar holding the sample rate found in the WAV header. +func DecodeWav(scope *Scope, contents tf.Output, optional ...DecodeWavAttr) (audio tf.Output, sample_rate tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"num_split": num_split} + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "SplitV", + Type: "DecodeWav", Input: []tf.Input{ - value, size_splits, axis, + contents, }, Attrs: attrs, } op := scope.AddOperation(opspec) - if scope.Err() != nil { - return - } - var idx int - var err error - if output, idx, err = makeOutputList(op, idx, "output"); err != nil { - scope.UpdateErr("SplitV", err) - return + return op.Output(0), op.Output(1) +} + +// UniqueAttr is an optional argument to Unique. +type UniqueAttr func(optionalAttr) + +// UniqueOutIdx sets the optional out_idx attribute to value. +// If not specified, defaults to DT_INT32 +func UniqueOutIdx(value tf.DataType) UniqueAttr { + return func(m optionalAttr) { + m["out_idx"] = value } - return output } -// Gives a guarantee to the TF runtime that the input tensor is a constant. +// Finds unique elements in a 1-D tensor. // -// The runtime is then free to make optimizations based on this. +// This operation returns a tensor `y` containing all of the unique elements of `x` +// sorted in the same order that they occur in `x`. This operation also returns a +// tensor `idx` the same size as `x` that contains the index of each value of `x` +// in the unique output `y`. In other words: // -// Only accepts value typed tensors as inputs and rejects resource variable handles -// as input. +// `y[idx[i]] = x[i] for i in [0, 1,...,rank(x) - 1]` // -// Returns the input tensor without modification. -func GuaranteeConst(scope *Scope, input tf.Output) (output tf.Output) { +// For example: +// +// ``` +// # tensor 'x' is [1, 1, 2, 4, 4, 4, 7, 8, 8] +// y, idx = unique(x) +// y ==> [1, 2, 4, 7, 8] +// idx ==> [0, 0, 1, 2, 2, 2, 3, 4, 4] +// ``` +// +// Arguments: +// x: 1-D. +// +// Returns 1-D.1-D. +func Unique(scope *Scope, x tf.Output, optional ...UniqueAttr) (y tf.Output, idx tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "GuaranteeConst", + Type: "Unique", Input: []tf.Input{ - input, + x, }, + Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1) } -// Returns a tensor of zeros with the same shape and type as x. +// Concatenates a list of `N` tensors along the first dimension. +// +// The input tensors are all required to have size 1 in the first dimension. +// +// For example: +// +// ``` +// # 'x' is [[1, 4]] +// # 'y' is [[2, 5]] +// # 'z' is [[3, 6]] +// parallel_concat([x, y, z]) => [[1, 4], [2, 5], [3, 6]] # Pack along first dim. +// ``` +// +// The difference between concat and parallel_concat is that concat requires all +// of the inputs be computed before the operation will begin but doesn't require +// that the input shapes be known during graph construction. Parallel concat +// will copy pieces of the input into the output as they become available, in +// some situations this can provide a performance benefit. // // Arguments: -// x: a tensor of type T. +// values: Tensors to be concatenated. All must have size 1 in the first dimension +// and same shape. +// shape: the final shape of the result; should be equal to the shapes of any input +// but with the number of input values in the first dimension. // -// Returns a tensor of the same shape and type as x but filled with zeros. -func ZerosLike(scope *Scope, x tf.Output) (y tf.Output) { +// Returns The concatenated tensor. +func ParallelConcat(scope *Scope, values []tf.Output, shape tf.Shape) (output tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"shape": shape} opspec := tf.OpSpec{ - Type: "ZerosLike", + Type: "ParallelConcat", Input: []tf.Input{ - x, + tf.OutputList(values), }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Flips all bits elementwise. +// Concatenates tensors along one dimension. // -// The result will have exactly those bits set, that are not set in `x`. The -// computation is performed on the underlying representation of x. -func Invert(scope *Scope, x tf.Output) (y tf.Output) { +// Arguments: +// concat_dim: 0-D. The dimension along which to concatenate. Must be in the +// range [0, rank(values)). +// values: The `N` Tensors to concatenate. Their ranks and types must match, +// and their sizes must match in all dimensions except `concat_dim`. +// +// Returns A `Tensor` with the concatenation of values stacked along the +// `concat_dim` dimension. This tensor's shape matches that of `values` except +// in `concat_dim` where it has the sum of the sizes. +func Concat(scope *Scope, concat_dim tf.Output, values []tf.Output) (output tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "Invert", + Type: "Concat", Input: []tf.Input{ - x, + concat_dim, tf.OutputList(values), }, } op := scope.AddOperation(opspec) return op.Output(0) } -// DequantizeAttr is an optional argument to Dequantize. -type DequantizeAttr func(optionalAttr) - -// DequantizeMode sets the optional mode attribute to value. -// If not specified, defaults to "MIN_COMBINED" -func DequantizeMode(value string) DequantizeAttr { - return func(m optionalAttr) { - m["mode"] = value - } -} - -// Dequantize the 'input' tensor into a float Tensor. -// -// [min_range, max_range] are scalar floats that specify the range for -// the 'input' data. The 'mode' attribute controls exactly which calculations are -// used to convert the float values to their quantized equivalents. -// -// In 'MIN_COMBINED' mode, each value of the tensor will undergo the following: -// -// ``` -// if T == qint8, in[i] += (range(T) + 1)/ 2.0 -// out[i] = min_range + (in[i]* (max_range - min_range) / range(T)) -// ``` -// here `range(T) = numeric_limits::max() - numeric_limits::min()` -// -// *MIN_COMBINED Mode Example* +// Compute the lower regularized incomplete Gamma function `Q(a, x)`. // -// If the input comes from a QuantizedRelu6, the output type is -// quint8 (range of 0-255) but the possible range of QuantizedRelu6 is -// 0-6. The min_range and max_range values are therefore 0.0 and 6.0. -// Dequantize on quint8 will take each value, cast to float, and multiply -// by 6 / 255. -// Note that if quantizedtype is qint8, the operation will additionally add -// each value by 128 prior to casting. +// The lower regularized incomplete Gamma function is defined as: // -// If the mode is 'MIN_FIRST', then this approach is used: // -// ```c++ -// num_discrete_values = 1 << (# of bits in T) -// range_adjust = num_discrete_values / (num_discrete_values - 1) -// range = (range_max - range_min) * range_adjust -// range_scale = range / num_discrete_values -// const double offset_input = static_cast(input) - lowest_quantized; -// result = range_min + ((input - numeric_limits::min()) * range_scale) -// ``` +// \\(P(a, x) = gamma(a, x) / Gamma(a) = 1 - Q(a, x)\\) // -// *SCALED mode Example* +// where // -// `SCALED` mode matches the quantization approach used in -// `QuantizeAndDequantize{V2|V3}`. +// \\(gamma(a, x) = int_{0}^{x} t^{a-1} exp(-t) dt\\) // -// If the mode is `SCALED`, we do not use the full range of the output type, -// choosing to elide the lowest possible value for symmetry (e.g., output range is -// -127 to 127, not -128 to 127 for signed 8 bit quantization), so that 0.0 maps to -// 0. +// is the lower incomplete Gamma function. // -// We first find the range of values in our tensor. The -// range we use is always centered on 0, so we find m such that -// ```c++ -// m = max(abs(input_min), abs(input_max)) -// ``` +// Note, above `Q(a, x)` (`Igammac`) is the upper regularized complete +// Gamma function. +func Igamma(scope *Scope, a tf.Output, x tf.Output) (z tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Igamma", + Input: []tf.Input{ + a, x, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Computes offsets of concat inputs within its output. // -// Our input tensor range is then `[-m, m]`. +// For example: // -// Next, we choose our fixed-point quantization buckets, `[min_fixed, max_fixed]`. -// If T is signed, this is // ``` -// num_bits = sizeof(T) * 8 -// [min_fixed, max_fixed] = -// [-(1 << (num_bits - 1) - 1), (1 << (num_bits - 1)) - 1] +// # 'x' is [2, 2, 7] +// # 'y' is [2, 3, 7] +// # 'z' is [2, 5, 7] +// concat_offset(2, [x, y, z]) => [0, 0, 0], [0, 2, 0], [0, 5, 0] // ``` // -// Otherwise, if T is unsigned, the fixed-point range is -// ``` -// [min_fixed, max_fixed] = [0, (1 << num_bits) - 1] -// ``` +// This is typically used by gradient computations for a concat operation. // -// From this we compute our scaling factor, s: -// ```c++ +// Arguments: +// concat_dim: The dimension along which to concatenate. +// shape: The `N` int32 vectors representing shape of tensors being concatenated. +// +// Returns The `N` int32 vectors representing the starting offset +// of input tensors within the concatenated output. +func ConcatOffset(scope *Scope, concat_dim tf.Output, shape []tf.Output) (offset []tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "ConcatOffset", + Input: []tf.Input{ + concat_dim, tf.OutputList(shape), + }, + } + op := scope.AddOperation(opspec) + if scope.Err() != nil { + return + } + var idx int + var err error + if offset, idx, err = makeOutputList(op, idx, "offset"); err != nil { + scope.UpdateErr("ConcatOffset", err) + return + } + return offset +} + +// Splits a tensor into `num_split` tensors along one dimension. +// +// Arguments: +// axis: 0-D. The dimension along which to split. Must be in the range +// `[-rank(value), rank(value))`. +// value: The tensor to split. +// num_split: The number of ways to split. Must evenly divide +// `value.shape[split_dim]`. +// +// Returns They are identically shaped tensors, whose shape matches that of `value` +// except along `axis`, where their sizes are +// `values.shape[split_dim] / num_split`. +func Split(scope *Scope, axis tf.Output, value tf.Output, num_split int64) (output []tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"num_split": num_split} + opspec := tf.OpSpec{ + Type: "Split", + Input: []tf.Input{ + axis, value, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + if scope.Err() != nil { + return + } + var idx int + var err error + if output, idx, err = makeOutputList(op, idx, "output"); err != nil { + scope.UpdateErr("Split", err) + return + } + return output +} + +// Splits a tensor into `num_split` tensors along one dimension. +// +// Arguments: +// value: The tensor to split. +// size_splits: list containing the sizes of each output tensor along the split +// dimension. Must sum to the dimension of value along split_dim. +// Can contain one -1 indicating that dimension is to be inferred. +// axis: 0-D. The dimension along which to split. Must be in the range +// `[-rank(value), rank(value))`. +// +// +// Returns Tensors whose shape matches that of `value` +// except along `axis`, where their sizes are +// `size_splits[i]`. +func SplitV(scope *Scope, value tf.Output, size_splits tf.Output, axis tf.Output, num_split int64) (output []tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"num_split": num_split} + opspec := tf.OpSpec{ + Type: "SplitV", + Input: []tf.Input{ + value, size_splits, axis, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + if scope.Err() != nil { + return + } + var idx int + var err error + if output, idx, err = makeOutputList(op, idx, "output"); err != nil { + scope.UpdateErr("SplitV", err) + return + } + return output +} + +// Gives a guarantee to the TF runtime that the input tensor is a constant. +// +// The runtime is then free to make optimizations based on this. +// +// Only accepts value typed tensors as inputs and rejects resource variable handles +// as input. +// +// Returns the input tensor without modification. +func GuaranteeConst(scope *Scope, input tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "GuaranteeConst", + Input: []tf.Input{ + input, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Returns a tensor of zeros with the same shape and type as x. +// +// Arguments: +// x: a tensor of type T. +// +// Returns a tensor of the same shape and type as x but filled with zeros. +func ZerosLike(scope *Scope, x tf.Output) (y tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "ZerosLike", + Input: []tf.Input{ + x, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Flips all bits elementwise. +// +// The result will have exactly those bits set, that are not set in `x`. The +// computation is performed on the underlying representation of x. +func Invert(scope *Scope, x tf.Output) (y tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Invert", + Input: []tf.Input{ + x, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// DequantizeAttr is an optional argument to Dequantize. +type DequantizeAttr func(optionalAttr) + +// DequantizeMode sets the optional mode attribute to value. +// If not specified, defaults to "MIN_COMBINED" +func DequantizeMode(value string) DequantizeAttr { + return func(m optionalAttr) { + m["mode"] = value + } +} + +// Dequantize the 'input' tensor into a float Tensor. +// +// [min_range, max_range] are scalar floats that specify the range for +// the 'input' data. The 'mode' attribute controls exactly which calculations are +// used to convert the float values to their quantized equivalents. +// +// In 'MIN_COMBINED' mode, each value of the tensor will undergo the following: +// +// ``` +// if T == qint8, in[i] += (range(T) + 1)/ 2.0 +// out[i] = min_range + (in[i]* (max_range - min_range) / range(T)) +// ``` +// here `range(T) = numeric_limits::max() - numeric_limits::min()` +// +// *MIN_COMBINED Mode Example* +// +// If the input comes from a QuantizedRelu6, the output type is +// quint8 (range of 0-255) but the possible range of QuantizedRelu6 is +// 0-6. The min_range and max_range values are therefore 0.0 and 6.0. +// Dequantize on quint8 will take each value, cast to float, and multiply +// by 6 / 255. +// Note that if quantizedtype is qint8, the operation will additionally add +// each value by 128 prior to casting. +// +// If the mode is 'MIN_FIRST', then this approach is used: +// +// ```c++ +// num_discrete_values = 1 << (# of bits in T) +// range_adjust = num_discrete_values / (num_discrete_values - 1) +// range = (range_max - range_min) * range_adjust +// range_scale = range / num_discrete_values +// const double offset_input = static_cast(input) - lowest_quantized; +// result = range_min + ((input - numeric_limits::min()) * range_scale) +// ``` +// +// *SCALED mode Example* +// +// `SCALED` mode matches the quantization approach used in +// `QuantizeAndDequantize{V2|V3}`. +// +// If the mode is `SCALED`, we do not use the full range of the output type, +// choosing to elide the lowest possible value for symmetry (e.g., output range is +// -127 to 127, not -128 to 127 for signed 8 bit quantization), so that 0.0 maps to +// 0. +// +// We first find the range of values in our tensor. The +// range we use is always centered on 0, so we find m such that +// ```c++ +// m = max(abs(input_min), abs(input_max)) +// ``` +// +// Our input tensor range is then `[-m, m]`. +// +// Next, we choose our fixed-point quantization buckets, `[min_fixed, max_fixed]`. +// If T is signed, this is +// ``` +// num_bits = sizeof(T) * 8 +// [min_fixed, max_fixed] = +// [-(1 << (num_bits - 1) - 1), (1 << (num_bits - 1)) - 1] +// ``` +// +// Otherwise, if T is unsigned, the fixed-point range is +// ``` +// [min_fixed, max_fixed] = [0, (1 << num_bits) - 1] +// ``` +// +// From this we compute our scaling factor, s: +// ```c++ // s = (2 * m) / (max_fixed - min_fixed) // ``` // @@ -27805,535 +28092,266 @@ func QuantizeAndDequantizeRangeGiven(value bool) QuantizeAndDequantizeAttr { // QuantizeAndDequantizeInputMin sets the optional input_min attribute to value. // If not specified, defaults to 0 func QuantizeAndDequantizeInputMin(value float32) QuantizeAndDequantizeAttr { - return func(m optionalAttr) { - m["input_min"] = value - } -} - -// QuantizeAndDequantizeInputMax sets the optional input_max attribute to value. -// If not specified, defaults to 0 -func QuantizeAndDequantizeInputMax(value float32) QuantizeAndDequantizeAttr { - return func(m optionalAttr) { - m["input_max"] = value - } -} - -// Use QuantizeAndDequantizeV2 instead. -// -// DEPRECATED at GraphDef version 22: Replaced by QuantizeAndDequantizeV2 -func QuantizeAndDequantize(scope *Scope, input tf.Output, optional ...QuantizeAndDequantizeAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "QuantizeAndDequantize", - Input: []tf.Input{ - input, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Returns the diagonal part of the tensor. -// -// This operation returns a tensor with the `diagonal` part -// of the `input`. The `diagonal` part is computed as follows: -// -// Assume `input` has dimensions `[D1,..., Dk, D1,..., Dk]`, then the output is a -// tensor of rank `k` with dimensions `[D1,..., Dk]` where: -// -// `diagonal[i1,..., ik] = input[i1, ..., ik, i1,..., ik]`. -// -// For example: -// -// ``` -// # 'input' is [[1, 0, 0, 0] -// [0, 2, 0, 0] -// [0, 0, 3, 0] -// [0, 0, 0, 4]] -// -// tf.diag_part(input) ==> [1, 2, 3, 4] -// ``` -// -// Arguments: -// input: Rank k tensor where k is even and not zero. -// -// Returns The extracted diagonal. -func DiagPart(scope *Scope, input tf.Output) (diagonal tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "DiagPart", - Input: []tf.Input{ - input, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// QuantizedInstanceNormAttr is an optional argument to QuantizedInstanceNorm. -type QuantizedInstanceNormAttr func(optionalAttr) - -// QuantizedInstanceNormOutputRangeGiven sets the optional output_range_given attribute to value. -// -// value: If True, `given_y_min` and `given_y_min` -// and `given_y_max` are used as the output range. Otherwise, -// the implementation computes the output range. -// If not specified, defaults to false -func QuantizedInstanceNormOutputRangeGiven(value bool) QuantizedInstanceNormAttr { - return func(m optionalAttr) { - m["output_range_given"] = value - } -} - -// QuantizedInstanceNormGivenYMin sets the optional given_y_min attribute to value. -// -// value: Output in `y_min` if `output_range_given` is True. -// If not specified, defaults to 0 -func QuantizedInstanceNormGivenYMin(value float32) QuantizedInstanceNormAttr { - return func(m optionalAttr) { - m["given_y_min"] = value - } -} - -// QuantizedInstanceNormGivenYMax sets the optional given_y_max attribute to value. -// -// value: Output in `y_max` if `output_range_given` is True. -// If not specified, defaults to 0 -func QuantizedInstanceNormGivenYMax(value float32) QuantizedInstanceNormAttr { - return func(m optionalAttr) { - m["given_y_max"] = value - } -} - -// QuantizedInstanceNormVarianceEpsilon sets the optional variance_epsilon attribute to value. -// -// value: A small float number to avoid dividing by 0. -// If not specified, defaults to 1e-05 -func QuantizedInstanceNormVarianceEpsilon(value float32) QuantizedInstanceNormAttr { - return func(m optionalAttr) { - m["variance_epsilon"] = value - } -} - -// QuantizedInstanceNormMinSeparation sets the optional min_separation attribute to value. -// -// value: Minimum value of `y_max - y_min` -// If not specified, defaults to 0.001 -func QuantizedInstanceNormMinSeparation(value float32) QuantizedInstanceNormAttr { - return func(m optionalAttr) { - m["min_separation"] = value - } -} - -// Quantized Instance normalization. -// -// Arguments: -// x: A 4D input Tensor. -// x_min: The value represented by the lowest quantized input. -// x_max: The value represented by the highest quantized input. -// -// Returns A 4D Tensor.The value represented by the lowest quantized output.The value represented by the highest quantized output. -func QuantizedInstanceNorm(scope *Scope, x tf.Output, x_min tf.Output, x_max tf.Output, optional ...QuantizedInstanceNormAttr) (y tf.Output, y_min tf.Output, y_max tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "QuantizedInstanceNorm", - Input: []tf.Input{ - x, x_min, x_max, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) -} - -// FakeQuantWithMinMaxVarsAttr is an optional argument to FakeQuantWithMinMaxVars. -type FakeQuantWithMinMaxVarsAttr func(optionalAttr) - -// FakeQuantWithMinMaxVarsNumBits sets the optional num_bits attribute to value. -// If not specified, defaults to 8 -func FakeQuantWithMinMaxVarsNumBits(value int64) FakeQuantWithMinMaxVarsAttr { - return func(m optionalAttr) { - m["num_bits"] = value - } -} - -// FakeQuantWithMinMaxVarsNarrowRange sets the optional narrow_range attribute to value. -// If not specified, defaults to false -func FakeQuantWithMinMaxVarsNarrowRange(value bool) FakeQuantWithMinMaxVarsAttr { - return func(m optionalAttr) { - m["narrow_range"] = value - } -} - -// Fake-quantize the 'inputs' tensor of type float via global float scalars `min` -// -// and `max` to 'outputs' tensor of same shape as `inputs`. -// -// `[min; max]` define the clamping range for the `inputs` data. -// `inputs` values are quantized into the quantization range (`[0; 2^num_bits - 1]` -// when `narrow_range` is false and `[1; 2^num_bits - 1]` when it is true) and -// then de-quantized and output as floats in `[min; max]` interval. -// `num_bits` is the bitwidth of the quantization; between 2 and 8, inclusive. -// -// This operation has a gradient and thus allows for training `min` and `max` -// values. -func FakeQuantWithMinMaxVars(scope *Scope, inputs tf.Output, min tf.Output, max tf.Output, optional ...FakeQuantWithMinMaxVarsAttr) (outputs tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "FakeQuantWithMinMaxVars", - Input: []tf.Input{ - inputs, min, max, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// FakeQuantWithMinMaxVarsPerChannelGradientAttr is an optional argument to FakeQuantWithMinMaxVarsPerChannelGradient. -type FakeQuantWithMinMaxVarsPerChannelGradientAttr func(optionalAttr) - -// FakeQuantWithMinMaxVarsPerChannelGradientNumBits sets the optional num_bits attribute to value. -// -// value: The bitwidth of the quantization; between 2 and 8, inclusive. -// If not specified, defaults to 8 -func FakeQuantWithMinMaxVarsPerChannelGradientNumBits(value int64) FakeQuantWithMinMaxVarsPerChannelGradientAttr { - return func(m optionalAttr) { - m["num_bits"] = value - } -} - -// FakeQuantWithMinMaxVarsPerChannelGradientNarrowRange sets the optional narrow_range attribute to value. -// -// value: Whether to quantize into 2^num_bits - 1 distinct values. -// If not specified, defaults to false -func FakeQuantWithMinMaxVarsPerChannelGradientNarrowRange(value bool) FakeQuantWithMinMaxVarsPerChannelGradientAttr { - return func(m optionalAttr) { - m["narrow_range"] = value - } -} - -// Compute gradients for a FakeQuantWithMinMaxVarsPerChannel operation. -// -// Arguments: -// gradients: Backpropagated gradients above the FakeQuantWithMinMaxVars operation, -// shape one of: `[d]`, `[b, d]`, `[b, h, w, d]`. -// inputs: Values passed as inputs to the FakeQuantWithMinMaxVars operation, shape -// same as `gradients`. -// min, max: Quantization interval, floats of shape `[d]`. -// -// -// -// Returns Backpropagated gradients w.r.t. inputs, shape same as -// `inputs`: -// `gradients * (inputs >= min && inputs <= max)`.Backpropagated gradients w.r.t. min parameter, shape `[d]`: -// `sum_per_d(gradients * (inputs < min))`.Backpropagated gradients w.r.t. max parameter, shape `[d]`: -// `sum_per_d(gradients * (inputs > max))`. -func FakeQuantWithMinMaxVarsPerChannelGradient(scope *Scope, gradients tf.Output, inputs tf.Output, min tf.Output, max tf.Output, optional ...FakeQuantWithMinMaxVarsPerChannelGradientAttr) (backprops_wrt_input tf.Output, backprop_wrt_min tf.Output, backprop_wrt_max tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "FakeQuantWithMinMaxVarsPerChannelGradient", - Input: []tf.Input{ - gradients, inputs, min, max, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) -} - -// QuantizeV2Attr is an optional argument to QuantizeV2. -type QuantizeV2Attr func(optionalAttr) - -// QuantizeV2Mode sets the optional mode attribute to value. -// If not specified, defaults to "MIN_COMBINED" -func QuantizeV2Mode(value string) QuantizeV2Attr { - return func(m optionalAttr) { - m["mode"] = value - } -} - -// QuantizeV2RoundMode sets the optional round_mode attribute to value. -// If not specified, defaults to "HALF_AWAY_FROM_ZERO" -func QuantizeV2RoundMode(value string) QuantizeV2Attr { - return func(m optionalAttr) { - m["round_mode"] = value - } -} - -// Quantize the 'input' tensor of type float to 'output' tensor of type 'T'. -// -// [min_range, max_range] are scalar floats that specify the range for -// the 'input' data. The 'mode' attribute controls exactly which calculations are -// used to convert the float values to their quantized equivalents. The -// 'round_mode' attribute controls which rounding tie-breaking algorithm is used -// when rounding float values to their quantized equivalents. -// -// In 'MIN_COMBINED' mode, each value of the tensor will undergo the following: -// -// ``` -// out[i] = (in[i] - min_range) * range(T) / (max_range - min_range) -// if T == qint8, out[i] -= (range(T) + 1) / 2.0 -// ``` -// here `range(T) = numeric_limits::max() - numeric_limits::min()` -// -// *MIN_COMBINED Mode Example* -// -// Assume the input is type float and has a possible range of [0.0, 6.0] and the -// output type is quint8 ([0, 255]). The min_range and max_range values should be -// specified as 0.0 and 6.0. Quantizing from float to quint8 will multiply each -// value of the input by 255/6 and cast to quint8. -// -// If the output type was qint8 ([-128, 127]), the operation will additionally -// subtract each value by 128 prior to casting, so that the range of values aligns -// with the range of qint8. -// -// If the mode is 'MIN_FIRST', then this approach is used: -// -// ``` -// num_discrete_values = 1 << (# of bits in T) -// range_adjust = num_discrete_values / (num_discrete_values - 1) -// range = (range_max - range_min) * range_adjust -// range_scale = num_discrete_values / range -// quantized = round(input * range_scale) - round(range_min * range_scale) + -// numeric_limits::min() -// quantized = max(quantized, numeric_limits::min()) -// quantized = min(quantized, numeric_limits::max()) -// ``` -// -// The biggest difference between this and MIN_COMBINED is that the minimum range -// is rounded first, before it's subtracted from the rounded value. With -// MIN_COMBINED, a small bias is introduced where repeated iterations of quantizing -// and dequantizing will introduce a larger and larger error. -// -// *SCALED mode Example* -// -// `SCALED` mode matches the quantization approach used in -// `QuantizeAndDequantize{V2|V3}`. -// -// If the mode is `SCALED`, we do not use the full range of the output type, -// choosing to elide the lowest possible value for symmetry (e.g., output range is -// -127 to 127, not -128 to 127 for signed 8 bit quantization), so that 0.0 maps to -// 0. -// -// We first find the range of values in our tensor. The -// range we use is always centered on 0, so we find m such that -// ```c++ -// m = max(abs(input_min), abs(input_max)) -// ``` -// -// Our input tensor range is then `[-m, m]`. -// -// Next, we choose our fixed-point quantization buckets, `[min_fixed, max_fixed]`. -// If T is signed, this is -// ``` -// num_bits = sizeof(T) * 8 -// [min_fixed, max_fixed] = -// [-(1 << (num_bits - 1) - 1), (1 << (num_bits - 1)) - 1] -// ``` -// -// Otherwise, if T is unsigned, the fixed-point range is -// ``` -// [min_fixed, max_fixed] = [0, (1 << num_bits) - 1] -// ``` -// -// From this we compute our scaling factor, s: -// ```c++ -// s = (max_fixed - min_fixed) / (2 * m) -// ``` -// -// Now we can quantize the elements of our tensor: -// ```c++ -// result = round(input * s) -// ``` -// -// One thing to watch out for is that the operator may choose to adjust the -// requested minimum and maximum values slightly during the quantization process, -// so you should always use the output ports as the range for further calculations. -// For example, if the requested minimum and maximum values are close to equal, -// they will be separated by a small epsilon value to prevent ill-formed quantized -// buffers from being created. Otherwise, you can end up with buffers where all the -// quantized values map to the same float value, which causes problems for -// operations that have to perform further calculations on them. -// -// Arguments: -// -// min_range: The minimum scalar value possibly produced for the input. -// max_range: The maximum scalar value possibly produced for the input. -// + return func(m optionalAttr) { + m["input_min"] = value + } +} + +// QuantizeAndDequantizeInputMax sets the optional input_max attribute to value. +// If not specified, defaults to 0 +func QuantizeAndDequantizeInputMax(value float32) QuantizeAndDequantizeAttr { + return func(m optionalAttr) { + m["input_max"] = value + } +} + +// Use QuantizeAndDequantizeV2 instead. // -// Returns The quantized data produced from the float input.The actual minimum scalar value used for the output.The actual maximum scalar value used for the output. -func QuantizeV2(scope *Scope, input tf.Output, min_range tf.Output, max_range tf.Output, T tf.DataType, optional ...QuantizeV2Attr) (output tf.Output, output_min tf.Output, output_max tf.Output) { +// DEPRECATED at GraphDef version 22: Replaced by QuantizeAndDequantizeV2 +func QuantizeAndDequantize(scope *Scope, input tf.Output, optional ...QuantizeAndDequantizeAttr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"T": T} + attrs := map[string]interface{}{} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "QuantizeV2", + Type: "QuantizeAndDequantize", Input: []tf.Input{ - input, min_range, max_range, + input, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) + return op.Output(0) } -// Flushes the writer's unwritten events. +// Returns the diagonal part of the tensor. +// +// This operation returns a tensor with the `diagonal` part +// of the `input`. The `diagonal` part is computed as follows: +// +// Assume `input` has dimensions `[D1,..., Dk, D1,..., Dk]`, then the output is a +// tensor of rank `k` with dimensions `[D1,..., Dk]` where: +// +// `diagonal[i1,..., ik] = input[i1, ..., ik, i1,..., ik]`. +// +// For example: +// +// ``` +// # 'input' is [[1, 0, 0, 0] +// [0, 2, 0, 0] +// [0, 0, 3, 0] +// [0, 0, 0, 4]] +// +// tf.diag_part(input) ==> [1, 2, 3, 4] +// ``` // // Arguments: -// writer: A handle to the summary writer resource. +// input: Rank k tensor where k is even and not zero. // -// Returns the created operation. -func FlushSummaryWriter(scope *Scope, writer tf.Output) (o *tf.Operation) { +// Returns The extracted diagonal. +func DiagPart(scope *Scope, input tf.Output) (diagonal tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "FlushSummaryWriter", + Type: "DiagPart", Input: []tf.Input{ - writer, + input, }, } - return scope.AddOperation(opspec) + op := scope.AddOperation(opspec) + return op.Output(0) } -// StackV2Attr is an optional argument to StackV2. -type StackV2Attr func(optionalAttr) +// QuantizedInstanceNormAttr is an optional argument to QuantizedInstanceNorm. +type QuantizedInstanceNormAttr func(optionalAttr) -// StackV2StackName sets the optional stack_name attribute to value. +// QuantizedInstanceNormOutputRangeGiven sets the optional output_range_given attribute to value. // -// value: Overrides the name used for the temporary stack resource. Default -// value is the name of the 'Stack' op (which is guaranteed unique). -// If not specified, defaults to "" -func StackV2StackName(value string) StackV2Attr { +// value: If True, `given_y_min` and `given_y_min` +// and `given_y_max` are used as the output range. Otherwise, +// the implementation computes the output range. +// If not specified, defaults to false +func QuantizedInstanceNormOutputRangeGiven(value bool) QuantizedInstanceNormAttr { return func(m optionalAttr) { - m["stack_name"] = value + m["output_range_given"] = value } } -// A stack that produces elements in first-in last-out order. +// QuantizedInstanceNormGivenYMin sets the optional given_y_min attribute to value. +// +// value: Output in `y_min` if `output_range_given` is True. +// If not specified, defaults to 0 +func QuantizedInstanceNormGivenYMin(value float32) QuantizedInstanceNormAttr { + return func(m optionalAttr) { + m["given_y_min"] = value + } +} + +// QuantizedInstanceNormGivenYMax sets the optional given_y_max attribute to value. +// +// value: Output in `y_max` if `output_range_given` is True. +// If not specified, defaults to 0 +func QuantizedInstanceNormGivenYMax(value float32) QuantizedInstanceNormAttr { + return func(m optionalAttr) { + m["given_y_max"] = value + } +} + +// QuantizedInstanceNormVarianceEpsilon sets the optional variance_epsilon attribute to value. +// +// value: A small float number to avoid dividing by 0. +// If not specified, defaults to 1e-05 +func QuantizedInstanceNormVarianceEpsilon(value float32) QuantizedInstanceNormAttr { + return func(m optionalAttr) { + m["variance_epsilon"] = value + } +} + +// QuantizedInstanceNormMinSeparation sets the optional min_separation attribute to value. +// +// value: Minimum value of `y_max - y_min` +// If not specified, defaults to 0.001 +func QuantizedInstanceNormMinSeparation(value float32) QuantizedInstanceNormAttr { + return func(m optionalAttr) { + m["min_separation"] = value + } +} + +// Quantized Instance normalization. // // Arguments: -// max_size: The maximum size of the stack if non-negative. If negative, the stack -// size is unlimited. -// elem_type: The type of the elements on the stack. +// x: A 4D input Tensor. +// x_min: The value represented by the lowest quantized input. +// x_max: The value represented by the highest quantized input. // -// Returns The handle to the stack. -func StackV2(scope *Scope, max_size tf.Output, elem_type tf.DataType, optional ...StackV2Attr) (handle tf.Output) { +// Returns A 4D Tensor.The value represented by the lowest quantized output.The value represented by the highest quantized output. +func QuantizedInstanceNorm(scope *Scope, x tf.Output, x_min tf.Output, x_max tf.Output, optional ...QuantizedInstanceNormAttr) (y tf.Output, y_min tf.Output, y_max tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"elem_type": elem_type} + attrs := map[string]interface{}{} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "StackV2", + Type: "QuantizedInstanceNorm", Input: []tf.Input{ - max_size, + x, x_min, x_max, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1), op.Output(2) } -// Flushes and closes the summary writer. +// FakeQuantWithMinMaxVarsAttr is an optional argument to FakeQuantWithMinMaxVars. +type FakeQuantWithMinMaxVarsAttr func(optionalAttr) + +// FakeQuantWithMinMaxVarsNumBits sets the optional num_bits attribute to value. +// If not specified, defaults to 8 +func FakeQuantWithMinMaxVarsNumBits(value int64) FakeQuantWithMinMaxVarsAttr { + return func(m optionalAttr) { + m["num_bits"] = value + } +} + +// FakeQuantWithMinMaxVarsNarrowRange sets the optional narrow_range attribute to value. +// If not specified, defaults to false +func FakeQuantWithMinMaxVarsNarrowRange(value bool) FakeQuantWithMinMaxVarsAttr { + return func(m optionalAttr) { + m["narrow_range"] = value + } +} + +// Fake-quantize the 'inputs' tensor of type float via global float scalars `min` // -// Also removes it from the resource manager. To reopen, use another -// CreateSummaryFileWriter op. +// and `max` to 'outputs' tensor of same shape as `inputs`. // -// Arguments: -// writer: A handle to the summary writer resource. +// `[min; max]` define the clamping range for the `inputs` data. +// `inputs` values are quantized into the quantization range (`[0; 2^num_bits - 1]` +// when `narrow_range` is false and `[1; 2^num_bits - 1]` when it is true) and +// then de-quantized and output as floats in `[min; max]` interval. +// `num_bits` is the bitwidth of the quantization; between 2 and 8, inclusive. // -// Returns the created operation. -func CloseSummaryWriter(scope *Scope, writer tf.Output) (o *tf.Operation) { +// This operation has a gradient and thus allows for training `min` and `max` +// values. +func FakeQuantWithMinMaxVars(scope *Scope, inputs tf.Output, min tf.Output, max tf.Output, optional ...FakeQuantWithMinMaxVarsAttr) (outputs tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "CloseSummaryWriter", + Type: "FakeQuantWithMinMaxVars", Input: []tf.Input{ - writer, + inputs, min, max, }, + Attrs: attrs, } - return scope.AddOperation(opspec) + op := scope.AddOperation(opspec) + return op.Output(0) } -// Outputs a `Summary` protocol buffer with a tensor. -// -// Arguments: -// writer: A handle to a summary writer. -// step: The step to write the summary for. -// tensor: A tensor to serialize. -// tag: The summary's tag. -// summary_metadata: Serialized SummaryMetadata protocol buffer containing -// plugin-related metadata for this summary. +// FakeQuantWithMinMaxVarsPerChannelGradientAttr is an optional argument to FakeQuantWithMinMaxVarsPerChannelGradient. +type FakeQuantWithMinMaxVarsPerChannelGradientAttr func(optionalAttr) + +// FakeQuantWithMinMaxVarsPerChannelGradientNumBits sets the optional num_bits attribute to value. // -// Returns the created operation. -func WriteSummary(scope *Scope, writer tf.Output, step tf.Output, tensor tf.Output, tag tf.Output, summary_metadata tf.Output) (o *tf.Operation) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "WriteSummary", - Input: []tf.Input{ - writer, step, tensor, tag, summary_metadata, - }, +// value: The bitwidth of the quantization; between 2 and 8, inclusive. +// If not specified, defaults to 8 +func FakeQuantWithMinMaxVarsPerChannelGradientNumBits(value int64) FakeQuantWithMinMaxVarsPerChannelGradientAttr { + return func(m optionalAttr) { + m["num_bits"] = value } - return scope.AddOperation(opspec) } -// Outputs a `tf.Event` protocol buffer. +// FakeQuantWithMinMaxVarsPerChannelGradientNarrowRange sets the optional narrow_range attribute to value. // -// When CreateSummaryDbWriter is being used, this op can be useful for -// importing data from event logs. +// value: Whether to quantize into 2^num_bits - 1 distinct values. +// If not specified, defaults to false +func FakeQuantWithMinMaxVarsPerChannelGradientNarrowRange(value bool) FakeQuantWithMinMaxVarsPerChannelGradientAttr { + return func(m optionalAttr) { + m["narrow_range"] = value + } +} + +// Compute gradients for a FakeQuantWithMinMaxVarsPerChannel operation. // // Arguments: -// writer: A handle to a summary writer. -// event: A string containing a binary-encoded tf.Event proto. +// gradients: Backpropagated gradients above the FakeQuantWithMinMaxVars operation, +// shape one of: `[d]`, `[b, d]`, `[b, h, w, d]`. +// inputs: Values passed as inputs to the FakeQuantWithMinMaxVars operation, shape +// same as `gradients`. +// min, max: Quantization interval, floats of shape `[d]`. // -// Returns the created operation. -func ImportEvent(scope *Scope, writer tf.Output, event tf.Output) (o *tf.Operation) { +// +// +// Returns Backpropagated gradients w.r.t. inputs, shape same as +// `inputs`: +// `gradients * (inputs >= min && inputs <= max)`.Backpropagated gradients w.r.t. min parameter, shape `[d]`: +// `sum_per_d(gradients * (inputs < min))`.Backpropagated gradients w.r.t. max parameter, shape `[d]`: +// `sum_per_d(gradients * (inputs > max))`. +func FakeQuantWithMinMaxVarsPerChannelGradient(scope *Scope, gradients tf.Output, inputs tf.Output, min tf.Output, max tf.Output, optional ...FakeQuantWithMinMaxVarsPerChannelGradientAttr) (backprops_wrt_input tf.Output, backprop_wrt_min tf.Output, backprop_wrt_max tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "ImportEvent", + Type: "FakeQuantWithMinMaxVarsPerChannelGradient", Input: []tf.Input{ - writer, event, + gradients, inputs, min, max, }, + Attrs: attrs, } - return scope.AddOperation(opspec) + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2) } diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 3493ed76f3d00d5af2f065d30de279ac2109aab1..01b3e92d2d9edc12afc6c98da44a4442796592e9 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -32,6 +32,7 @@ load("//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library") load("//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library_py") load("//tensorflow/core:platform/default/build_config.bzl", "tf_additional_lib_deps") load("//tensorflow/core:platform/default/build_config.bzl", "tf_additional_all_protos") +load("//tensorflow/core:platform/default/build_config.bzl", "tf_protos_grappler") load("//tensorflow/core:platform/default/build_config_root.bzl", "tf_additional_plugin_deps") load("//tensorflow/python:build_defs.bzl", "tf_gen_op_wrapper_private_py") load("//tensorflow/core:platform/default/build_config_root.bzl", "tf_additional_verbs_deps") @@ -209,9 +210,8 @@ cc_library( "//tensorflow/core/grappler/costs:analytical_cost_estimator", "//tensorflow/core/grappler/costs:cost_estimator", "//tensorflow/core/grappler/costs:measuring_cost_estimator", - "//tensorflow/core/grappler/costs:op_performance_data_cc", "//tensorflow/core/grappler/costs:utils", - ], + ] + tf_protos_grappler(), ) cc_library( @@ -1387,6 +1387,13 @@ tf_gen_op_wrapper_private_py( ], ) +tf_gen_op_wrapper_private_py( + name = "batch_ops_gen", + visibility = [ + "//tensorflow:__subpackages__", + ], +) + tf_gen_op_wrapper_private_py( name = "math_ops_gen", visibility = [ @@ -1951,6 +1958,7 @@ py_library( srcs = ["ops/list_ops.py"], srcs_version = "PY2AND3", deps = [ + ":array_ops", ":list_ops_gen", ], ) diff --git a/tensorflow/python/data/kernel_tests/batch_dataset_op_test.py b/tensorflow/python/data/kernel_tests/batch_dataset_op_test.py index 53c8be1d1dc8b2f23b4faef7d64350edffede34a..bd80b9dbf561de16168b05facf0086dadcda6444 100644 --- a/tensorflow/python/data/kernel_tests/batch_dataset_op_test.py +++ b/tensorflow/python/data/kernel_tests/batch_dataset_op_test.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- # Copyright 2017 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -50,8 +51,9 @@ class BatchDatasetTest(test.TestCase): def _map_fn(x, y, z): return math_ops.square(x), math_ops.square(y), math_ops.square(z) - iterator = (dataset_ops.Dataset.from_tensor_slices(components).map(_map_fn) - .repeat(count).batch(batch_size).make_initializable_iterator()) + iterator = ( + dataset_ops.Dataset.from_tensor_slices(components).map(_map_fn) + .repeat(count).batch(batch_size).make_initializable_iterator()) init_op = iterator.initializer get_next = iterator.get_next() @@ -67,7 +69,7 @@ class BatchDatasetTest(test.TestCase): result = sess.run(get_next) for component, result_component in zip(components, result): for j in range(14): - self.assertAllEqual(component[(i*14 + j) % 7]**2, + self.assertAllEqual(component[(i * 14 + j) % 7]**2, result_component[j]) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) @@ -82,12 +84,12 @@ class BatchDatasetTest(test.TestCase): result = sess.run(get_next) for component, result_component in zip(components, result): for j in range(8): - self.assertAllEqual(component[(i*8 + j) % 7]**2, + self.assertAllEqual(component[(i * 8 + j) % 7]**2, result_component[j]) result = sess.run(get_next) for component, result_component in zip(components, result): for j in range((14 * 7) % 8): - self.assertAllEqual(component[((num_batches - 1)*8 + j) % 7]**2, + self.assertAllEqual(component[((num_batches - 1) * 8 + j) % 7]**2, result_component[j]) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) @@ -188,33 +190,34 @@ class BatchDatasetTest(test.TestCase): sess.run(get_next) def testBatchShapeError(self): + def generator(): yield [1.0, 2.0, 3.0] yield [4.0, 5.0, 6.0] yield [7.0, 8.0, 9.0, 10.0] - iterator = (dataset_ops.Dataset.from_generator(generator, dtypes.float32, - output_shapes=[None]) - .batch(3) - .make_initializable_iterator()) + iterator = ( + dataset_ops.Dataset.from_generator( + generator, dtypes.float32, output_shapes=[None]).batch(3) + .make_initializable_iterator()) next_element = iterator.get_next() with self.test_session() as sess: sess.run(iterator.initializer) with self.assertRaisesRegexp( errors.InvalidArgumentError, - r"Cannot batch tensors with different shapes in component 0. " - r"First element had shape \[3\] and element 2 had shape \[4\]."): + r'Cannot batch tensors with different shapes in component 0. ' + r'First element had shape \[3\] and element 2 had shape \[4\].'): sess.run(next_element) def testPaddedBatchDataset(self): seq_lens = array_ops.placeholder(dtypes.int32, shape=[None]) padded_shape = array_ops.placeholder(dtypes.int64, shape=[1]) - iterator = (dataset_ops.Dataset.from_tensor_slices(seq_lens) - .map(lambda x: array_ops.fill([x], x)).padded_batch( - 4, - padded_shapes=padded_shape).make_initializable_iterator()) + iterator = ( + dataset_ops.Dataset.from_tensor_slices(seq_lens) + .map(lambda x: array_ops.fill([x], x)).padded_batch( + 4, padded_shapes=padded_shape).make_initializable_iterator()) init_op = iterator.initializer get_next = iterator.get_next() @@ -222,35 +225,40 @@ class BatchDatasetTest(test.TestCase): with self.test_session() as sess: # Test with random sequence lengths, and max padding. random_seq_lens = np.random.randint(20, size=(32,)).astype(np.int32) - sess.run(init_op, feed_dict={padded_shape: [-1], - seq_lens: random_seq_lens}) + sess.run( + init_op, feed_dict={ + padded_shape: [-1], + seq_lens: random_seq_lens + }) for i in range(8): result = sess.run(get_next) padded_len = np.max(result) self.assertEqual((4, padded_len), result.shape) for j in range(4): - seq_len = random_seq_lens[(i*4)+j] + seq_len = random_seq_lens[(i * 4) + j] self.assertAllEqual(result[j, :seq_len], [seq_len] * seq_len) self.assertAllEqual(result[j, seq_len:], [0] * (padded_len - seq_len)) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) # Test with random sequence lengths, and constant padding. - sess.run(init_op, feed_dict={padded_shape: [25], - seq_lens: random_seq_lens}) + sess.run( + init_op, feed_dict={ + padded_shape: [25], + seq_lens: random_seq_lens + }) for i in range(8): result = sess.run(get_next) self.assertEqual((4, 25), result.shape) for j in range(4): - seq_len = random_seq_lens[(i*4)+j] + seq_len = random_seq_lens[(i * 4) + j] self.assertAllEqual(result[j, :seq_len], [seq_len] * seq_len) self.assertAllEqual(result[j, seq_len:], [0] * (25 - seq_len)) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) # Test correct handling of empty tensors. - sess.run(init_op, feed_dict={padded_shape: [-1], - seq_lens: [0, 0, 0, 0]}) + sess.run(init_op, feed_dict={padded_shape: [-1], seq_lens: [0, 0, 0, 0]}) result = sess.run(get_next) self.assertAllEqual([[], [], [], []], result) with self.assertRaises(errors.OutOfRangeError): @@ -258,8 +266,7 @@ class BatchDatasetTest(test.TestCase): # Test error handling with constant sequence lengths, and # too-short padding. - sess.run(init_op, feed_dict={padded_shape: [5], - seq_lens: [6, 5, 5, 5]}) + sess.run(init_op, feed_dict={padded_shape: [5], seq_lens: [6, 5, 5, 5]}) with self.assertRaises(errors.DataLossError): result = sess.run(get_next) @@ -270,11 +277,13 @@ class BatchDatasetTest(test.TestCase): def fill_tuple(x): filled = array_ops.fill([x], x) return (filled, string_ops.as_string(filled)) - iterator = (dataset_ops.Dataset.from_tensor_slices(seq_lens).map(fill_tuple) - .padded_batch( - 4, - padded_shapes=(padded_shape, padded_shape), - padding_values=(-1, "")).make_initializable_iterator()) + + iterator = ( + dataset_ops.Dataset.from_tensor_slices(seq_lens).map(fill_tuple) + .padded_batch( + 4, + padded_shapes=(padded_shape, padded_shape), + padding_values=(-1, '')).make_initializable_iterator()) init_op = iterator.initializer get_next = iterator.get_next() @@ -282,25 +291,46 @@ class BatchDatasetTest(test.TestCase): with self.test_session() as sess: # Test with random sequence lengths, and max padding. random_seq_lens = np.random.randint(20, size=(32,)).astype(np.int32) - sess.run(init_op, feed_dict={padded_shape: [-1], - seq_lens: random_seq_lens}) + sess.run( + init_op, feed_dict={ + padded_shape: [-1], + seq_lens: random_seq_lens + }) for i in range(8): result = sess.run(get_next) padded_len = np.max(result[0]) self.assertEqual((4, padded_len), result[0].shape) self.assertEqual((4, padded_len), result[1].shape) for j in range(4): - seq_len = random_seq_lens[(i*4)+j] + seq_len = random_seq_lens[(i * 4) + j] self.assertAllEqual(result[0][j, :seq_len], [seq_len] * seq_len) self.assertAllEqual(result[0][j, seq_len:], [-1] * (padded_len - seq_len)) self.assertAllEqual(result[1][j, :seq_len], [compat.as_bytes(str(seq_len))] * seq_len) self.assertAllEqual(result[1][j, seq_len:], - [b""] * (padded_len - seq_len)) + [b''] * (padded_len - seq_len)) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) + def testPaddedBatchDatasetUnicode(self): + # See GitHub issue 16149 + def generator(): + data = [[u'Простой', u'тест', u'юникода'], + [u'никогда', u'не', u'бывает', u'простым']] + + for seq in data: + yield seq, [0, 1, 2, 3] + + dataset = dataset_ops.Dataset.from_generator( + generator, (dtypes.string, dtypes.int32), + (tensor_shape.TensorShape([None]), tensor_shape.TensorShape([None]))) + padded_dataset = dataset.padded_batch( + 2, padded_shapes=([None], [None]), padding_values=('', 0)) + with self.test_session() as sess: + next_element = padded_dataset.make_one_shot_iterator().get_next() + sess.run(next_element) + def testPaddedBatchDatasetShapeSpecifications(self): int_placeholder = array_ops.placeholder(dtypes.int32) float_placeholder = array_ops.placeholder(dtypes.float32) @@ -324,15 +354,16 @@ class BatchDatasetTest(test.TestCase): constant_op.constant([-1, -1], dtype=dtypes.int64), constant_op.constant([37], dtype=dtypes.int64))) - for dataset in [dynamic_padding_from_tensor_shapes, - dynamic_padding_from_lists, - dynamic_padding_from_lists_with_minus_one, - dynamic_padding_from_tensors]: + for dataset in [ + dynamic_padding_from_tensor_shapes, dynamic_padding_from_lists, + dynamic_padding_from_lists_with_minus_one, dynamic_padding_from_tensors + ]: self.assertEqual([None, None], dataset.output_shapes[0].as_list()) self.assertEqual([None, None, None], dataset.output_shapes[1].as_list()) self.assertEqual([None, 37], dataset.output_shapes[2].as_list()) def testPaddedBatchSparseError(self): + def _map_fn(i): return sparse_tensor.SparseTensorValue( indices=[[0, 0]], values=(i * [1]), dense_shape=[1, 1]), i @@ -341,5 +372,5 @@ class BatchDatasetTest(test.TestCase): _ = dataset_ops.Dataset.range(10).map(_map_fn).padded_batch(10) -if __name__ == "__main__": +if __name__ == '__main__': test.main() diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py index 0594c6d6a7325ae0952f012e0d543e5c80edb529..c1ba67e4744c6282f0fd3d9a388aabc1ed51267b 100644 --- a/tensorflow/python/data/ops/dataset_ops.py +++ b/tensorflow/python/data/ops/dataset_ops.py @@ -201,7 +201,7 @@ class Dataset(object): tensors: A nested structure of tensors. Returns: - A `Dataset`. + Dataset: A `Dataset`. """ return TensorDataset(tensors) @@ -214,7 +214,7 @@ class Dataset(object): 0th dimension. Returns: - A `Dataset`. + Dataset: A `Dataset`. """ return TensorSliceDataset(tensors) @@ -227,7 +227,7 @@ class Dataset(object): sparse_tensor: A `tf.SparseTensor`. Returns: - A `Dataset` of rank-(N-1) sparse tensors. + Dataset: A `Dataset` of rank-(N-1) sparse tensors. """ return SparseTensorSliceDataset(sparse_tensor) @@ -313,7 +313,7 @@ class Dataset(object): `generator`. Returns: - A `Dataset`. + Dataset: A `Dataset`. """ if not callable(generator): raise TypeError("`generator` must be callable.") @@ -456,7 +456,7 @@ class Dataset(object): len(args) == 3 -> start = args[0], stop = args[1, stop = args[2] Returns: - A `RangeDataset`. + Dataset: A `RangeDataset`. Raises: ValueError: if len(args) == 0. @@ -500,7 +500,7 @@ class Dataset(object): datasets: A nested structure of datasets. Returns: - A `Dataset`. + Dataset: A `Dataset`. """ return ZipDataset(datasets) @@ -526,7 +526,7 @@ class Dataset(object): dataset: `Dataset` to be concatenated. Returns: - A `Dataset`. + Dataset: A `Dataset`. """ return ConcatenateDataset(self, dataset) @@ -538,7 +538,7 @@ class Dataset(object): maximum number elements that will be buffered when prefetching. Returns: - A `Dataset`. + Dataset: A `Dataset`. """ return PrefetchDataset(self, buffer_size) @@ -561,7 +561,7 @@ class Dataset(object): the filename pattern that will be matched. Returns: - A `Dataset` of strings corresponding to file names. + Dataset: A `Dataset` of strings corresponding to file names. """ return Dataset.from_tensor_slices(gen_io_ops.matching_files(file_pattern)) @@ -578,7 +578,7 @@ class Dataset(object): indefinitely. Returns: - A `Dataset`. + Dataset: A `Dataset`. """ return RepeatDataset(self, count) @@ -602,7 +602,7 @@ class Dataset(object): iterated over. (Defaults to `True`.) Returns: - A `Dataset`. + Dataset: A `Dataset`. """ return ShuffleDataset(self, buffer_size, seed, reshuffle_each_iteration) @@ -615,7 +615,7 @@ class Dataset(object): If a filename is not provided, the dataset will be cached in memory. Returns: - A `Dataset`. + Dataset: A `Dataset`. """ return CacheDataset(self, filename) @@ -629,7 +629,7 @@ class Dataset(object): dataset, the new dataset will contain all elements of this dataset. Returns: - A `Dataset`. + Dataset: A `Dataset`. """ return TakeDataset(self, count) @@ -644,7 +644,7 @@ class Dataset(object): is -1, skips the entire dataset. Returns: - A `Dataset`. + Dataset: A `Dataset`. """ return SkipDataset(self, count) @@ -691,7 +691,7 @@ class Dataset(object): index: A `tf.int64` scalar `tf.Tensor`, representing the worker index. Returns: - A `Dataset`. + Dataset: A `Dataset`. Raises: ValueError: if `num_shards` or `index` are illegal values. Note: error @@ -735,7 +735,7 @@ class Dataset(object): consecutive elements of this dataset to combine in a single batch. Returns: - A `Dataset`. + Dataset: A `Dataset`. """ return BatchDataset(self, batch_size) @@ -764,7 +764,7 @@ class Dataset(object): the empty string for string types. Returns: - A `Dataset`. + Dataset: A `Dataset`. """ return PaddedBatchDataset(self, batch_size, padded_shapes, padding_values) @@ -780,7 +780,7 @@ class Dataset(object): specified, elements will be processed sequentially. Returns: - A `Dataset`. + Dataset: A `Dataset`. """ if num_parallel_calls is None: return MapDataset(self, map_func) @@ -796,7 +796,7 @@ class Dataset(object): `Dataset`. Returns: - A `Dataset`. + Dataset: A `Dataset`. """ return FlatMapDataset(self, map_func) @@ -865,7 +865,7 @@ class Dataset(object): input element before cycling to another input element. Returns: - A `Dataset`. + Dataset: A `Dataset`. """ return InterleaveDataset(self, map_func, cycle_length, block_length) @@ -878,7 +878,7 @@ class Dataset(object): scalar `tf.bool` tensor. Returns: - A `Dataset`. + Dataset: A `Dataset`. """ return FilterDataset(self, predicate) @@ -902,7 +902,7 @@ class Dataset(object): returns a `Dataset`. Returns: - The `Dataset` returned by applying `transformation_func` to this dataset. + Dataset: The `Dataset` returned by applying `transformation_func` to this dataset. """ dataset = transformation_func(self) if not isinstance(dataset, Dataset): diff --git a/tensorflow/python/debug/lib/debug_gradients_test.py b/tensorflow/python/debug/lib/debug_gradients_test.py index 6fd89e018aa3b2a21dad4b56a4aa1a5b01a1d69d..b6c7280a415b367751c4900a302e5af61f260cb0 100644 --- a/tensorflow/python/debug/lib/debug_gradients_test.py +++ b/tensorflow/python/debug/lib/debug_gradients_test.py @@ -39,7 +39,7 @@ class IdentifyGradientTest(test_util.TensorFlowTestCase): def setUp(self): self.sess = session.Session() - with self.sess: + with self.sess.as_default(): self.u = variables.Variable(2.0, name="u") self.v = variables.Variable(3.0, name="v") self.w = math_ops.multiply(self.u.value(), self.v.value(), name="w") diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD index f470e181200f19d672cced3ea21d05aa2eee0bea..9e3382d4f301529cd2b476bc76efe7dfd2be9298 100644 --- a/tensorflow/python/eager/BUILD +++ b/tensorflow/python/eager/BUILD @@ -1,8 +1,7 @@ licenses(["notice"]) # Apache 2.0 -load("//tensorflow:tensorflow.bzl", "py_test") +load("//tensorflow:tensorflow.bzl", "py_test", "tf_cc_binary") load("//tensorflow:tensorflow.bzl", "cuda_py_test") -load("//tensorflow:tensorflow.bzl", "tf_cc_binary") load( "//tensorflow/tools/test:performance.bzl", "tf_py_logged_benchmark", @@ -423,6 +422,22 @@ cuda_py_test( ], ) +py_test( + name = "pywrap_tfe_test", + srcs = ["pywrap_tfe_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":backprop", + ":context", + ":test", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:math_ops", + "//tensorflow/python:pywrap_tensorflow", + "//tensorflow/python:random_ops", + "//third_party/py/numpy", + ], +) + # ----------------------------------------------------------------------------- # Google-internal targets. diff --git a/tensorflow/python/eager/benchmarks_test.py b/tensorflow/python/eager/benchmarks_test.py index 9849f0f322eff2d909e7396158539a9663b95f29..75526ba9c139e78dfe9e3de271f1316924539371 100644 --- a/tensorflow/python/eager/benchmarks_test.py +++ b/tensorflow/python/eager/benchmarks_test.py @@ -42,11 +42,25 @@ from tensorflow.python.ops import gen_math_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops - CPU = "/device:CPU:0" GPU = "/device:GPU:0" +def record_gradient_callback(inputs, attrs, results): + return backprop._record_gradient("MatMul", inputs, attrs, results, None) + + +def c_tfe_py_fastpath_execute(a, b, transpose_a=False, transpose_b=False): + ctx = context.context() + assert not ctx.in_graph_mode( + ), "The prototype doesn't contain C code for graph construction" + ctx_handle = ctx._handle # pylint: disable=protected-access + + return pywrap_tensorflow.TFE_Py_FastPathExecute( + ctx_handle, None, "MatMul", record_gradient_callback, a, b, + "transpose_a", transpose_a, "transpose_b", transpose_b)[0] + + class MicroBenchmarks(test.Benchmark): def __init__(self): @@ -222,6 +236,14 @@ class MicroBenchmarks(test.Benchmark): gen_math_ops._mat_mul(m, m, transpose_b=transpose_b) self._run(func, num_iters) + def _benchmark_tfe_py_fastpath_execute_matmul(self, m, transpose_b, + num_iters): + + def func(): + c_tfe_py_fastpath_execute(m, m, transpose_b=transpose_b) + + self._run(func, num_iters) + def _benchmark_tfe_py_execute_matmul(self, m, transpose_b, num_iters): inputs = [m, m] # pylint: disable=protected-access @@ -257,6 +279,12 @@ class MicroBenchmarks(test.Benchmark): self._benchmark_gen_math_ops_matmul( m, transpose_b=False, num_iters=self._num_iters_2_by_2) + def benchmark_tfe_py_fastpath_execute_matmul_2_by_2_CPU(self): + with context.device(CPU): + m = self._m_2_by_2.cpu() + self._benchmark_tfe_py_fastpath_execute_matmul( + m, transpose_b=False, num_iters=self._num_iters_2_by_2) + def benchmark_tfe_py_execute_matmul_2_by_2_CPU(self): with context.device(CPU): m = self._m_2_by_2.cpu() @@ -320,6 +348,12 @@ class MicroBenchmarks(test.Benchmark): self._benchmark_gen_math_ops_matmul( m, transpose_b=True, num_iters=self._num_iters_100_by_784) + def benchmark_tfe_py_fastpath_execute_matmul_100_by_784_CPU(self): + with context.device(CPU): + m = self._m_100_by_784.cpu() + self._benchmark_tfe_py_fastpath_execute_matmul( + m, transpose_b=True, num_iters=self._num_iters_100_by_784) + def benchmark_tfe_py_execute_matmul_100_by_784_CPU(self): with context.device(CPU): m = self._m_100_by_784.cpu() diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py index cbf588336d75dbc16e73ea227d8ebba639a84f1c..b6c7d823237231a138f6a25bb9d03954b69d58d9 100644 --- a/tensorflow/python/eager/context.py +++ b/tensorflow/python/eager/context.py @@ -49,6 +49,8 @@ _MAXINT32 = 2**31 - 1 DEVICE_PLACEMENT_EXPLICIT = pywrap_tensorflow.TFE_DEVICE_PLACEMENT_EXPLICIT DEVICE_PLACEMENT_WARN = pywrap_tensorflow.TFE_DEVICE_PLACEMENT_WARN DEVICE_PLACEMENT_SILENT = pywrap_tensorflow.TFE_DEVICE_PLACEMENT_SILENT +DEVICE_PLACEMENT_SILENT_FOR_INT32 = ( + pywrap_tensorflow.TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32) # TODO(agarwal): better name ? @@ -122,6 +124,8 @@ class Context(object): right device but raises a warning. tfe.DEVICE_PLACEMENT_SILENT: silently copies the tensors. This might hide performance problems. + tfe.DEVICE_PLACEMENT_SILENT_FOR_INT32: silently copies int32 tensors, + raising errors on the other ones. """ self._eager_context = _EagerContext() self._context_handle = None @@ -411,6 +415,20 @@ class Context(object): self._initialize_handle_and_devices() pywrap_tensorflow.TFE_ContextEnableRunMetadata(self._context_handle) + @tf_contextlib.contextmanager + def device_policy(self, policy): + if not self._context_handle: + self._initialize_handle_and_devices() + old = pywrap_tensorflow.TFE_ContextGetDevicePlacementPolicy( + self._context_handle) + pywrap_tensorflow.TFE_ContextSetThreadLocalDevicePlacementPolicy( + self._handle, policy) + try: + yield + finally: + pywrap_tensorflow.TFE_ContextSetThreadLocalDevicePlacementPolicy( + self._handle, old) + def disable_run_metadata(self): """Disables tracing of op execution via RunMetadata.""" if not self._context_handle: diff --git a/tensorflow/python/eager/core_test.py b/tensorflow/python/eager/core_test.py index a70fa7280485497c4795bd890c1a19d2aa52d895..ee3c10633e1cb849e319f2f5490e5beb5dd15c80 100644 --- a/tensorflow/python/eager/core_test.py +++ b/tensorflow/python/eager/core_test.py @@ -33,6 +33,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import test_util +from tensorflow.python.ops import nn_ops def execute(op_name, num_outputs, inputs, attrs=None): @@ -112,6 +113,14 @@ class TFETest(test_util.TensorFlowTestCase): # is enabled; the stack entry should reflect this fact. self.assertFalse(stack_entry.is_building_function) + def testInt32GPU(self): + if not context.context().num_gpus(): + self.skipTest('No GPUs found') + with ops.device('gpu:0'): + xent = nn_ops.sparse_softmax_cross_entropy_with_logits( + logits=[[0.0, 0.0]], labels=[0]) + self.assertAllClose(xent, [0.69314718]) + def _runInThread(self, target, args): t = threading.Thread(target=target, args=args) try: @@ -173,6 +182,15 @@ class TFETest(test_util.TensorFlowTestCase): with self.assertRaises(RuntimeError): x.gpu(context.context().num_gpus() + 1) + def testCopyScope(self): + if not context.context().num_gpus(): + self.skipTest('No GPUs found') + constant = constant_op.constant(1.0) + with ops.device('gpu:0'): + with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT): + c = constant + 1.0 + self.assertAllEqual(c, 2.0) + def testNumpyForceCPU(self): if not context.context().num_gpus(): self.skipTest('No GPUs found') diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py index 9b08a35ff1aaaab2559b89d4c8106685783503d5..0babc29f17b21ee663cdd5bd170875247353e70b 100644 --- a/tensorflow/python/eager/function_test.py +++ b/tensorflow/python/eager/function_test.py @@ -400,10 +400,11 @@ class FunctionTest(test.TestCase): # The Reshape op requires the shape tensor to be placed in host memory. reshape = function.defun(array_ops.reshape) - value = constant_op.constant([1., 2.]).gpu() + value = constant_op.constant([1., 2.]) shape = constant_op.constant([2, 1]).gpu() with self.assertRaises(errors.InvalidArgumentError): - reshape(value, shape) + with ops.device('gpu:0'): + reshape(value, shape) def testDifferentiableFunctionNoneOutputs(self): diff --git a/tensorflow/python/eager/ops_test.py b/tensorflow/python/eager/ops_test.py index f8c5037dcf8d4c9c2ca90c641981c9280b946c4f..f2e70341d975fb06bce7f2ce6cba7d8c3bc9826c 100644 --- a/tensorflow/python/eager/ops_test.py +++ b/tensorflow/python/eager/ops_test.py @@ -24,7 +24,6 @@ from tensorflow.python.eager import execute from tensorflow.python.eager import test from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes -from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import test_util @@ -246,15 +245,6 @@ class OpsTest(test_util.TensorFlowTestCase): reshaped = array_ops.reshape(value, shape) self.assertAllEqual([[1], [2]], reshaped.cpu()) - # And if the shape is in device memory, it should complain - # TODO(ashankar): Revisit this - perhaps instead of complaining, - # it should implicitly copy the tensor to host memory? - with self.assertRaisesRegexp( - errors.InvalidArgumentError, - 'cannot compute Reshape as input #1 was expected to be on.*' - 'using.*DEVICE_PLACEMENT_SILENT'): - reshaped = array_ops.reshape(value, shape.gpu()) - def testInt64(self): # Fill requires the first input to be an int32 tensor. self.assertAllEqual( diff --git a/tensorflow/python/eager/python_eager_op_gen.h b/tensorflow/python/eager/python_eager_op_gen.h index f9dfdf0408f2ea0cf72631e67266ec445b98a868..d27b00139d129aba1c511a21afce749eae8b32ed 100644 --- a/tensorflow/python/eager/python_eager_op_gen.h +++ b/tensorflow/python/eager/python_eager_op_gen.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_PYTHON_EAGER_PYTHON_EAGER_OP_GEN_H_ -#define THIRD_PARTY_TENSORFLOW_PYTHON_EAGER_PYTHON_EAGER_OP_GEN_H_ +#ifndef TENSORFLOW_PYTHON_EAGER_PYTHON_EAGER_OP_GEN_H_ +#define TENSORFLOW_PYTHON_EAGER_PYTHON_EAGER_OP_GEN_H_ #include #include @@ -40,4 +40,4 @@ string GetEagerPythonWrappers(const char* op_list_buf, size_t op_list_len); } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_PYTHON_EAGER_PYTHON_EAGER_OP_GEN_H_ +#endif // TENSORFLOW_PYTHON_EAGER_PYTHON_EAGER_OP_GEN_H_ diff --git a/tensorflow/python/eager/pywrap_tfe.h b/tensorflow/python/eager/pywrap_tfe.h index cecef426032f967afd122b1cfeec6f29d2d7e7a5..4aea134fa9df845fe2a84f32d56a17a8766bde9b 100644 --- a/tensorflow/python/eager/pywrap_tfe.h +++ b/tensorflow/python/eager/pywrap_tfe.h @@ -131,6 +131,28 @@ PyObject* TFE_Py_TapeGradient(PyObject* tape, PyObject* vspace, PyObject* target, PyObject* sources, PyObject* output_gradients, TF_Status* status); +// Execute a tensorflow operation assuming that all provided inputs are +// correctly formatted (i.e. EagerTensors). If it doesn't find EagerTensors, +// it will simply fail with a NotImplementedError. +// +// The first PyObject* is unused. +// The "args" PyObject* is meant to be a tuple with the following structure: +// Item 1: The TFE Context +// Item 2: device_name: Name of the device on which to execute the operation, +// or NULL for automatic selection. +// Item 3: op_name: Name of the TensorFlow op to execute. +// Item 4: record_gradient_callback: Callback that records the gradient of the +// result. +// The callback takes (inputs, attrs, result) - all sequences and +// records the gradient. +// Item 5 onwards: inputs - This is a list of inputs followed by a list of +// attrs. It is not necessary for type attrs to be present. +// +// This is named _C since there doesn't seem to be any way to make it visible +// in the SWIG interface without renaming due to the use of the %native +// directive. +PyObject* TFE_Py_FastPathExecute_C(PyObject*, PyObject* args); + // Returns the set of variables watched by the given tape. PyObject* TFE_Py_TapeWatchedVariables(PyObject* tape); diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc index 38c3cb21743eb3657274e8e6ce5ebc3fc85e26b9..6162644036998bfaa97ac4a37680b661d844ff7a 100644 --- a/tensorflow/python/eager/pywrap_tfe_src.cc +++ b/tensorflow/python/eager/pywrap_tfe_src.cc @@ -21,12 +21,16 @@ limitations under the License. #include "tensorflow/c/c_api_internal.h" #include "tensorflow/c/eager/c_api_internal.h" #include "tensorflow/c/eager/tape.h" +#include "tensorflow/core/lib/gtl/cleanup.h" +#include "tensorflow/core/lib/gtl/compactptrset.h" #include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/python/eager/pywrap_tensor.h" using tensorflow::string; +using tensorflow::strings::Printf; namespace { @@ -289,14 +293,12 @@ bool SetOpAttrScalar(TFE_Context* ctx, TFE_Op* op, const char* key, return true; } -void SetOpAttrs(TFE_Context* ctx, TFE_Op* op, PyObject* attrs, +// start_index is the index at which the Tuple/List attrs will start getting +// processed. +void SetOpAttrs(TFE_Context* ctx, TFE_Op* op, PyObject* attrs, int start_index, TF_Status* out_status) { if (attrs == Py_None) return; - if (!PyTuple_Check(attrs)) { - TF_SetStatus(out_status, TF_INVALID_ARGUMENT, "Expecting an attrs tuple."); - return; - } - Py_ssize_t len = PyTuple_GET_SIZE(attrs); + Py_ssize_t len = PyTuple_GET_SIZE(attrs) - start_index; if ((len & 1) != 0) { TF_SetStatus(out_status, TF_INVALID_ARGUMENT, "Expecting attrs tuple to have even length."); @@ -304,8 +306,8 @@ void SetOpAttrs(TFE_Context* ctx, TFE_Op* op, PyObject* attrs, } // Parse attrs for (Py_ssize_t i = 0; i < len; i += 2) { - PyObject* py_key = PyTuple_GET_ITEM(attrs, i); - PyObject* py_value = PyTuple_GET_ITEM(attrs, i + 1); + PyObject* py_key = PyTuple_GET_ITEM(attrs, start_index + i); + PyObject* py_value = PyTuple_GET_ITEM(attrs, start_index + i + 1); #if PY_MAJOR_VERSION >= 3 const char* key = PyBytes_Check(py_key) ? PyBytes_AsString(py_key) : PyUnicode_AsUTF8(py_key); @@ -329,7 +331,6 @@ PyObject* exception_class GUARDED_BY(exception_class_mutex) = nullptr; static tensorflow::mutex _uid_mutex(tensorflow::LINKER_INITIALIZED); static tensorflow::int64 _uid GUARDED_BY(_uid_mutex) = 0; - } // namespace void TFE_Py_Execute(TFE_Context* ctx, const char* device_name, @@ -346,7 +347,7 @@ void TFE_Py_Execute(TFE_Context* ctx, const char* device_name, } } if (TF_GetCode(out_status) == TF_OK) { - SetOpAttrs(ctx, op, attrs, out_status); + SetOpAttrs(ctx, op, attrs, 0, out_status); } Py_BEGIN_ALLOW_THREADS; if (TF_GetCode(out_status) == TF_OK) { @@ -542,10 +543,10 @@ static PyTypeObject TFE_Py_Tape_Type = { // GIL, which is always held when any TFE_Py_* methods are called. We should // revisit this if/when decide to not hold the GIL while manipulating the tape // stack. -static std::unordered_set* tape_set = nullptr; -std::unordered_set* GetTapeSet() { +static tensorflow::gtl::CompactPointerSet* tape_set = nullptr; +tensorflow::gtl::CompactPointerSet* GetTapeSet() { if (tape_set == nullptr) { - tape_set = new std::unordered_set; + tape_set = new tensorflow::gtl::CompactPointerSet; } return tape_set; } @@ -636,8 +637,8 @@ PyObject* TFE_Py_TapeSetShouldRecord(PyObject* tensors) { if (*ThreadTapeIsStopped()) { Py_RETURN_FALSE; } - auto* tape_set = GetTapeSet(); - if (tape_set->empty()) { + auto* tape_set_ptr = GetTapeSet(); + if (tape_set_ptr->empty()) { Py_RETURN_FALSE; } PyObject* seq = PySequence_Fast(tensors, "expected a sequence"); @@ -654,7 +655,8 @@ PyObject* TFE_Py_TapeSetShouldRecord(PyObject* tensors) { tensor_ids.push_back(FastTensorId(item)); } Py_DECREF(seq); - for (TFE_Py_Tape* tape : *tape_set) { + auto tape_set = *tape_set_ptr; + for (TFE_Py_Tape* tape : tape_set) { if (tape->tape->ShouldRecord(tensor_ids)) { Py_RETURN_TRUE; } @@ -760,8 +762,7 @@ PyObject* TFE_Py_TapeWatchedVariables(PyObject* tape) { void TFE_Py_TapeSetRecordOperation(PyObject* op_type, PyObject* output_tensors, PyObject* input_tensors, PyObject* backward_function) { - auto* set = GetTapeSet(); - if (set->empty() || *ThreadTapeIsStopped()) { + if (GetTapeSet()->empty() || *ThreadTapeIsStopped()) { return; } std::vector input_ids = MakeTensorIDList(input_tensors); @@ -796,7 +797,8 @@ void TFE_Py_TapeSetRecordOperation(PyObject* op_type, PyObject* output_tensors, return; } - for (TFE_Py_Tape* tape : *set) { + auto set = *GetTapeSet(); + for (TFE_Py_Tape* tape : set) { Py_INCREF(backward_function); tape->tape->RecordOperation( op_type_str, output_info, input_ids, backward_function, @@ -805,7 +807,10 @@ void TFE_Py_TapeSetRecordOperation(PyObject* op_type, PyObject* output_tensors, } void TFE_Py_TapeSetDeleteTrace(tensorflow::int64 tensor_id) { - for (TFE_Py_Tape* tape : *GetTapeSet()) { + // Note: making a copy because deleting the trace can trigger a change to the + // set of tapes by allowing python's garbage collector to run. + auto tape_set = *GetTapeSet(); + for (TFE_Py_Tape* tape : tape_set) { tape->tape->DeleteTrace(tensor_id); } } @@ -974,7 +979,6 @@ std::vector MakeTensorList(PyObject* tensors) { return list; } - PyObject* TFE_Py_TapeGradient(PyObject* tape, PyObject* vspace, PyObject* target, PyObject* sources, PyObject* output_gradients, TF_Status* status) { @@ -1029,3 +1033,195 @@ PyObject* TFE_Py_TapeGradient(PyObject* tape, PyObject* vspace, Py_INCREF(Py_None); return Py_None; } + +namespace { +static const int kFastPathExecuteInputStartIndex = 4; + +bool CheckEagerTensors(PyObject* seq, int start_index, int num_to_check) { + for (int i = start_index; i < start_index + num_to_check; i++) { + PyObject* item = PyTuple_GET_ITEM(seq, i); + if (!EagerTensor_CheckExact(item)) return false; + } + + return true; +} + +const tensorflow::OpDef* GetOpDef(PyObject* py_op_name) { + const char* op_name = TFE_GetPythonString(py_op_name); + if (op_name == nullptr) { + PyErr_SetString(PyExc_TypeError, + Printf("expected a string for op_name, got %s instead", + py_op_name->ob_type->tp_name) + .c_str()); + return nullptr; + } + + const tensorflow::OpRegistrationData* op_reg_data = nullptr; + const tensorflow::Status lookup_status = + tensorflow::OpRegistry::Global()->LookUp(op_name, &op_reg_data); + if (MaybeRaiseExceptionFromStatus(lookup_status, nullptr)) { + return nullptr; + } + return &op_reg_data->op_def; +} + +const char* GetDeviceName(PyObject* py_device_name) { + if (py_device_name != Py_None) { + return TFE_GetPythonString(py_device_name); + } + return nullptr; +} + +bool MaybeRunRecordGradientCallback(const tensorflow::OpDef* op_def, + PyObject* args, PyObject* result, + PyObject* record_gradient_callback) { + if (*ThreadTapeIsStopped() || GetTapeSet()->empty() || + record_gradient_callback == Py_None) { + return true; + } + if (!PyCallable_Check(record_gradient_callback)) { + PyErr_SetString( + PyExc_TypeError, + Printf( + "expected a function for record_gradient_callback, got %s instead", + record_gradient_callback->ob_type->tp_name) + .c_str()); + return false; + } + + PyObject* inputs = PyTuple_New(op_def->input_arg_size()); + for (int i = 0; i < op_def->input_arg_size(); i++) { + auto* input = PyTuple_GET_ITEM(args, kFastPathExecuteInputStartIndex + i); + Py_INCREF(input); + PyTuple_SET_ITEM(inputs, i, input); + } + + int args_size = PyTuple_GET_SIZE(args); + int num_attrs = + args_size - op_def->input_arg_size() - kFastPathExecuteInputStartIndex; + PyObject* attrs = PyTuple_New(num_attrs); + for (int i = 0; i < num_attrs; i++) { + auto* attr = PyTuple_GET_ITEM( + args, kFastPathExecuteInputStartIndex + op_def->input_arg_size() + i); + Py_INCREF(attr); + PyTuple_SET_ITEM(attrs, i, attr); + } + + PyObject* callback_args = Py_BuildValue("OOO", inputs, attrs, result); + PyObject_CallObject(record_gradient_callback, callback_args); + + Py_DECREF(inputs); + Py_DECREF(callback_args); + Py_DECREF(attrs); + return true; +} +} // namespace + +PyObject* TFE_Py_FastPathExecute_C(PyObject*, PyObject* args) { + TFE_Context* ctx = reinterpret_cast( + PyCapsule_GetPointer(PyTuple_GET_ITEM(args, 0), nullptr)); + const tensorflow::OpDef* op_def = GetOpDef(PyTuple_GET_ITEM(args, 2)); + if (op_def == nullptr) return nullptr; + const char* device_name = GetDeviceName(PyTuple_GET_ITEM(args, 1)); + PyObject* record_gradient_callback = PyTuple_GET_ITEM(args, 3); + + Py_ssize_t args_size = PyTuple_GET_SIZE(args); + if (args_size < kFastPathExecuteInputStartIndex) { + PyErr_SetString( + PyExc_ValueError, + Printf("There must be at least %d items in the input tuple.", + kFastPathExecuteInputStartIndex) + .c_str()); + return nullptr; + } + + if (args_size < kFastPathExecuteInputStartIndex + op_def->input_arg_size()) { + PyErr_SetString( + PyExc_ValueError, + Printf("Tuple size smaller than intended. Expected to be at least %d, " + "was %ld", + kFastPathExecuteInputStartIndex + op_def->input_arg_size(), + args_size) + .c_str()); + return nullptr; + } + + if (!CheckEagerTensors(args, kFastPathExecuteInputStartIndex, + op_def->input_arg_size())) { + // TODO(nareshmodi): Maybe some other way of signalling that this should + // fall back? + PyErr_SetString(PyExc_NotImplementedError, + "This function does not handle the case of the path where " + "all inputs are not already EagerTensors."); + return nullptr; + } + + TF_Status* status = TF_NewStatus(); + TFE_Op* op = TFE_NewOp(ctx, op_def->name().c_str(), status); + auto cleaner = tensorflow::gtl::MakeCleanup([status, op] { + TF_DeleteStatus(status); + TFE_DeleteOp(op); + }); + if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) { + return nullptr; + } + + TFE_OpSetDevice(op, device_name, status); + if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) { + return nullptr; + } + + // Add non-type attrs. + SetOpAttrs(ctx, op, args, + kFastPathExecuteInputStartIndex + op_def->input_arg_size(), + status); + if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) { + return nullptr; + } + + // Add type attrs and inputs. + for (int i = 0; i < op_def->input_arg_size(); i++) { + const auto& input_arg = op_def->input_arg(i); + + PyObject* input = + PyTuple_GET_ITEM(args, kFastPathExecuteInputStartIndex + i); + TFE_TensorHandle* input_handle = EagerTensor_Handle(input); + + // The following code might set duplicate type attrs. This will result in + // the CacheKey for the generated AttrBuilder possibly differing from those + // where the type attrs are correctly set. Inconsistent CacheKeys for ops + // means that there might be unnecessarily duplicated kernels. + // TODO(nareshmodi): Fix this. + if (!input_arg.type_attr().empty()) { + TFE_OpSetAttrType(op, input_arg.type_attr().data(), + TFE_TensorHandleDataType(input_handle)); + } + + TFE_OpAddInput(op, input_handle, status); + if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) { + return nullptr; + } + } + + int num_retvals = op_def->output_arg_size(); + tensorflow::gtl::InlinedVector retvals(num_retvals); + + Py_BEGIN_ALLOW_THREADS; + TFE_Execute(op, retvals.data(), &num_retvals, status); + Py_END_ALLOW_THREADS; + if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) { + return nullptr; + } + + PyObject* result = PyTuple_New(num_retvals); + for (int i = 0; i < num_retvals; ++i) { + PyTuple_SET_ITEM(result, i, EagerTensorFromHandle(retvals[i])); + } + + if (!MaybeRunRecordGradientCallback(op_def, args, result, + record_gradient_callback)) { + return nullptr; + } + + return result; +} diff --git a/tensorflow/python/eager/pywrap_tfe_test.py b/tensorflow/python/eager/pywrap_tfe_test.py new file mode 100644 index 0000000000000000000000000000000000000000..d4f4ed592fb99e475af4652a33e5364d9abeea1a --- /dev/null +++ b/tensorflow/python/eager/pywrap_tfe_test.py @@ -0,0 +1,109 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for low-level eager execution primitives.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python import pywrap_tensorflow +from tensorflow.python.eager import backprop +from tensorflow.python.eager import context +from tensorflow.python.eager import test +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import test_util +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops + + +def record_gradient_callback(inputs, attrs, results): + return backprop._record_gradient("MatMul", inputs, attrs, results, None) + + +def c_tfe_py_fastpath_execute(a, b, transpose_a=False, transpose_b=False): + ctx = context.context() + assert not ctx.in_graph_mode( + ), "The prototype doesn't contain C code for graph construction" + ctx_handle = ctx._handle # pylint: disable=protected-access + + return pywrap_tensorflow.TFE_Py_FastPathExecute( + ctx_handle, ctx.device_name, "MatMul", record_gradient_callback, a, b, + "transpose_a", transpose_a, "transpose_b", transpose_b)[0] + + +class Tests(test.TestCase): + + @test_util.assert_no_new_tensors + @test_util.assert_no_garbage_created + def testFastpathExecute_MatMulCorrectResponse(self): + a_2_by_2 = random_ops.random_uniform((2, 2)) + b_2_by_2 = random_ops.random_uniform((2, 2)) + + a_100_by_784 = random_ops.random_uniform((100, 784)) + b_100_by_784 = random_ops.random_uniform((100, 784)) + + self.assertAllClose( + math_ops.matmul(a_2_by_2, b_2_by_2), + c_tfe_py_fastpath_execute(a_2_by_2, b_2_by_2)) + self.assertAllClose( + math_ops.matmul(a_100_by_784, b_100_by_784, transpose_b=True), + c_tfe_py_fastpath_execute(a_100_by_784, b_100_by_784, transpose_b=True)) + + @test_util.assert_no_new_tensors + @test_util.assert_no_garbage_created + def testFastpathExecute_TapeWrite(self): + with backprop.GradientTape(persistent=True) as tape: + a_2_by_2 = constant_op.constant(1.0, shape=[2, 2]) + tape.watch(a_2_by_2) + z = c_tfe_py_fastpath_execute(a_2_by_2, a_2_by_2) + dz_dy = tape.gradient(z, [a_2_by_2])[0] + self.assertAllEqual(dz_dy.numpy(), + constant_op.constant(4.0, shape=[2, 2]).numpy()) + + @test_util.assert_no_new_tensors + @test_util.assert_no_garbage_created + def testFastpathExecute_MatMulSlowPath(self): + a_2_by_2 = random_ops.random_uniform((2, 2)).cpu().numpy() + + with self.assertRaises(NotImplementedError): + c_tfe_py_fastpath_execute(a_2_by_2, a_2_by_2) + + @test_util.assert_no_new_tensors + @test_util.assert_no_garbage_created + def testFastpathExecute_InvalidInputs(self): + a_2_by_2 = random_ops.random_uniform((2, 2)) + ctx = context.context() + assert not ctx.in_graph_mode( + ), "The prototype doesn't contain C code for graph construction" + ctx_handle = ctx._handle # pylint: disable=protected-access + + with self.assertRaisesRegexp(ValueError, + "at least 4 items in the input tuple"): + pywrap_tensorflow.TFE_Py_FastPathExecute(ctx_handle, ctx.device_name, + "Identity") + + with self.assertRaisesRegexp(ValueError, + "Expected to be at least 5, was 4"): + pywrap_tensorflow.TFE_Py_FastPathExecute( + ctx_handle, ctx_handle, "Identity", record_gradient_callback) + + with self.assertRaisesRegexp(TypeError, "expected a string for op_name"): + pywrap_tensorflow.TFE_Py_FastPathExecute( + ctx_handle, ctx.device_name, ctx_handle, record_gradient_callback, + a_2_by_2) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/eager/tensor_test.py b/tensorflow/python/eager/tensor_test.py index 2568d3dc0543f925a90f53d77cff724e7effa535..0bd5a5dbafd5ea8da21d4fb8a7dcae9fe23dd3d2 100644 --- a/tensorflow/python/eager/tensor_test.py +++ b/tensorflow/python/eager/tensor_test.py @@ -112,6 +112,19 @@ class TFETensorTest(test_util.TensorFlowTestCase): numpy_tensor = np.asarray(tensor, dtype=np.int32) self.assertAllEqual(numpy_tensor, [1, 2, 3]) + def testNdimsAgreesWithNumpy(self): + numpy_tensor = np.asarray(1.0) + tensor = constant_op.constant(numpy_tensor) + self.assertAllEqual(numpy_tensor.ndim, tensor.ndim) + + numpy_tensor = np.asarray([1.0, 2.0, 3.0]) + tensor = constant_op.constant(numpy_tensor) + self.assertAllEqual(numpy_tensor.ndim, tensor.ndim) + + numpy_tensor = np.asarray([[1.0, 2.0, 3.0], [1.0, 2.0, 3.0]]) + tensor = constant_op.constant(numpy_tensor) + self.assertAllEqual(numpy_tensor.ndim, tensor.ndim) + def testCopy(self): t = constant_op.constant(1.0) tt = copy.copy(t) diff --git a/tensorflow/python/estimator/BUILD b/tensorflow/python/estimator/BUILD index 63436157371148bc858344a57bf4e180d8a34526..41f55b12af893e3207ad1ffa45098d12b1c4fff6 100644 --- a/tensorflow/python/estimator/BUILD +++ b/tensorflow/python/estimator/BUILD @@ -267,6 +267,7 @@ py_library( "//tensorflow/python:training", "//tensorflow/python:variable_scope", "//tensorflow/python/feature_column", + "//tensorflow/python/ops/losses", "@six_archive//:six", ], ) @@ -356,6 +357,7 @@ py_library( "//tensorflow/python:training", "//tensorflow/python:variable_scope", "//tensorflow/python/feature_column", + "//tensorflow/python/ops/losses", "@six_archive//:six", ], ) @@ -624,8 +626,9 @@ py_library( py_test( name = "head_test", - size = "small", + size = "medium", srcs = ["canned/head_test.py"], + shard_count = 4, srcs_version = "PY2AND3", tags = ["no_pip"], deps = [ @@ -679,6 +682,7 @@ py_library( "//tensorflow/python:training", "//tensorflow/python:variable_scope", "//tensorflow/python/feature_column", + "//tensorflow/python/ops/losses", "@six_archive//:six", ], ) diff --git a/tensorflow/python/estimator/canned/dnn.py b/tensorflow/python/estimator/canned/dnn.py index 0392ff9a71920cb966a3731e03b7fc74030292c6..0f274a23c03426fc431c15ac0a14617a4a65bb79 100644 --- a/tensorflow/python/estimator/canned/dnn.py +++ b/tensorflow/python/estimator/canned/dnn.py @@ -22,7 +22,6 @@ import six from tensorflow.python.estimator import estimator from tensorflow.python.estimator import model_fn -from tensorflow.python.estimator import warm_starting_util from tensorflow.python.estimator.canned import head as head_lib from tensorflow.python.estimator.canned import optimizers from tensorflow.python.feature_column import feature_column as feature_column_lib @@ -31,6 +30,7 @@ from tensorflow.python.ops import init_ops from tensorflow.python.ops import nn from tensorflow.python.ops import partitioned_variables from tensorflow.python.ops import variable_scope +from tensorflow.python.ops.losses import losses from tensorflow.python.summary import summary from tensorflow.python.training import training_util @@ -280,6 +280,7 @@ class DNNClassifier(estimator.Estimator): input_layer_partitioner=None, config=None, warm_start_from=None, + loss_reduction=losses.Reduction.SUM, ): """Initializes a `DNNClassifier` instance. @@ -323,19 +324,23 @@ class DNNClassifier(estimator.Estimator): string filepath is provided instead of a `WarmStartSettings`, then all weights are warm-started, and it is assumed that vocabularies and Tensor names are unchanged. + loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how + to reduce training loss over batch. Defaults to `SUM`. """ if n_classes == 2: head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss( # pylint: disable=protected-access weight_column=weight_column, - label_vocabulary=label_vocabulary) + label_vocabulary=label_vocabulary, + loss_reduction=loss_reduction) else: head = head_lib._multi_class_head_with_softmax_cross_entropy_loss( # pylint: disable=protected-access n_classes, weight_column=weight_column, - label_vocabulary=label_vocabulary) + label_vocabulary=label_vocabulary, + loss_reduction=loss_reduction) def _model_fn(features, labels, mode, config): - """Call the defined shared _dnn_model_fn and possibly warm-start.""" - estimator_spec = _dnn_model_fn( + """Call the defined shared _dnn_model_fn.""" + return _dnn_model_fn( features=features, labels=labels, mode=mode, @@ -347,17 +352,10 @@ class DNNClassifier(estimator.Estimator): dropout=dropout, input_layer_partitioner=input_layer_partitioner, config=config) - # pylint: disable=protected-access - warm_start_settings = warm_starting_util._get_default_warm_start_settings( - warm_start_from) - if warm_start_settings: - warm_starting_util._warm_start(warm_start_settings) - # pylint: enable=protected-access - - return estimator_spec super(DNNClassifier, self).__init__( - model_fn=_model_fn, model_dir=model_dir, config=config) + model_fn=_model_fn, model_dir=model_dir, config=config, + warm_start_from=warm_start_from) class DNNRegressor(estimator.Estimator): @@ -441,6 +439,7 @@ class DNNRegressor(estimator.Estimator): input_layer_partitioner=None, config=None, warm_start_from=None, + loss_reduction=losses.Reduction.SUM, ): """Initializes a `DNNRegressor` instance. @@ -478,17 +477,20 @@ class DNNRegressor(estimator.Estimator): string filepath is provided instead of a `WarmStartSettings`, then all weights are warm-started, and it is assumed that vocabularies and Tensor names are unchanged. + loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how + to reduce training loss over batch. Defaults to `SUM`. """ def _model_fn(features, labels, mode, config): - """Call the defined shared _dnn_model_fn and possibly warm-start.""" - estimator_spec = _dnn_model_fn( + """Call the defined shared _dnn_model_fn.""" + return _dnn_model_fn( features=features, labels=labels, mode=mode, head=head_lib. # pylint: disable=protected-access _regression_head_with_mean_squared_error_loss( - label_dimension=label_dimension, weight_column=weight_column), + label_dimension=label_dimension, weight_column=weight_column, + loss_reduction=loss_reduction), hidden_units=hidden_units, feature_columns=tuple(feature_columns or []), optimizer=optimizer, @@ -496,14 +498,7 @@ class DNNRegressor(estimator.Estimator): dropout=dropout, input_layer_partitioner=input_layer_partitioner, config=config) - # pylint: disable=protected-access - warm_start_settings = warm_starting_util._get_default_warm_start_settings( - warm_start_from) - if warm_start_settings: - warm_starting_util._warm_start(warm_start_settings) - # pylint: enable=protected-access - - return estimator_spec super(DNNRegressor, self).__init__( - model_fn=_model_fn, model_dir=model_dir, config=config) + model_fn=_model_fn, model_dir=model_dir, config=config, + warm_start_from=warm_start_from) diff --git a/tensorflow/python/estimator/canned/dnn_linear_combined.py b/tensorflow/python/estimator/canned/dnn_linear_combined.py index 1d06a54a321233722ba0736f7d658cd6029991f8..1a0f4c5c3931a6b41026470f30e7bdd381e5b37a 100644 --- a/tensorflow/python/estimator/canned/dnn_linear_combined.py +++ b/tensorflow/python/estimator/canned/dnn_linear_combined.py @@ -23,7 +23,6 @@ import math import six from tensorflow.python.estimator import estimator -from tensorflow.python.estimator import warm_starting_util from tensorflow.python.estimator.canned import dnn from tensorflow.python.estimator.canned import head as head_lib from tensorflow.python.estimator.canned import linear @@ -34,6 +33,7 @@ from tensorflow.python.ops import nn from tensorflow.python.ops import partitioned_variables from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope +from tensorflow.python.ops.losses import losses from tensorflow.python.summary import summary from tensorflow.python.training import sync_replicas_optimizer from tensorflow.python.training import training_util @@ -309,7 +309,8 @@ class DNNLinearCombinedClassifier(estimator.Estimator): label_vocabulary=None, input_layer_partitioner=None, config=None, - warm_start_from=None): + warm_start_from=None, + loss_reduction=losses.Reduction.SUM): """Initializes a DNNLinearCombinedClassifier instance. Args: @@ -356,6 +357,8 @@ class DNNLinearCombinedClassifier(estimator.Estimator): string filepath is provided instead of a `WarmStartSettings`, then all weights are warm-started, and it is assumed that vocabularies and Tensor names are unchanged. + loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how + to reduce training loss over batch. Defaults to `SUM`. Raises: ValueError: If both linear_feature_columns and dnn_features_columns are @@ -371,16 +374,18 @@ class DNNLinearCombinedClassifier(estimator.Estimator): if n_classes == 2: head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss( # pylint: disable=protected-access weight_column=weight_column, - label_vocabulary=label_vocabulary) + label_vocabulary=label_vocabulary, + loss_reduction=loss_reduction) else: head = head_lib._multi_class_head_with_softmax_cross_entropy_loss( # pylint: disable=protected-access n_classes, weight_column=weight_column, - label_vocabulary=label_vocabulary) + label_vocabulary=label_vocabulary, + loss_reduction=loss_reduction) def _model_fn(features, labels, mode, config): - """Call the _dnn_linear_combined_model_fn and possibly warm-start.""" - estimator_spec = _dnn_linear_combined_model_fn( + """Call the _dnn_linear_combined_model_fn.""" + return _dnn_linear_combined_model_fn( features=features, labels=labels, mode=mode, @@ -394,17 +399,10 @@ class DNNLinearCombinedClassifier(estimator.Estimator): dnn_dropout=dnn_dropout, input_layer_partitioner=input_layer_partitioner, config=config) - # pylint: disable=protected-access - warm_start_settings = warm_starting_util._get_default_warm_start_settings( - warm_start_from) - if warm_start_settings: - warm_starting_util._warm_start(warm_start_settings) - # pylint: enable=protected-access - - return estimator_spec super(DNNLinearCombinedClassifier, self).__init__( - model_fn=_model_fn, model_dir=model_dir, config=config) + model_fn=_model_fn, model_dir=model_dir, config=config, + warm_start_from=warm_start_from) class DNNLinearCombinedRegressor(estimator.Estimator): @@ -490,7 +488,8 @@ class DNNLinearCombinedRegressor(estimator.Estimator): weight_column=None, input_layer_partitioner=None, config=None, - warm_start_from=None): + warm_start_from=None, + loss_reduction=losses.Reduction.SUM): """Initializes a DNNLinearCombinedRegressor instance. Args: @@ -531,6 +530,8 @@ class DNNLinearCombinedRegressor(estimator.Estimator): string filepath is provided instead of a `WarmStartSettings`, then all weights are warm-started, and it is assumed that vocabularies and Tensor names are unchanged. + loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how + to reduce training loss over batch. Defaults to `SUM`. Raises: ValueError: If both linear_feature_columns and dnn_features_columns are @@ -545,14 +546,15 @@ class DNNLinearCombinedRegressor(estimator.Estimator): 'must be defined.') def _model_fn(features, labels, mode, config): - """Call the _dnn_linear_combined_model_fn and possibly warm-start.""" - estimator_spec = _dnn_linear_combined_model_fn( + """Call the _dnn_linear_combined_model_fn.""" + return _dnn_linear_combined_model_fn( features=features, labels=labels, mode=mode, head=head_lib. # pylint: disable=protected-access _regression_head_with_mean_squared_error_loss( - label_dimension=label_dimension, weight_column=weight_column), + label_dimension=label_dimension, weight_column=weight_column, + loss_reduction=loss_reduction), linear_feature_columns=linear_feature_columns, linear_optimizer=linear_optimizer, dnn_feature_columns=dnn_feature_columns, @@ -562,14 +564,7 @@ class DNNLinearCombinedRegressor(estimator.Estimator): dnn_dropout=dnn_dropout, input_layer_partitioner=input_layer_partitioner, config=config) - # pylint: disable=protected-access - warm_start_settings = warm_starting_util._get_default_warm_start_settings( - warm_start_from) - if warm_start_settings: - warm_starting_util._warm_start(warm_start_settings) - # pylint: enable=protected-access - - return estimator_spec super(DNNLinearCombinedRegressor, self).__init__( - model_fn=_model_fn, model_dir=model_dir, config=config) + model_fn=_model_fn, model_dir=model_dir, config=config, + warm_start_from=warm_start_from) diff --git a/tensorflow/python/estimator/canned/dnn_testing_utils.py b/tensorflow/python/estimator/canned/dnn_testing_utils.py index 2bdec693033858fd3bbbb137259b2d129fc72797..706575985ff9e0fef94f110825ec11af33031ea3 100644 --- a/tensorflow/python/estimator/canned/dnn_testing_utils.py +++ b/tensorflow/python/estimator/canned/dnn_testing_utils.py @@ -877,7 +877,7 @@ class BaseDNNWarmStartingTest(object): # Create a second DNNClassifier, warm-started from the first. Use a # learning_rate = 0.0 optimizer to check values (use SGD so we don't have - # accumulator values that change). Use a a new FeatureColumn with a + # accumulator values that change). Use a new FeatureColumn with a # different vocabulary for occupation. new_vocab_list = ['doctor', 'consultant', 'engineer'] new_vocab_file = os.path.join(self._ckpt_and_vocab_dir, diff --git a/tensorflow/python/estimator/canned/head.py b/tensorflow/python/estimator/canned/head.py index 204e1119f2191457359ecaf9fd012fcb5a2b0463..94a5d3a342dd7bad49d5fb4b91166c67a2705ff3 100644 --- a/tensorflow/python/estimator/canned/head.py +++ b/tensorflow/python/estimator/canned/head.py @@ -627,15 +627,15 @@ class _MultiClassHeadWithSoftmaxCrossEntropyLoss(_Head): logits: logits `Tensor` with shape `[D0, D1, ... DN, logits_dimension]`. For many applications, the shape is `[batch_size, logits_dimension]`. labels: Labels integer or string `Tensor` with shape matching `logits`, - namely `[D0, D1, ... DN, 1]`. `labels` is required argument when `mode` - equals `TRAIN` or `EVAL`. + namely `[D0, D1, ... DN, 1]` or `[D0, D1, ... DN]`. `labels` is + required argument when `mode` equals `TRAIN` or `EVAL`. train_op_fn: Function that takes a scalar loss `Tensor` and returns `train_op`. Required in TRAIN mode. regularization_losses: A list of additional scalar losses to be added to the training loss, such as regularization losses. These losses are usually expressed as a batch average, so for best results users need to - set `loss_reduction=MEAN_PER_ELEMENT` or - `loss_reduction=SUM_BY_NONZERO_WEIGHTS` when creating the head to + set `loss_reduction=SUM_OVER_BATCH_SIZE` or + `loss_reduction=SUM_OVER_NONZERO_WEIGHTS` when creating the head to avoid scaling errors. Returns: `EstimatorSpec`. @@ -827,10 +827,10 @@ class _BinaryLogisticHeadWithSigmoidCrossEntropyLoss(_Head): return 1 def _eval_metric_ops(self, labels, logits, logistic, class_ids, weights, - unreduced_loss): + unreduced_loss, regularization_loss): with ops.name_scope(None, 'metrics', (labels, logits, logistic, class_ids, weights, - unreduced_loss)): + unreduced_loss, regularization_loss)): keys = metric_keys.MetricKeys labels_mean = _indicator_labels_mean( labels=labels, weights=weights, name=keys.LABEL_MEAN) @@ -870,6 +870,11 @@ class _BinaryLogisticHeadWithSigmoidCrossEntropyLoss(_Head): curve='PR', name=keys.AUC_PR) } + if regularization_loss is not None: + metric_ops[_summary_key(self._name, keys.LOSS_REGULARIZATION)] = ( + metrics_lib.mean( + values=regularization_loss, + name=keys.LOSS_REGULARIZATION)) for threshold in self._thresholds: accuracy_key = keys.ACCURACY_AT_THRESHOLD % threshold metric_ops[_summary_key(self._name, @@ -924,8 +929,31 @@ class _BinaryLogisticHeadWithSigmoidCrossEntropyLoss(_Head): processed_labels=labels) def create_estimator_spec( - self, features, mode, logits, labels=None, train_op_fn=None): - """See `Head`.""" + self, features, mode, logits, labels=None, train_op_fn=None, + regularization_losses=None): + """Returns an `EstimatorSpec`. + + Args: + features: Input `dict` of `Tensor` or `SparseTensor` objects. + mode: Estimator's `ModeKeys`. + logits: logits `Tensor` with shape `[D0, D1, ... DN, 1]`. For many + applications, the shape is `[batch_size, 1]`. + labels: Labels integer or string `Tensor` with shape matching `logits`, + namely `[D0, D1, ... DN, 1]` or `[D0, D1, ... DN]`. `labels` is required + argument when `mode` equals `TRAIN` or `EVAL`. + train_op_fn: Function that takes a scalar loss `Tensor` and returns + `train_op`. Required in TRAIN mode. + regularization_losses: A list of additional scalar losses to be added to + the training loss, such as regularization losses. These losses are + usually expressed as a batch average, so for best results users need to + set `loss_reduction=SUM_OVER_BATCH_SIZE` or + `loss_reduction=SUM_OVER_NONZERO_WEIGHTS` when creating the head to + avoid scaling errors. + Returns: + `EstimatorSpec`. + Raises: + ValueError: If `train_op_fn` is `None` in TRAIN mode. + """ # Predict. with ops.name_scope(self._name, 'head'): with ops.name_scope(None, 'predictions', (logits,)): @@ -972,20 +1000,28 @@ class _BinaryLogisticHeadWithSigmoidCrossEntropyLoss(_Head): (training_loss, unreduced_loss, weights, processed_labels) = ( self.create_loss( features=features, mode=mode, logits=logits, labels=labels)) + if regularization_losses: + regularization_loss = math_ops.add_n(regularization_losses) + regularized_training_loss = math_ops.add_n( + [training_loss, regularization_loss]) + else: + regularization_loss = None + regularized_training_loss = training_loss # Eval. if mode == model_fn.ModeKeys.EVAL: return model_fn.EstimatorSpec( mode=model_fn.ModeKeys.EVAL, predictions=predictions, - loss=training_loss, + loss=regularized_training_loss, eval_metric_ops=self._eval_metric_ops( labels=processed_labels, logits=logits, logistic=logistic, class_ids=class_ids, weights=weights, - unreduced_loss=unreduced_loss)) + unreduced_loss=unreduced_loss, + regularization_loss=regularization_loss)) # Train. if train_op_fn is None: @@ -999,18 +1035,22 @@ class _BinaryLogisticHeadWithSigmoidCrossEntropyLoss(_Head): else: mean_loss = None with ops.name_scope(''): + keys = metric_keys.MetricKeys summary.scalar( - _summary_key(self._name, metric_keys.MetricKeys.LOSS), - training_loss) + _summary_key(self._name, keys.LOSS), + regularized_training_loss) if mean_loss is not None: summary.scalar( - _summary_key(self._name, metric_keys.MetricKeys.LOSS_MEAN), - mean_loss) + _summary_key(self._name, keys.LOSS_MEAN), mean_loss) + if regularization_loss is not None: + summary.scalar( + _summary_key(self._name, keys.LOSS_REGULARIZATION), + regularization_loss) return model_fn.EstimatorSpec( mode=model_fn.ModeKeys.TRAIN, predictions=predictions, - loss=training_loss, - train_op=train_op_fn(training_loss)) + loss=regularized_training_loss, + train_op=train_op_fn(regularized_training_loss)) def _regression_head_with_mean_squared_error_loss( @@ -1111,7 +1151,8 @@ class _RegressionHeadWithMeanSquaredErrorLoss(_Head): processed_labels=labels) def create_estimator_spec( - self, features, mode, logits, labels=None, train_op_fn=None): + self, features, mode, logits, labels=None, train_op_fn=None, + regularization_losses=None): """Returns an `EstimatorSpec`. Args: @@ -1125,6 +1166,12 @@ class _RegressionHeadWithMeanSquaredErrorLoss(_Head): `mode` equals `TRAIN` or `EVAL`. train_op_fn: Function that takes a scalar loss `Tensor` and returns `train_op`. Required in TRAIN mode. + regularization_losses: A list of additional scalar losses to be added to + the training loss, such as regularization losses. These losses are + usually expressed as a batch average, so for best results users need to + set `loss_reduction=SUM_OVER_BATCH_SIZE` or + `loss_reduction=SUM_OVER_NONZERO_WEIGHTS` when creating the head to + avoid scaling errors. Returns: `EstimatorSpec`. Raises: @@ -1147,20 +1194,34 @@ class _RegressionHeadWithMeanSquaredErrorLoss(_Head): training_loss, unreduced_loss, weights, _ = self.create_loss( features=features, mode=mode, logits=logits, labels=labels) + if regularization_losses: + regularization_loss = math_ops.add_n(regularization_losses) + regularized_training_loss = math_ops.add_n( + [training_loss, regularization_loss]) + else: + regularization_loss = None + regularized_training_loss = training_loss # Eval. if mode == model_fn.ModeKeys.EVAL: + keys = metric_keys.MetricKeys # Estimator already adds a metric for loss. eval_metric_ops = { - _summary_key(self._name, metric_keys.MetricKeys.LOSS_MEAN): + _summary_key(self._name, keys.LOSS_MEAN): metrics_lib.mean( values=unreduced_loss, weights=weights) } + if regularization_loss is not None: + regularization_loss_key = _summary_key( + self._name, keys.LOSS_REGULARIZATION) + eval_metric_ops[regularization_loss_key] = metrics_lib.mean( + values=regularization_loss, + name=keys.LOSS_REGULARIZATION) return model_fn.EstimatorSpec( mode=model_fn.ModeKeys.EVAL, predictions=predictions, - loss=training_loss, + loss=regularized_training_loss, eval_metric_ops=eval_metric_ops) # Train. @@ -1175,18 +1236,22 @@ class _RegressionHeadWithMeanSquaredErrorLoss(_Head): else: mean_loss = None with ops.name_scope(''): + keys = metric_keys.MetricKeys summary.scalar( - _summary_key(self._name, metric_keys.MetricKeys.LOSS), - training_loss) + _summary_key(self._name, keys.LOSS), + regularized_training_loss) if mean_loss is not None: summary.scalar( - _summary_key(self._name, metric_keys.MetricKeys.LOSS_MEAN), - mean_loss) + _summary_key(self._name, keys.LOSS_MEAN), mean_loss) + if regularization_loss is not None: + summary.scalar( + _summary_key(self._name, keys.LOSS_REGULARIZATION), + regularization_loss) return model_fn.EstimatorSpec( mode=model_fn.ModeKeys.TRAIN, predictions=predictions, - loss=training_loss, - train_op=train_op_fn(training_loss)) + loss=regularized_training_loss, + train_op=train_op_fn(regularized_training_loss)) def _assert_range(labels, n_classes, message=None): diff --git a/tensorflow/python/estimator/canned/head_test.py b/tensorflow/python/estimator/canned/head_test.py index 28b8e635fb483252edf68a4140bb57c3b99fb96a..4e871e8f375f346bfd1b0be2cade97c34871f31c 100644 --- a/tensorflow/python/estimator/canned/head_test.py +++ b/tensorflow/python/estimator/canned/head_test.py @@ -487,7 +487,7 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase): def test_eval_with_regularization_losses(self): n_classes = 3 head = head_lib._multi_class_head_with_softmax_cross_entropy_loss( - n_classes, loss_reduction=losses.Reduction.MEAN) + n_classes, loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE) logits = np.array(((10, 0, 0), (0, 10, 0),), dtype=np.float32) labels = np.array(((1,), (1,)), dtype=np.int64) features = {'x': np.array(((42,),), dtype=np.int32)} @@ -790,7 +790,7 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase): def test_train_with_regularization_losses(self): n_classes = 3 head = head_lib._multi_class_head_with_softmax_cross_entropy_loss( - n_classes, loss_reduction=losses.Reduction.MEAN) + n_classes, loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE) logits = np.array(((10, 0, 0), (0, 10, 0),), dtype=np.float32) labels = np.array(((1,), (1,)), dtype=np.int64) @@ -1485,6 +1485,53 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase): ] self.assertItemsEqual(expected_metric_keys, spec.eval_metric_ops.keys()) + def test_eval_with_regularization_losses(self): + head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss( + loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE) + logits = np.array(((45,), (-41,),), dtype=np.float32) + labels = np.array(((1,), (1,),), dtype=np.int32) + features = {'x': np.array(((42,),), dtype=np.int32)} + regularization_losses = [1.5, 0.5] + expected_regularization_loss = 2. + # unregularized_loss = sum(cross_entropy(labels, logits)) / batch_size + # = sum(0, 41) / 2 = 20.5 + expected_unregularized_loss = 20.5 + expected_regularized_loss = ( + expected_unregularized_loss + expected_regularization_loss) + + # Create estimator spec. + spec = head.create_estimator_spec( + features=features, + mode=model_fn.ModeKeys.EVAL, + logits=logits, + labels=labels, + regularization_losses=regularization_losses) + + keys = metric_keys.MetricKeys + expected_metrics = { + keys.LOSS_MEAN: expected_unregularized_loss, + keys.LOSS_REGULARIZATION: expected_regularization_loss, + keys.ACCURACY: 1./2, + keys.PREDICTION_MEAN: 1./2, + keys.LABEL_MEAN: 2./2, + keys.ACCURACY_BASELINE: 2./2, + keys.AUC: 0., + keys.AUC_PR: 1., + } + + # Assert predictions, loss, and metrics. + with self.test_session() as sess: + _initialize_variables(self, spec.scaffold) + self.assertIsNone(spec.scaffold.summary_op) + value_ops = {k: spec.eval_metric_ops[k][0] for k in spec.eval_metric_ops} + update_ops = {k: spec.eval_metric_ops[k][1] for k in spec.eval_metric_ops} + loss, metrics = sess.run((spec.loss, update_ops)) + self.assertAllClose(expected_regularized_loss, loss) + # Check results of both update (in `metrics`) and value ops. + self.assertAllClose(expected_metrics, metrics) + self.assertAllClose( + expected_metrics, {k: value_ops[k].eval() for k in value_ops}) + def test_eval_with_vocabulary_list_create_loss(self): head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss( label_vocabulary=['aang', 'iroh']) @@ -1749,6 +1796,49 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase): }, summary_str) + def test_train_with_regularization_losses(self): + head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss( + loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE) + + logits = np.array(((45,), (-41,),), dtype=np.float32) + labels = np.array(((1,), (1,),), dtype=np.float64) + expected_train_result = b'my_train_op' + features = {'x': np.array(((42,),), dtype=np.float32)} + regularization_losses = [1.5, 0.5] + expected_regularization_loss = 2. + # unregularized_loss = sum(cross_entropy(labels, logits)) / batch_size + # = sum(0, 41) / 2 = 20.5 + # loss = unregularized_loss + regularization_loss = 7. + expected_loss = 22.5 + def _train_op_fn(loss): + with ops.control_dependencies((check_ops.assert_equal( + math_ops.to_float(expected_loss), math_ops.to_float(loss), + name='assert_loss'),)): + return constant_op.constant(expected_train_result) + + # Create estimator spec. + spec = head.create_estimator_spec( + features=features, + mode=model_fn.ModeKeys.TRAIN, + logits=logits, + labels=labels, + train_op_fn=_train_op_fn, + regularization_losses=regularization_losses) + + # Assert predictions, loss, train_op, and summaries. + with self.test_session() as sess: + _initialize_variables(self, spec.scaffold) + self.assertIsNotNone(spec.scaffold.summary_op) + loss, train_result, summary_str = sess.run((spec.loss, spec.train_op, + spec.scaffold.summary_op)) + self.assertAllClose(expected_loss, loss) + self.assertEqual(expected_train_result, train_result) + _assert_simple_summaries(self, { + metric_keys.MetricKeys.LOSS: expected_loss, + metric_keys.MetricKeys.LOSS_REGULARIZATION: ( + expected_regularization_loss), + }, summary_str) + def test_float_labels_train_create_loss(self): head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss() @@ -2512,6 +2602,51 @@ class RegressionHeadWithMeanSquaredErrorLossTest(test.TestCase): ] self.assertItemsEqual(expected_metric_keys, spec.eval_metric_ops.keys()) + def test_eval_with_regularization_losses(self): + head = head_lib._regression_head_with_mean_squared_error_loss( + loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE) + self.assertEqual(1, head.logits_dimension) + + logits = np.array(((45,), (41,),), dtype=np.float32) + labels = np.array(((43,), (44,),), dtype=np.int32) + features = {'x': np.array(((42,),), dtype=np.float32)} + regularization_losses = [1.5, 0.5] + expected_regularization_loss = 2. + # unregularized_loss = ((43-45)^2 + (44-41)^2) / batch_size + # = (4 + 9) / 2 = 6.5 + expected_unregularized_loss = 6.5 + expected_regularized_loss = ( + expected_unregularized_loss + expected_regularization_loss) + # Create estimator spec. + spec = head.create_estimator_spec( + features=features, + mode=model_fn.ModeKeys.EVAL, + logits=logits, + labels=labels, + regularization_losses=regularization_losses) + + keys = metric_keys.MetricKeys + expected_metrics = { + keys.LOSS_MEAN: expected_unregularized_loss, + keys.LOSS_REGULARIZATION: expected_regularization_loss, + } + + # Assert predictions, loss, and metrics. + with self.test_session() as sess: + _initialize_variables(self, spec.scaffold) + self.assertIsNone(spec.scaffold.summary_op) + value_ops = {k: spec.eval_metric_ops[k][0] for k in spec.eval_metric_ops} + update_ops = {k: spec.eval_metric_ops[k][1] for k in spec.eval_metric_ops} + prediction_key = prediction_keys.PredictionKeys.PREDICTIONS + predictions, loss, metrics = sess.run(( + spec.predictions[prediction_key], spec.loss, update_ops)) + self.assertAllClose(logits, predictions) + self.assertAllClose(expected_regularized_loss, loss) + # Check results of both update (in `metrics`) and value ops. + self.assertAllClose(expected_metrics, metrics) + self.assertAllClose( + expected_metrics, {k: value_ops[k].eval() for k in value_ops}) + def test_train_create_loss(self): head = head_lib._regression_head_with_mean_squared_error_loss() logits = np.array(((45,), (41,),), dtype=np.float32) @@ -2666,6 +2801,53 @@ class RegressionHeadWithMeanSquaredErrorLossTest(test.TestCase): }, summary_str) + def test_train_with_regularization_losses(self): + head = head_lib._regression_head_with_mean_squared_error_loss( + loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE) + self.assertEqual(1, head.logits_dimension) + + # Create estimator spec. + logits = np.array(((45,), (41,),), dtype=np.float32) + labels = np.array(((43.,), (44.,),), dtype=np.float64) + expected_train_result = b'my_train_op' + features = {'x': np.array(((42.,),), dtype=np.float32)} + regularization_losses = [1.5, 0.5] + expected_regularization_loss = 2. + # unregularized_loss = ((43-45)^2 + (44-41)^2) / batch_size + # = (4 + 9) / 2 = 6.5 + # loss = unregularized_loss + regularization_loss = 8.5 + expected_loss = 8.5 + def _train_op_fn(loss): + with ops.control_dependencies((check_ops.assert_equal( + math_ops.to_float(expected_loss), math_ops.to_float(loss), + name='assert_loss'),)): + return constant_op.constant(expected_train_result) + + spec = head.create_estimator_spec( + features=features, + mode=model_fn.ModeKeys.TRAIN, + logits=logits, + labels=labels, + train_op_fn=_train_op_fn, + regularization_losses=regularization_losses) + + # Assert predictions, loss, train_op, and summaries. + with self.test_session() as sess: + _initialize_variables(self, spec.scaffold) + self.assertIsNotNone(spec.scaffold.summary_op) + prediction_key = prediction_keys.PredictionKeys.PREDICTIONS + predictions, loss, train_result, summary_str = sess.run(( + spec.predictions[prediction_key], spec.loss, spec.train_op, + spec.scaffold.summary_op)) + self.assertAllClose(logits, predictions) + self.assertAllClose(expected_loss, loss) + self.assertEqual(expected_train_result, train_result) + _assert_simple_summaries(self, { + metric_keys.MetricKeys.LOSS: expected_loss, + metric_keys.MetricKeys.LOSS_REGULARIZATION: ( + expected_regularization_loss), + }, summary_str) + def test_weighted_multi_example_eval(self): """1d label, 3 examples, 1 batch.""" head = head_lib._regression_head_with_mean_squared_error_loss( diff --git a/tensorflow/python/estimator/canned/linear.py b/tensorflow/python/estimator/canned/linear.py index 97cfd24a101edbb88bca54fe3e213d126002779b..a5b1172e729240a2ea02fa1d4330420786c2686c 100644 --- a/tensorflow/python/estimator/canned/linear.py +++ b/tensorflow/python/estimator/canned/linear.py @@ -23,7 +23,6 @@ import math import six from tensorflow.python.estimator import estimator -from tensorflow.python.estimator import warm_starting_util from tensorflow.python.estimator.canned import head as head_lib from tensorflow.python.estimator.canned import optimizers from tensorflow.python.feature_column import feature_column as feature_column_lib @@ -31,6 +30,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import nn from tensorflow.python.ops import partitioned_variables from tensorflow.python.ops import variable_scope +from tensorflow.python.ops.losses import losses from tensorflow.python.summary import summary from tensorflow.python.training import ftrl from tensorflow.python.training import training_util @@ -245,7 +245,8 @@ class LinearClassifier(estimator.Estimator): optimizer='Ftrl', config=None, partitioner=None, - warm_start_from=None): + warm_start_from=None, + loss_reduction=losses.Reduction.SUM): """Construct a `LinearClassifier` estimator object. Args: @@ -282,6 +283,8 @@ class LinearClassifier(estimator.Estimator): string filepath is provided instead of a `WarmStartSettings`, then all weights and biases are warm-started, and it is assumed that vocabularies and Tensor names are unchanged. + loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how + to reduce training loss over batch. Defaults to `SUM`. Returns: A `LinearClassifier` estimator. @@ -292,15 +295,17 @@ class LinearClassifier(estimator.Estimator): if n_classes == 2: head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss( # pylint: disable=protected-access weight_column=weight_column, - label_vocabulary=label_vocabulary) + label_vocabulary=label_vocabulary, + loss_reduction=loss_reduction) else: head = head_lib._multi_class_head_with_softmax_cross_entropy_loss( # pylint: disable=protected-access n_classes, weight_column=weight_column, - label_vocabulary=label_vocabulary) + label_vocabulary=label_vocabulary, + loss_reduction=loss_reduction) def _model_fn(features, labels, mode, config): - """Call the defined shared _linear_model_fn and possibly warm-start.""" - estimator_spec = _linear_model_fn( + """Call the defined shared _linear_model_fn.""" + return _linear_model_fn( features=features, labels=labels, mode=mode, @@ -309,19 +314,12 @@ class LinearClassifier(estimator.Estimator): optimizer=optimizer, partitioner=partitioner, config=config) - # pylint: disable=protected-access - warm_start_settings = warm_starting_util._get_default_warm_start_settings( - warm_start_from) - if warm_start_settings: - warm_starting_util._warm_start(warm_start_settings) - # pylint: enable=protected-access - - return estimator_spec super(LinearClassifier, self).__init__( model_fn=_model_fn, model_dir=model_dir, - config=config) + config=config, + warm_start_from=warm_start_from) class LinearRegressor(estimator.Estimator): @@ -388,7 +386,8 @@ class LinearRegressor(estimator.Estimator): optimizer='Ftrl', config=None, partitioner=None, - warm_start_from=None): + warm_start_from=None, + loss_reduction=losses.Reduction.SUM): """Initializes a `LinearRegressor` instance. Args: @@ -417,13 +416,16 @@ class LinearRegressor(estimator.Estimator): string filepath is provided instead of a `WarmStartSettings`, then all weights and biases are warm-started, and it is assumed that vocabularies and Tensor names are unchanged. + loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how + to reduce training loss over batch. Defaults to `SUM`. """ head = head_lib._regression_head_with_mean_squared_error_loss( # pylint: disable=protected-access - label_dimension=label_dimension, weight_column=weight_column) + label_dimension=label_dimension, weight_column=weight_column, + loss_reduction=loss_reduction) def _model_fn(features, labels, mode, config): - """Call the defined shared _linear_model_fn and possibly warm-start.""" - estimator_spec = _linear_model_fn( + """Call the defined shared _linear_model_fn.""" + return _linear_model_fn( features=features, labels=labels, mode=mode, @@ -432,16 +434,9 @@ class LinearRegressor(estimator.Estimator): optimizer=optimizer, partitioner=partitioner, config=config) - # pylint: disable=protected-access - warm_start_settings = warm_starting_util._get_default_warm_start_settings( - warm_start_from) - if warm_start_settings: - warm_starting_util._warm_start(warm_start_settings) - # pylint: enable=protected-access - - return estimator_spec super(LinearRegressor, self).__init__( model_fn=_model_fn, model_dir=model_dir, - config=config) + config=config, + warm_start_from=warm_start_from) diff --git a/tensorflow/python/estimator/canned/linear_testing_utils.py b/tensorflow/python/estimator/canned/linear_testing_utils.py index cccb9af4b21daca45b9db5b921cd6a0a726edb7e..3e9183cf1b633757074377472e9b4cac953e04a1 100644 --- a/tensorflow/python/estimator/canned/linear_testing_utils.py +++ b/tensorflow/python/estimator/canned/linear_testing_utils.py @@ -2003,7 +2003,7 @@ class BaseLinearWarmStartingTest(object): # Create a second LinearClassifier, warm-started from the first. Use a # learning_rate = 0.0 optimizer to check values (use SGD so we don't have - # accumulator values that change). Use a a new FeatureColumn with a + # accumulator values that change). Use a new FeatureColumn with a # different vocabulary for occupation. new_vocab_list = ['doctor', 'consultant', 'engineer'] new_vocab_file = os.path.join(self._ckpt_and_vocab_dir, diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py index 90eecc1fda5c432a348bbaa4d35c4dc92f2d7489..96555b5e03c7a291480b3c30fe1f2c641c5c75e1 100644 --- a/tensorflow/python/estimator/estimator.py +++ b/tensorflow/python/estimator/estimator.py @@ -35,6 +35,7 @@ from tensorflow.python.eager import context from tensorflow.python.estimator import model_fn as model_fn_lib from tensorflow.python.estimator import run_config from tensorflow.python.estimator import util +from tensorflow.python.estimator import warm_starting_util from tensorflow.python.estimator.export.export import build_all_signature_defs from tensorflow.python.estimator.export.export import get_temp_export_dir from tensorflow.python.estimator.export.export import get_timestamped_export_dir @@ -54,6 +55,7 @@ from tensorflow.python.training import saver from tensorflow.python.training import training from tensorflow.python.training import training_util from tensorflow.python.util import compat +from tensorflow.python.util import compat_internal from tensorflow.python.util import nest @@ -96,9 +98,22 @@ class Estimator(object): @end_compatibility """ - def __init__(self, model_fn, model_dir=None, config=None, params=None): + def __init__(self, model_fn, model_dir=None, config=None, params=None, + warm_start_from=None): """Constructs an `Estimator` instance. + See @{$estimators} for more information. To warm-start an `Estimator`: + + ```python + estimator = tf.estimator.DNNClassifier( + feature_columns=[categorical_feature_a_emb, categorical_feature_b_emb], + hidden_units=[1024, 512, 256], + warm_start_from="/path/to/checkpoint/dir") + ``` + + For more details on warm-start configuration, see + @{tf.estimator.WarmStartSettings$WarmStartSettings}. + Args: model_fn: Model function. Follows the signature: @@ -128,12 +143,19 @@ class Estimator(object): model_dir: Directory to save model parameters, graph and etc. This can also be used to load checkpoints from the directory into a estimator to - continue training a previously saved model. If `None`, the model_dir in - `config` will be used if set. If both are set, they must be same. If - both are `None`, a temporary directory will be used. + continue training a previously saved model. If `PathLike` object, the + path will be resolved. If `None`, the model_dir in `config` will be used + if set. If both are set, they must be same. If both are `None`, a + temporary directory will be used. config: Configuration object. params: `dict` of hyper parameters that will be passed into `model_fn`. Keys are names of parameters, values are basic python types. + warm_start_from: Optional string filepath to a checkpoint to warm-start + from, or a `tf.estimator.WarmStartSettings` object to + fully configure warm-starting. If the string filepath is + provided instead of a `WarmStartSettings`, then all + variables are warm-started, and it is assumed that + vocabularies and Tensor names are unchanged. Raises: RuntimeError: If eager execution is enabled. @@ -158,6 +180,7 @@ class Estimator(object): self._config = config # Model directory. + model_dir = compat_internal.path_to_str(model_dir) if (model_dir is not None) and (self._config.model_dir is not None): if model_dir != self._config.model_dir: # TODO(alanyee): remove this suppression after it is no longer needed @@ -190,6 +213,11 @@ class Estimator(object): self._model_fn = model_fn self._params = copy.deepcopy(params or {}) + # pylint: disable=protected-access + self._warm_start_settings = ( + warm_starting_util._get_default_warm_start_settings(warm_start_from)) + # pylint: enable=protected-access + @property def model_dir(self): return self._model_dir @@ -453,6 +481,7 @@ class Estimator(object): with training.MonitoredSession( session_creator=training.ChiefSessionCreator( checkpoint_filename_with_path=checkpoint_path, + master=self._config.master, scaffold=estimator_spec.scaffold, config=self._session_config), hooks=input_hooks + hooks) as mon_sess: @@ -778,6 +807,13 @@ class Estimator(object): worker_hooks.extend(input_hooks) estimator_spec = self._call_model_fn( features, labels, model_fn_lib.ModeKeys.TRAIN, self.config) + + if self._warm_start_settings: + logging.info('Warm-starting with WarmStartSettings: %s' % + (self._warm_start_settings,)) + # pylint: disable=protected-access + warm_starting_util._warm_start(self._warm_start_settings) + # pylint: enable=protected-access # Check if the user created a loss summary, and add one if they didn't. # We assume here that the summary is called 'loss'. If it is not, we will # make another one with the name 'loss' to ensure it shows up in the right diff --git a/tensorflow/python/estimator/estimator_test.py b/tensorflow/python/estimator/estimator_test.py index ed1676a92de19203be8bc61fc6efeb559a2fb8aa..833f3dcac3b97962c967cba9ac7ab53a3b9c61f1 100644 --- a/tensorflow/python/estimator/estimator_test.py +++ b/tensorflow/python/estimator/estimator_test.py @@ -52,6 +52,7 @@ from tensorflow.python.ops import metrics as metrics_lib from tensorflow.python.ops import parsing_ops from tensorflow.python.ops import state_ops from tensorflow.python.ops import string_ops +from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.ops.losses import losses from tensorflow.python.platform import gfile @@ -629,6 +630,33 @@ class EstimatorTrainTest(test.TestCase): self.assertEqual( 10, estimator._load_global_step_from_checkpoint_dir(est.model_dir)) + def test_warm_starts(self): + def _make_model_fn(x): + def _variable_creating_model_fn(features, labels, mode): + _, _ = features, labels + variable_scope.get_variable('x', initializer=x) + global_step = training.get_global_step() + return model_fn_lib.EstimatorSpec( + mode, + loss=constant_op.constant(1.), + train_op=state_ops.assign_add(global_step, 1)) + return _variable_creating_model_fn + + est = estimator.Estimator(model_fn=_make_model_fn(42.)) + est.train(dummy_input_fn, steps=10) + + warm_started_est = estimator.Estimator( + model_fn=_make_model_fn(36.), + warm_start_from=est.model_dir) + warm_started_est.train(dummy_input_fn, steps=5) + # warm_start is called after the model_fn, so x should have the value + # from the checkpoint. + self.assertEqual(42., warm_started_est.get_variable_value('x')) + # global_step should not be warm-started. + self.assertEqual( + 5, estimator._load_global_step_from_checkpoint_dir( + warm_started_est.model_dir)) + def test_max_step(self): est = estimator.Estimator(model_fn=model_fn_global_step_incrementer) est.train(dummy_input_fn, max_steps=5) diff --git a/tensorflow/python/estimator/run_config.py b/tensorflow/python/estimator/run_config.py index dc714d4d22ccf6c14c544f84cba99b2bac55da88..e446b3e03a262e0d4abe69df73cb8604b0dab9f9 100644 --- a/tensorflow/python/estimator/run_config.py +++ b/tensorflow/python/estimator/run_config.py @@ -27,6 +27,8 @@ import six from tensorflow.core.protobuf import config_pb2 from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import server_lib +from tensorflow.python.util import compat +from tensorflow.python.util import compat_internal _USE_DEFAULT = object() @@ -399,7 +401,8 @@ class RunConfig(object): Args: model_dir: directory where model parameters, graph, etc are saved. If - `None`, will use a default value set by the Estimator. + `PathLike` object, the path will be resolved. If `None`, will use a + default value set by the Estimator. tf_random_seed: Random seed for TensorFlow initializers. Setting this value allows consistency between reruns. save_summary_steps: Save summaries every this many steps. @@ -442,7 +445,8 @@ class RunConfig(object): if tf_config: logging.info('TF_CONFIG environment variable: %s', tf_config) - model_dir = _get_model_dir(tf_config, model_dir) + model_dir = _get_model_dir(tf_config, + compat_internal.path_to_str(model_dir)) RunConfig._replace( self, @@ -484,7 +488,7 @@ class RunConfig(object): self._num_ps_replicas = _count_ps(self._cluster_spec) self._num_worker_replicas = _count_worker( self._cluster_spec, chief_task_type=TaskType.CHIEF) - self._global_id = _get_global_id_in_cluster( + self._global_id_in_cluster = _get_global_id_in_cluster( self._cluster_spec, self._task_type, self._task_id, @@ -495,14 +499,14 @@ class RunConfig(object): self._master = _LOCAL_MASTER self._num_ps_replicas = 0 self._num_worker_replicas = 0 - self._global_id = None # undefined + self._global_id_in_cluster = None # undefined self._is_chief = self._task_type == TaskType.CHIEF else: # Local mode. self._task_type = task_env.get(_TASK_TYPE_KEY, TaskType.WORKER) self._task_id = int(task_env.get(_TASK_ID_KEY, 0)) - self._global_id = 0 + self._global_id_in_cluster = 0 if self._task_type != TaskType.WORKER: raise ValueError( @@ -537,7 +541,7 @@ class RunConfig(object): raise ValueError('If `master` node exists in `cluster`, task_type ' '`evaluator` is not supported.') - self._global_id = _get_global_id_in_cluster( + self._global_id_in_cluster = _get_global_id_in_cluster( self._cluster_spec, self._task_type, self._task_id, @@ -619,7 +623,7 @@ class RunConfig(object): Returns: An integer id. """ - return self._global_id + return self._global_id_in_cluster @property def task_type(self): diff --git a/tensorflow/python/estimator/warm_starting_util.py b/tensorflow/python/estimator/warm_starting_util.py index c748b318b730f4a4ff855c5e4558da88ada9581b..ad95c71234f82457cb938ca55214b28086b033a2 100644 --- a/tensorflow/python/estimator/warm_starting_util.py +++ b/tensorflow/python/estimator/warm_starting_util.py @@ -402,10 +402,10 @@ def _warm_start_var_with_vocab(var, def _warm_start(warm_start_settings): - """Warmstarts a model using the given settings. + """Warm-starts a model using the given settings. - Currently, this is intended for use only in canned Estimators. Once made - public, it can be used in any model_fn. + If you are using a tf.estimator.Estimator, this will automatically be called + during training. Args: warm_start_settings: An object of `WarmStartSettings`. diff --git a/tensorflow/python/feature_column/BUILD b/tensorflow/python/feature_column/BUILD index 76d44fc474f936733f4eeeefd5d9510964ebb430..a758f8a4fc4898713772c4e919acda48b0f6ad0b 100644 --- a/tensorflow/python/feature_column/BUILD +++ b/tensorflow/python/feature_column/BUILD @@ -85,6 +85,7 @@ py_test( "//tensorflow/python:dtypes", "//tensorflow/python:errors", "//tensorflow/python:framework_ops", + "//tensorflow/python:framework_test_lib", "//tensorflow/python:lookup_ops", "//tensorflow/python:parsing_ops", "//tensorflow/python:partitioned_variables", @@ -93,6 +94,8 @@ py_test( "//tensorflow/python:training", "//tensorflow/python:variable_scope", "//tensorflow/python:variables", + "//tensorflow/python/eager:backprop", + "//tensorflow/python/eager:context", "//tensorflow/python/estimator:numpy_io", ], ) diff --git a/tensorflow/python/feature_column/feature_column.py b/tensorflow/python/feature_column/feature_column.py index a7fe528ee1d85c3c06d4e9376ca4937aaf168b8a..7feb209cc49c4be70387c44168dbdeea6d108d66 100644 --- a/tensorflow/python/feature_column/feature_column.py +++ b/tensorflow/python/feature_column/feature_column.py @@ -657,11 +657,11 @@ def embedding_column( trainable=trainable) -def _shared_embedding_columns( +def shared_embedding_columns( categorical_columns, dimension, combiner='mean', initializer=None, shared_embedding_collection_name=None, ckpt_to_load_from=None, tensor_name_in_ckpt=None, max_norm=None, trainable=True): - """List of `_DenseColumn`s that convert from sparse, categorical input. + """List of dense columns that convert from sparse, categorical input. This is similar to `embedding_column`, except that that it produces a list of embedding columns that share the same embedding weights. @@ -670,7 +670,7 @@ def _shared_embedding_columns( impression video IDs that share the same vocabulary), and you want to convert them to a dense representation (e.g., to feed to a DNN). - Inputs must be a list of `_CategoricalColumn` created by any of the + Inputs must be a list of categorical columns created by any of the `categorical_column_*` function. They must all be of the same type and have the same arguments except `key`. E.g. they can be categorical_column_with_vocabulary_file with the same vocabulary_file. Some or @@ -714,7 +714,7 @@ def _shared_embedding_columns( ``` Args: - categorical_columns: List of `_CategoricalColumn`s created by a + categorical_columns: List of categorical columns created by a `categorical_column_with_*` function. These columns produce the sparse IDs that are inputs to the embedding lookup. All columns must be of the same type and have the same arguments except `key`. E.g. they can be @@ -744,7 +744,7 @@ def _shared_embedding_columns( trainable: Whether or not the embedding is trainable. Default is True. Returns: - A list of `_DenseColumn`s that converts from sparse input. The order of + A list of dense columns that converts from sparse input. The order of results follows the ordering of `categorical_columns`. Raises: diff --git a/tensorflow/python/feature_column/feature_column_lib.py b/tensorflow/python/feature_column/feature_column_lib.py index 8a57986764f9f5e2cff788817cc7706089dc73b0..505a1408d271e9262226b2ea4cff234345e2f3b6 100644 --- a/tensorflow/python/feature_column/feature_column_lib.py +++ b/tensorflow/python/feature_column/feature_column_lib.py @@ -29,6 +29,7 @@ _allowed_symbols = [ 'linear_model', 'make_parse_example_spec', 'embedding_column', + 'shared_embedding_columns', 'crossed_column', 'numeric_column', 'bucketized_column', diff --git a/tensorflow/python/feature_column/feature_column_test.py b/tensorflow/python/feature_column/feature_column_test.py index 2374680b968813b76d0ec115aa46c547eb9ab036..6f366e77229577b1a6a5363f882daa07203f525c 100644 --- a/tensorflow/python/feature_column/feature_column_test.py +++ b/tensorflow/python/feature_column/feature_column_test.py @@ -29,7 +29,6 @@ from tensorflow.python.client import session from tensorflow.python.eager import backprop from tensorflow.python.eager import context from tensorflow.python.estimator.inputs import numpy_io -from tensorflow.python.feature_column import feature_column as fc_lib from tensorflow.python.feature_column import feature_column_lib as fc from tensorflow.python.feature_column.feature_column import _CategoricalColumn from tensorflow.python.feature_column.feature_column import _DenseColumn @@ -1072,6 +1071,7 @@ def get_linear_model_column_var(column): 'linear_model/' + column.name)[0] +@test_util.with_c_api class LinearModelTest(test.TestCase): def test_raises_if_empty_feature_columns(self): @@ -1325,10 +1325,16 @@ class LinearModelTest(test.TestCase): price = fc.numeric_column('price', shape=2) with ops.Graph().as_default(): features = {'price': [[1.], [5.]]} - predictions = fc.linear_model(features, [price]) - with _initialized_session(): - with self.assertRaisesRegexp(Exception, 'requested shape has 4'): - predictions.eval() + if ops._USE_C_API: + with self.assertRaisesRegexp( + Exception, + r'Cannot reshape a tensor with 2 elements to shape \[2,2\]'): + predictions = fc.linear_model(features, [price]) + else: + predictions = fc.linear_model(features, [price]) + with _initialized_session(): + with self.assertRaisesRegexp(Exception, 'requested shape has 4'): + predictions.eval() def test_dense_reshaping(self): price = fc.numeric_column('price', shape=[1, 2]) @@ -1791,6 +1797,7 @@ class InputLayerTest(test.TestCase): self.assertAllEqual([[2, 2], [2, 2], [2, 2]], gradient) +@test_util.with_c_api class FunctionalInputLayerTest(test.TestCase): def test_raises_if_empty_feature_columns(self): @@ -1855,10 +1862,16 @@ class FunctionalInputLayerTest(test.TestCase): price = fc.numeric_column('price', shape=2) with ops.Graph().as_default(): features = {'price': [[1.], [5.]]} - net = fc.input_layer(features, [price]) - with _initialized_session(): - with self.assertRaisesRegexp(Exception, 'requested shape has 4'): - net.eval() + if ops._USE_C_API: + with self.assertRaisesRegexp( + Exception, + r'Cannot reshape a tensor with 2 elements to shape \[2,2\]'): + net = fc.input_layer(features, [price]) + else: + net = fc.input_layer(features, [price]) + with _initialized_session(): + with self.assertRaisesRegexp(Exception, 'requested shape has 4'): + net.eval() def test_reshaping(self): price = fc.numeric_column('price', shape=[1, 2]) @@ -4137,7 +4150,7 @@ class SharedEmbeddingColumnTest(test.TestCase): categorical_column_b = fc.categorical_column_with_identity( key='bbb', num_buckets=3) embedding_dimension = 2 - embedding_column_b, embedding_column_a = fc_lib._shared_embedding_columns( + embedding_column_b, embedding_column_a = fc.shared_embedding_columns( [categorical_column_b, categorical_column_a], dimension=embedding_dimension) self.assertIs(categorical_column_a, embedding_column_a.categorical_column) @@ -4183,7 +4196,7 @@ class SharedEmbeddingColumnTest(test.TestCase): categorical_column_b = fc.categorical_column_with_identity( key='bbb', num_buckets=3) embedding_dimension = 2 - embedding_column_a, embedding_column_b = fc_lib._shared_embedding_columns( + embedding_column_a, embedding_column_b = fc.shared_embedding_columns( [categorical_column_a, categorical_column_b], dimension=embedding_dimension, combiner='my_combiner', @@ -4236,7 +4249,7 @@ class SharedEmbeddingColumnTest(test.TestCase): categorical_column_b = fc.categorical_column_with_identity( key='bbb', num_buckets=3) embedding_dimension = 2 - original_a, _ = fc_lib._shared_embedding_columns( + original_a, _ = fc.shared_embedding_columns( [categorical_column_a, categorical_column_b], dimension=embedding_dimension, combiner='my_combiner', @@ -4274,7 +4287,7 @@ class SharedEmbeddingColumnTest(test.TestCase): categorical_column_b = fc.categorical_column_with_identity( key='bbb', num_buckets=3) with self.assertRaisesRegexp(ValueError, 'initializer must be callable'): - fc_lib._shared_embedding_columns( + fc.shared_embedding_columns( [categorical_column_a, categorical_column_b], dimension=2, initializer='not_fn') @@ -4289,7 +4302,7 @@ class SharedEmbeddingColumnTest(test.TestCase): ValueError, 'all categorical_columns must have the same type.*' '_IdentityCategoricalColumn.*_HashedCategoricalColumn'): - fc_lib._shared_embedding_columns( + fc.shared_embedding_columns( [categorical_column_a, categorical_column_b, categorical_column_c], dimension=2) @@ -4302,11 +4315,11 @@ class SharedEmbeddingColumnTest(test.TestCase): key='bbb', num_buckets=3) weighted_categorical_column_b = fc.weighted_categorical_column( categorical_column_b, weight_feature_key='bbb_weights') - fc_lib._shared_embedding_columns( + fc.shared_embedding_columns( [weighted_categorical_column_a, categorical_column_b], dimension=2) - fc_lib._shared_embedding_columns( + fc.shared_embedding_columns( [categorical_column_a, weighted_categorical_column_b], dimension=2) - fc_lib._shared_embedding_columns( + fc.shared_embedding_columns( [weighted_categorical_column_a, weighted_categorical_column_b], dimension=2) @@ -4315,7 +4328,7 @@ class SharedEmbeddingColumnTest(test.TestCase): key='aaa', vocabulary_list=('omar', 'stringer', 'marlo')) b = fc.categorical_column_with_vocabulary_list( key='bbb', vocabulary_list=('omar', 'stringer', 'marlo')) - a_embedded, b_embedded = fc_lib._shared_embedding_columns( + a_embedded, b_embedded = fc.shared_embedding_columns( [a, b], dimension=2) data = example_pb2.Example(features=feature_pb2.Features( feature={ @@ -4350,7 +4363,7 @@ class SharedEmbeddingColumnTest(test.TestCase): def test_transform_feature(self): a = fc.categorical_column_with_identity(key='aaa', num_buckets=3) b = fc.categorical_column_with_identity(key='bbb', num_buckets=3) - a_embedded, b_embedded = fc_lib._shared_embedding_columns( + a_embedded, b_embedded = fc.shared_embedding_columns( [a, b], dimension=2) features = { 'aaa': sparse_tensor.SparseTensor( @@ -4420,7 +4433,7 @@ class SharedEmbeddingColumnTest(test.TestCase): key='aaa', num_buckets=vocabulary_size) categorical_column_b = fc.categorical_column_with_identity( key='bbb', num_buckets=vocabulary_size) - embedding_column_a, embedding_column_b = fc_lib._shared_embedding_columns( + embedding_column_a, embedding_column_b = fc.shared_embedding_columns( [categorical_column_a, categorical_column_b], dimension=embedding_dimension, initializer=_initializer) @@ -4482,7 +4495,7 @@ class SharedEmbeddingColumnTest(test.TestCase): key='aaa', num_buckets=vocabulary_size) categorical_column_b = fc.categorical_column_with_identity( key='bbb', num_buckets=vocabulary_size) - embedding_column_a, embedding_column_b = fc_lib._shared_embedding_columns( + embedding_column_a, embedding_column_b = fc.shared_embedding_columns( [categorical_column_a, categorical_column_b], dimension=embedding_dimension, initializer=_initializer) @@ -4522,7 +4535,7 @@ class SharedEmbeddingColumnTest(test.TestCase): key='aaa', num_buckets=vocabulary_size) categorical_column_b = fc.categorical_column_with_identity( key='bbb', num_buckets=vocabulary_size) - embedding_column_a, embedding_column_b = fc_lib._shared_embedding_columns( + embedding_column_a, embedding_column_b = fc.shared_embedding_columns( [categorical_column_a, categorical_column_b], dimension=embedding_dimension, initializer=_initializer) @@ -4628,7 +4641,7 @@ class SharedEmbeddingColumnTest(test.TestCase): key='aaa', num_buckets=vocabulary_size) categorical_column_b = fc.categorical_column_with_identity( key='bbb', num_buckets=vocabulary_size) - embedding_column_a, embedding_column_b = fc_lib._shared_embedding_columns( + embedding_column_a, embedding_column_b = fc.shared_embedding_columns( [categorical_column_a, categorical_column_b], dimension=embedding_dimension, initializer=_initializer, trainable=trainable) diff --git a/tensorflow/python/framework/constant_op.py b/tensorflow/python/framework/constant_op.py index ac915157f528e78a960f0a5bf85539955c192eba..d3d8c9c154fbfcc9613acce4e1bdab7df2e7d56d 100644 --- a/tensorflow/python/framework/constant_op.py +++ b/tensorflow/python/framework/constant_op.py @@ -52,6 +52,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util +from tensorflow.python.util.tf_export import tf_export def _eager_reshape(tensor, shape, ctx): @@ -59,7 +60,6 @@ def _eager_reshape(tensor, shape, ctx): attr_t = tensor._datatype_enum() # pylint: disable=protected-access attr_tshape, (shape,) = execute.args_to_matching_eager( [shape], ctx, dtypes.int32) - attr_tshape = attr_tshape inputs_flat = [tensor, shape] attrs = ("T", attr_t, "Tshape", attr_tshape) result, = execute.execute( @@ -131,6 +131,7 @@ def convert_to_eager_tensor(value, ctx, dtype=None): return ops.EagerTensor(value, context=handle, device=device, dtype=dtype) +@tf_export("constant") def constant(value, dtype=None, shape=None, name="Const", verify_shape=False): """Creates a constant tensor. diff --git a/tensorflow/python/framework/cpp_shape_inference.h b/tensorflow/python/framework/cpp_shape_inference.h index afca7277c775062a8efa7052f789f5146636f4b9..c6ab6b106f5ea335424701afebbabba72ece8660 100644 --- a/tensorflow/python/framework/cpp_shape_inference.h +++ b/tensorflow/python/framework/cpp_shape_inference.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_PYTHON_FRAMEWORK_CPP_SHAPE_INFERENCE_H_ -#define THIRD_PARTY_TENSORFLOW_PYTHON_FRAMEWORK_CPP_SHAPE_INFERENCE_H_ +#ifndef TENSORFLOW_PYTHON_FRAMEWORK_CPP_SHAPE_INFERENCE_H_ +#define TENSORFLOW_PYTHON_FRAMEWORK_CPP_SHAPE_INFERENCE_H_ // Must be included first #include "tensorflow/python/lib/core/numpy.h" @@ -51,4 +51,4 @@ std::vector RunCppShapeInference( } // namespace swig } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_PYTHON_FRAMEWORK_CPP_SHAPE_INFERENCE_H_ +#endif // TENSORFLOW_PYTHON_FRAMEWORK_CPP_SHAPE_INFERENCE_H_ diff --git a/tensorflow/python/framework/device.py b/tensorflow/python/framework/device.py index 8f5125dcfef004bcbd5a581c5ff9dea1d85cf57e..ab06a2babf3976347714a98a50f95c07cbb6fdda 100644 --- a/tensorflow/python/framework/device.py +++ b/tensorflow/python/framework/device.py @@ -19,8 +19,10 @@ from __future__ import division from __future__ import print_function import copy +from tensorflow.python.util.tf_export import tf_export +@tf_export("DeviceSpec") class DeviceSpec(object): """Represents a (possibly partial) specification for a TensorFlow device. diff --git a/tensorflow/python/framework/dtypes.py b/tensorflow/python/framework/dtypes.py index b0422eb6be091a3fcf4b213f04a2e13a3ae8a963..67ccf990d6a0e59c965ff76c2ba601be2a64060a 100644 --- a/tensorflow/python/framework/dtypes.py +++ b/tensorflow/python/framework/dtypes.py @@ -23,11 +23,13 @@ import numpy as np from tensorflow.core.framework import types_pb2 from tensorflow.python import pywrap_tensorflow +from tensorflow.python.util.tf_export import tf_export _np_bfloat16 = pywrap_tensorflow.TF_bfloat16_type() +@tf_export("DType") class DType(object): """Represents the type of the elements in a `Tensor`. @@ -321,32 +323,55 @@ dtype_range = {np.bool_: (False, True), # Define standard wrappers for the types_pb2.DataType enum. resource = DType(types_pb2.DT_RESOURCE) +tf_export("resource").export_constant(__name__, "resource") variant = DType(types_pb2.DT_VARIANT) +tf_export("variant").export_constant(__name__, "variant") float16 = DType(types_pb2.DT_HALF) +tf_export("float16").export_constant(__name__, "float16") half = float16 +tf_export("half").export_constant(__name__, "half") float32 = DType(types_pb2.DT_FLOAT) +tf_export("float32").export_constant(__name__, "float32") float64 = DType(types_pb2.DT_DOUBLE) +tf_export("float64").export_constant(__name__, "float64") double = float64 +tf_export("double").export_constant(__name__, "double") int32 = DType(types_pb2.DT_INT32) +tf_export("int32").export_constant(__name__, "int32") uint8 = DType(types_pb2.DT_UINT8) +tf_export("uint8").export_constant(__name__, "uint8") uint16 = DType(types_pb2.DT_UINT16) +tf_export("uint16").export_constant(__name__, "uint16") uint32 = DType(types_pb2.DT_UINT32) uint64 = DType(types_pb2.DT_UINT64) int16 = DType(types_pb2.DT_INT16) +tf_export("int16").export_constant(__name__, "int16") int8 = DType(types_pb2.DT_INT8) +tf_export("int8").export_constant(__name__, "int8") string = DType(types_pb2.DT_STRING) +tf_export("string").export_constant(__name__, "string") complex64 = DType(types_pb2.DT_COMPLEX64) +tf_export("complex64").export_constant(__name__, "complex64") complex128 = DType(types_pb2.DT_COMPLEX128) +tf_export("complex128").export_constant(__name__, "complex128") int64 = DType(types_pb2.DT_INT64) +tf_export("int64").export_constant(__name__, "int64") bool = DType(types_pb2.DT_BOOL) +tf_export("bool").export_constant(__name__, "bool") qint8 = DType(types_pb2.DT_QINT8) +tf_export("qint8").export_constant(__name__, "qint8") quint8 = DType(types_pb2.DT_QUINT8) +tf_export("quint8").export_constant(__name__, "quint8") qint16 = DType(types_pb2.DT_QINT16) +tf_export("qint16").export_constant(__name__, "qint16") quint16 = DType(types_pb2.DT_QUINT16) +tf_export("quint16").export_constant(__name__, "quint16") qint32 = DType(types_pb2.DT_QINT32) +tf_export("qint32").export_constant(__name__, "qint32") resource_ref = DType(types_pb2.DT_RESOURCE_REF) variant_ref = DType(types_pb2.DT_VARIANT_REF) bfloat16 = DType(types_pb2.DT_BFLOAT16) +tf_export("bfloat16").export_constant(__name__, "bfloat16") float16_ref = DType(types_pb2.DT_HALF_REF) half_ref = float16_ref float32_ref = DType(types_pb2.DT_FLOAT_REF) @@ -578,8 +603,10 @@ _TF_TO_NP = { QUANTIZED_DTYPES = frozenset( [qint8, quint8, qint16, quint16, qint32, qint8_ref, quint8_ref, qint16_ref, quint16_ref, qint32_ref]) +tf_export("QUANTIZED_DTYPES").export_constant(__name__, "QUANTIZED_DTYPES") +@tf_export("as_dtype") def as_dtype(type_value): """Converts the given `type_value` to a `DType`. diff --git a/tensorflow/python/framework/errors_impl.py b/tensorflow/python/framework/errors_impl.py index c3b2c498c3118087ed57d825ae5c4e66703d8174..2a40316d51c023df9c664d0dd79a0df3b2ac5041 100644 --- a/tensorflow/python/framework/errors_impl.py +++ b/tensorflow/python/framework/errors_impl.py @@ -25,8 +25,10 @@ from tensorflow.core.lib.core import error_codes_pb2 from tensorflow.python import pywrap_tensorflow as c_api from tensorflow.python.framework import c_api_util from tensorflow.python.util import compat +from tensorflow.python.util.tf_export import tf_export +@tf_export("OpError", "errors.OpError") class OpError(Exception): """A generic error that is raised when TensorFlow execution fails. @@ -133,25 +135,48 @@ class OpError(Exception): OK = error_codes_pb2.OK +tf_export("errors.OK").export_constant(__name__, "OK") CANCELLED = error_codes_pb2.CANCELLED +tf_export("errors.CANCELLED").export_constant(__name__, "CANCELLED") UNKNOWN = error_codes_pb2.UNKNOWN +tf_export("errors.UNKNOWN").export_constant(__name__, "UNKNOWN") INVALID_ARGUMENT = error_codes_pb2.INVALID_ARGUMENT +tf_export("errors.INVALID_ARGUMENT").export_constant(__name__, + "INVALID_ARGUMENT") DEADLINE_EXCEEDED = error_codes_pb2.DEADLINE_EXCEEDED +tf_export("errors.DEADLINE_EXCEEDED").export_constant(__name__, + "DEADLINE_EXCEEDED") NOT_FOUND = error_codes_pb2.NOT_FOUND +tf_export("errors.NOT_FOUND").export_constant(__name__, "NOT_FOUND") ALREADY_EXISTS = error_codes_pb2.ALREADY_EXISTS +tf_export("errors.ALREADY_EXISTS").export_constant(__name__, "ALREADY_EXISTS") PERMISSION_DENIED = error_codes_pb2.PERMISSION_DENIED +tf_export("errors.PERMISSION_DENIED").export_constant(__name__, + "PERMISSION_DENIED") UNAUTHENTICATED = error_codes_pb2.UNAUTHENTICATED +tf_export("errors.UNAUTHENTICATED").export_constant(__name__, "UNAUTHENTICATED") RESOURCE_EXHAUSTED = error_codes_pb2.RESOURCE_EXHAUSTED +tf_export("errors.RESOURCE_EXHAUSTED").export_constant(__name__, + "RESOURCE_EXHAUSTED") FAILED_PRECONDITION = error_codes_pb2.FAILED_PRECONDITION +tf_export("errors.FAILED_PRECONDITION").export_constant(__name__, + "FAILED_PRECONDITION") ABORTED = error_codes_pb2.ABORTED +tf_export("errors.ABORTED").export_constant(__name__, "ABORTED") OUT_OF_RANGE = error_codes_pb2.OUT_OF_RANGE +tf_export("errors.OUT_OF_RANGE").export_constant(__name__, "OUT_OF_RANGE") UNIMPLEMENTED = error_codes_pb2.UNIMPLEMENTED +tf_export("errors.UNIMPLEMENTED").export_constant(__name__, "UNIMPLEMENTED") INTERNAL = error_codes_pb2.INTERNAL +tf_export("errors.INTERNAL").export_constant(__name__, "INTERNAL") UNAVAILABLE = error_codes_pb2.UNAVAILABLE +tf_export("errors.UNAVAILABLE").export_constant(__name__, "UNAVAILABLE") DATA_LOSS = error_codes_pb2.DATA_LOSS +tf_export("errors.DATA_LOSS").export_constant(__name__, "DATA_LOSS") # pylint: disable=line-too-long +@tf_export("errors.CancelledError") class CancelledError(OpError): """Raised when an operation or step is cancelled. @@ -172,6 +197,7 @@ class CancelledError(OpError): # pylint: enable=line-too-long +@tf_export("errors.UnknownError") class UnknownError(OpError): """Unknown error. @@ -189,6 +215,7 @@ class UnknownError(OpError): super(UnknownError, self).__init__(node_def, op, message, error_code) +@tf_export("errors.InvalidArgumentError") class InvalidArgumentError(OpError): """Raised when an operation receives an invalid argument. @@ -209,6 +236,7 @@ class InvalidArgumentError(OpError): INVALID_ARGUMENT) +@tf_export("errors.DeadlineExceededError") class DeadlineExceededError(OpError): """Raised when a deadline expires before an operation could complete. @@ -223,6 +251,7 @@ class DeadlineExceededError(OpError): DEADLINE_EXCEEDED) +@tf_export("errors.NotFoundError") class NotFoundError(OpError): """Raised when a requested entity (e.g., a file or directory) was not found. @@ -239,6 +268,7 @@ class NotFoundError(OpError): super(NotFoundError, self).__init__(node_def, op, message, NOT_FOUND) +@tf_export("errors.AlreadyExistsError") class AlreadyExistsError(OpError): """Raised when an entity that we attempted to create already exists. @@ -256,6 +286,7 @@ class AlreadyExistsError(OpError): ALREADY_EXISTS) +@tf_export("errors.PermissionDeniedError") class PermissionDeniedError(OpError): """Raised when the caller does not have permission to run an operation. @@ -273,6 +304,7 @@ class PermissionDeniedError(OpError): PERMISSION_DENIED) +@tf_export("errors.UnauthenticatedError") class UnauthenticatedError(OpError): """The request does not have valid authentication credentials. @@ -287,6 +319,7 @@ class UnauthenticatedError(OpError): UNAUTHENTICATED) +@tf_export("errors.ResourceExhaustedError") class ResourceExhaustedError(OpError): """Some resource has been exhausted. @@ -302,6 +335,7 @@ class ResourceExhaustedError(OpError): RESOURCE_EXHAUSTED) +@tf_export("errors.FailedPreconditionError") class FailedPreconditionError(OpError): """Operation was rejected because the system is not in a state to execute it. @@ -318,6 +352,7 @@ class FailedPreconditionError(OpError): FAILED_PRECONDITION) +@tf_export("errors.AbortedError") class AbortedError(OpError): """The operation was aborted, typically due to a concurrent action. @@ -335,6 +370,7 @@ class AbortedError(OpError): super(AbortedError, self).__init__(node_def, op, message, ABORTED) +@tf_export("errors.OutOfRangeError") class OutOfRangeError(OpError): """Raised when an operation iterates past the valid input range. @@ -353,6 +389,7 @@ class OutOfRangeError(OpError): OUT_OF_RANGE) +@tf_export("errors.UnimplementedError") class UnimplementedError(OpError): """Raised when an operation has not been implemented. @@ -371,6 +408,7 @@ class UnimplementedError(OpError): UNIMPLEMENTED) +@tf_export("errors.InternalError") class InternalError(OpError): """Raised when the system experiences an internal error. @@ -385,6 +423,7 @@ class InternalError(OpError): super(InternalError, self).__init__(node_def, op, message, INTERNAL) +@tf_export("errors.UnavailableError") class UnavailableError(OpError): """Raised when the runtime is currently unavailable. @@ -399,6 +438,7 @@ class UnavailableError(OpError): UNAVAILABLE) +@tf_export("errors.DataLossError") class DataLossError(OpError): """Raised when unrecoverable data loss or corruption is encountered. @@ -437,10 +477,12 @@ _EXCEPTION_CLASS_TO_CODE = dict(( (class_, code) for (code, class_) in _CODE_TO_EXCEPTION_CLASS.items())) +@tf_export("errors.exception_type_from_error_code") def exception_type_from_error_code(error_code): return _CODE_TO_EXCEPTION_CLASS[error_code] +@tf_export("errors.error_code_from_exception_type") def error_code_from_exception_type(cls): return _EXCEPTION_CLASS_TO_CODE[cls] @@ -457,7 +499,8 @@ def _make_specific_exception(node_def, op, message, error_code): # Named like a function for backwards compatibility with the # @tf_contextlib.contextmanager version, which was switched to a class to avoid # some object creation overhead. -class raise_exception_on_not_ok_status(object): # pylint: disable=invalid-name +@tf_export("errors.raise_exception_on_not_ok_status") # pylint: disable=invalid-name +class raise_exception_on_not_ok_status(object): """Context manager to check for C API status.""" def __enter__(self): diff --git a/tensorflow/python/framework/function.py b/tensorflow/python/framework/function.py index 416bbf4f48a6439eded68589ee5687789f42c02b..cba225e749d88a45c43266e45172a7335a8e0b71 100644 --- a/tensorflow/python/framework/function.py +++ b/tensorflow/python/framework/function.py @@ -417,7 +417,7 @@ class _DefinedFunction(object): if self._func_name: assert self._func_name == self._op_def.name else: - self._func_name = self._op_def.name + self._func_name = compat.as_str(self._op_def.name) def _set_c_attrs(self, attrs): """Sets `attrs` as attributes of self._c_func. diff --git a/tensorflow/python/framework/function_test.py b/tensorflow/python/framework/function_test.py index 57e5a724c99bd77df8cd11eff99288fa6647f4ac..a4ca3f9a89bd4cce2240d90895c43dda1acb849b 100644 --- a/tensorflow/python/framework/function_test.py +++ b/tensorflow/python/framework/function_test.py @@ -26,6 +26,7 @@ import numpy as np from tensorflow.core.framework import function_pb2 from tensorflow.core.protobuf import config_pb2 +from tensorflow.core.protobuf import rewriter_config_pb2 from tensorflow.python.client import session from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -451,13 +452,17 @@ class FunctionTest(test.TestCase): lambda y: AssertFail(y), [x]) # pylint: enable=unnecessary-lambda + rewriter_config = rewriter_config_pb2.RewriterConfig( + dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF) # Enables inlining. - config = config_pb2.ConfigProto(graph_options=config_pb2.GraphOptions( - optimizer_options=config_pb2.OptimizerOptions( - opt_level=config_pb2.OptimizerOptions.L0, - do_common_subexpression_elimination=True, - do_function_inlining=True, - do_constant_folding=True))) + config = config_pb2.ConfigProto( + graph_options=config_pb2.GraphOptions( + optimizer_options=config_pb2.OptimizerOptions( + opt_level=config_pb2.OptimizerOptions.L0, + do_common_subexpression_elimination=True, + do_function_inlining=True, + do_constant_folding=True), + rewrite_options=rewriter_config)) with session.Session(config=config) as sess: # Since the 'False' branch is not taken, the assertion should not fire. diff --git a/tensorflow/python/framework/graph_io.py b/tensorflow/python/framework/graph_io.py index a0ea4ad48eb84b22f42ea840513ebefbf6b4abbe..be30b16f5f0a76469226687fc1a419882b96f133 100644 --- a/tensorflow/python/framework/graph_io.py +++ b/tensorflow/python/framework/graph_io.py @@ -24,8 +24,10 @@ import os.path from google.protobuf import text_format from tensorflow.python.framework import ops from tensorflow.python.lib.io import file_io +from tensorflow.python.util.tf_export import tf_export +@tf_export('train.write_graph') def write_graph(graph_or_graph_def, logdir, name, as_text=True): """Writes a graph proto to a file. diff --git a/tensorflow/python/framework/graph_util_impl.py b/tensorflow/python/framework/graph_util_impl.py index 6c7b4553881637ce0b2ec63449bde0a397ef2d72..5a543317e665a940841714fd72d834a430f8406a 100644 --- a/tensorflow/python/framework/graph_util_impl.py +++ b/tensorflow/python/framework/graph_util_impl.py @@ -29,6 +29,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_util from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.util.tf_export import tf_export _VARIABLE_OPS = { "Assign", @@ -49,6 +50,7 @@ def _is_variable_op(op): return op in _VARIABLE_OPS +@tf_export("graph_util.must_run_on_cpu") def must_run_on_cpu(node, pin_variables_on_cpu=False): """Returns True if the given node_def must run on CPU, otherwise False. @@ -147,6 +149,7 @@ def _bfs_for_reachable_nodes(target_nodes, name_to_input_name): return nodes_to_keep +@tf_export("graph_util.extract_sub_graph") def extract_sub_graph(graph_def, dest_nodes): """Extract the subgraph that can reach any of the nodes in 'dest_nodes'. @@ -184,6 +187,7 @@ def extract_sub_graph(graph_def, dest_nodes): return out +@tf_export("graph_util.tensor_shape_from_node_def_name") def tensor_shape_from_node_def_name(graph, input_name): """Convenience function to get a shape from a NodeDef's input string.""" # To get a tensor, the name must be in the form :, for example @@ -198,6 +202,7 @@ def tensor_shape_from_node_def_name(graph, input_name): return shape +@tf_export("graph_util.convert_variables_to_constants") def convert_variables_to_constants(sess, input_graph_def, output_node_names, @@ -270,6 +275,7 @@ def convert_variables_to_constants(sess, return output_graph_def +@tf_export("graph_util.remove_training_nodes") def remove_training_nodes(input_graph, protected_nodes=None): """Prunes out nodes that aren't needed for inference. diff --git a/tensorflow/python/framework/importer.py b/tensorflow/python/framework/importer.py index a3dbe43f06eadb311338165bb07c3dccdf0299c3..00fff8d040d6facfc81359061f6cf9a1cf6d3d3c 100644 --- a/tensorflow/python/framework/importer.py +++ b/tensorflow/python/framework/importer.py @@ -36,6 +36,7 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.util import compat from tensorflow.python.util.deprecation import deprecated_args +from tensorflow.python.util.tf_export import tf_export # TODO(josh11b): SWIG the code from node_def_util instead of duplicating @@ -369,6 +370,7 @@ def _GatherReturnElements(requested_return_elements, graph, results): return combined_return_elements +@tf_export('import_graph_def') @deprecated_args(None, 'Please file an issue at ' 'https://github.com/tensorflow/tensorflow/issues if you depend' ' on this feature.', diff --git a/tensorflow/python/framework/load_library.py b/tensorflow/python/framework/load_library.py index 909e6d4c7be76743211d4c9045706fce62d4910e..c997ead829855f33efdb3efe947c3f59b5dbe76c 100644 --- a/tensorflow/python/framework/load_library.py +++ b/tensorflow/python/framework/load_library.py @@ -28,8 +28,10 @@ from tensorflow.core.lib.core import error_codes_pb2 from tensorflow.python import pywrap_tensorflow as py_tf from tensorflow.python.framework import errors_impl from tensorflow.python.util import compat +from tensorflow.python.util.tf_export import tf_export +@tf_export('load_op_library') def load_op_library(library_filename): """Loads a TensorFlow plugin, containing custom ops and kernels. @@ -79,6 +81,7 @@ def load_op_library(library_filename): return module +@tf_export('load_file_system_library') def load_file_system_library(library_filename): """Loads a TensorFlow plugin, containing file system implementation. diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index e7f08a64a622c7a8332aa095ad6de86015d18a2e..b107670275c87e2ee711c1a10fbe6bacc334ad5f 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -55,6 +55,7 @@ from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import compat from tensorflow.python.util import decorator_utils from tensorflow.python.util import tf_contextlib +from tensorflow.python.util.tf_export import tf_export # Temporary global switch determining if we should enable the work-in-progress @@ -191,6 +192,7 @@ class _TensorLike(object): pass +@tf_export("Tensor") class Tensor(_TensorLike): """Represents one of the outputs of an `Operation`. @@ -285,7 +287,7 @@ class Tensor(_TensorLike): self._op = op self._value_index = value_index self._dtype = dtypes.as_dtype(dtype) - self._shape = tensor_shape.unknown_shape() + self._shape_val = tensor_shape.unknown_shape() # List of operations that use this Tensor as input. We maintain this list # to easily navigate a computation graph. self._consumers = [] @@ -379,7 +381,18 @@ class Tensor(_TensorLike): graph, self._as_tf_output(), num_dims, status) dim_list = [None if i == -1 else i for i in dim_list] return tensor_shape.TensorShape(dim_list) - return self._shape + return self._shape_val + + @property + def _shape(self): + logging.warning("Tensor._shape is private, use Tensor.shape " + "instead. Tensor._shape will eventually be removed.") + return self.shape + + @_shape.setter + def _shape(self, value): + raise ValueError( + "Tensor._shape cannot be assigned, use Tensor.set_shape instead.") def __iter__(self): if context.in_graph_mode(): @@ -454,7 +467,7 @@ class Tensor(_TensorLike): this tensor. """ if not _USE_C_API: - self._shape = self._shape.merge_with(shape) # pylint: disable=protected-access + self._shape_val = self._shape_val.merge_with(shape) return if not isinstance(shape, tensor_shape.TensorShape): shape = tensor_shape.TensorShape(shape) @@ -468,13 +481,17 @@ class Tensor(_TensorLike): dim_list.append(-1) else: dim_list.append(dim.value) - with errors.raise_exception_on_not_ok_status() as status: - c_api.TF_GraphSetTensorShape_wrapper( - self._op._graph._c_graph, # pylint: disable=protected-access - self._as_tf_output(), - dim_list, - unknown_shape, - status) + try: + with errors.raise_exception_on_not_ok_status() as status: + c_api.TF_GraphSetTensorShape_wrapper( + self._op._graph._c_graph, # pylint: disable=protected-access + self._as_tf_output(), + dim_list, + unknown_shape, + status) + except errors.InvalidArgumentError as e: + # Convert to ValueError for backwards compatibility. + raise ValueError(str(e)) @property def value_index(self): @@ -775,6 +792,11 @@ class _EagerTensorBase(Tensor): """The shape of the tensor as a list.""" return list(self._shape_tuple()) + @property + def ndim(self): + """Returns the number of Tensor dimensions.""" + return self.shape.ndims + def cpu(self): """A copy of this Tensor with contents backed by host memory.""" return self._copy(context.context(), "CPU:0") @@ -866,6 +888,7 @@ _tensor_conversion_func_lock = threading.Lock() register_dense_tensor_like_type(Tensor) +@tf_export("convert_to_tensor") def convert_to_tensor(value, dtype=None, name=None, preferred_dtype=None): """Converts the given `value` to a `Tensor`. @@ -1111,6 +1134,7 @@ def convert_n_to_tensor(values, dtype=None, name=None, preferred_dtype=None): as_ref=False) +@tf_export("convert_to_tensor_or_indexed_slices") def convert_to_tensor_or_indexed_slices(value, dtype=None, name=None): """Converts the given object to a `Tensor` or an `IndexedSlices`. @@ -1241,6 +1265,7 @@ def convert_n_to_tensor_or_indexed_slices(values, dtype=None, name=None): # TODO(josh11b): Add ctx argument to conversion_func() signature. +@tf_export("register_tensor_conversion_function") def register_tensor_conversion_function(base_type, conversion_func, priority=100): @@ -1301,6 +1326,7 @@ def register_tensor_conversion_function(base_type, _tensor_conversion_func_cache = {} +@tf_export("IndexedSlices") class IndexedSlices(_TensorLike): """A sparse representation of a set of tensor slices at given indices. @@ -1481,6 +1507,7 @@ def _create_c_op(graph, node_def, inputs, control_inputs): return c_op +@tf_export("Operation") class Operation(object): """Represents a graph node that performs computation on tensors. @@ -1560,7 +1587,6 @@ class Operation(object): "Cannot create a tensor proto whose content is larger than 2GB.") if not _VALID_OP_NAME_REGEX.match(node_def.name): raise ValueError("'%s' is not a valid node name" % node_def.name) - self._node_def = copy.deepcopy(node_def) c_op = None elif type(node_def).__name__ == "SwigPyObject": assert inputs is None @@ -1569,7 +1595,6 @@ class Operation(object): assert input_types is None assert original_op is None assert op_def is None - self._node_def = None c_op = node_def else: raise TypeError("node_def needs to be a NodeDef: %s" % node_def) @@ -1577,28 +1602,27 @@ class Operation(object): if not isinstance(g, Graph): raise TypeError("g needs to be a Graph: %s" % g) self._graph = g + if inputs is None: inputs = [] elif not isinstance(inputs, list): raise TypeError("inputs needs to be a list of Tensors: %s" % inputs) - self._inputs = list(inputs) # Defensive copy. - for a in self._inputs: + for a in inputs: if not isinstance(a, Tensor): raise TypeError("input needs to be a Tensor: %s" % a) if input_types is None: - input_types = [i.dtype.base_dtype for i in self._inputs] + input_types = [i.dtype.base_dtype for i in inputs] else: if not all( x.is_compatible_with(i.dtype) - for i, x in zip(self._inputs, input_types)): + for i, x in zip(inputs, input_types)): raise TypeError("In op '%s', input types (%s) are not compatible " "with expected types (%s)" % - (self.node_def.name, [i.dtype for i in self._inputs], + (self.node_def.name, [i.dtype for i in inputs], input_types)) - self._input_types_val = input_types # Build the list of control inputs. - self._control_inputs = [] + control_input_ops = [] if control_inputs: for c in control_inputs: control_op = None @@ -1609,11 +1633,20 @@ class Operation(object): else: raise TypeError("Control input must be an Operation, " "a Tensor, or IndexedSlices: %s" % c) - self._control_inputs.append(control_op) + control_input_ops.append(control_op) + + # Don't set private fields with C API enabled to catch users who need to + # switch to public API. + # TODO(skyewm): delete these fields once we remove _USE_C_API + if not self._graph._c_graph: + self._inputs_val = list(inputs) # Defensive copy. + self._input_types_val = input_types + self._control_inputs_val = control_input_ops + self._node_def_val = copy.deepcopy(node_def) + self._op_def_val = op_def self._id_value = self._graph._next_id() # pylint: disable=protected-access self._original_op = original_op - self._op_def = op_def self._traceback = self._graph._extract_stack() # pylint: disable=protected-access self._control_flow_context = self.graph._get_control_flow_context() # pylint: disable=protected-access @@ -1629,15 +1662,15 @@ class Operation(object): # Refactor so we don't have to do this here. grouped_inputs = self._reconstruct_sequence_inputs( op_def, inputs, node_def.attr) - self._c_op = _create_c_op(self._graph, self._node_def, grouped_inputs, - self._control_inputs) + self._c_op = _create_c_op(self._graph, node_def, grouped_inputs, + control_input_ops) else: self._c_op = None # Mark that we consume the inputs. This is unnecessary and unsupported with # the C API enabled, since the C API tracks the tensor consumers instead. if not self._c_op: - for input_tensor in self._inputs: + for input_tensor in self._inputs_val: input_tensor._add_consumer(self) # pylint: disable=protected-access # Initialize self._outputs. @@ -1752,7 +1785,7 @@ class Operation(object): if self._c_op: return c_api.TF_OperationName(self._c_op) else: - return self._node_def.name + return self._node_def_val.name @property def _id(self): @@ -1771,7 +1804,7 @@ class Operation(object): if self._c_op: return c_api.TF_OperationDevice(self._c_op) else: - return self._node_def.device + return self._node_def_val.device @property def _output_types(self): @@ -1831,7 +1864,7 @@ class Operation(object): self._c_op, # pylint: disable=protected-access compat.as_str(_device_string(device))) else: - self._node_def.device = _device_string(device) + self._node_def_val.device = _device_string(device) def _add_input(self, tensor, dtype=None): """Add a new input to this operation. @@ -1859,7 +1892,7 @@ class Operation(object): raise TypeError( "Cannot convert a tensor of type %s to an input of type %s" % (tensor.dtype.name, dtype.name)) - self._inputs.append(tensor) + self._inputs_val.append(tensor) self._input_types_val.append(dtype) tensor._add_consumer(self) # pylint: disable=protected-access self._recompute_node_def() @@ -1889,8 +1922,8 @@ class Operation(object): self._tf_input(index), status) else: - self._inputs[index].consumers().remove(self) - self._inputs[index] = tensor + self._inputs_val[index].consumers().remove(self) + self._inputs_val[index] = tensor self._input_types_val[index] = tensor.dtype tensor._add_consumer(self) # pylint: disable=protected-access self._recompute_node_def() @@ -1916,7 +1949,7 @@ class Operation(object): if not isinstance(op, Operation): raise TypeError("op must be an Operation: %s" % op) _assert_same_graph(self, op) - self._control_inputs.append(op) + self._control_inputs_val.append(op) self._recompute_node_def() def _add_control_input(self, op): @@ -1948,13 +1981,14 @@ class Operation(object): # TODO(skyewm): remove this function when we switch to C API if self._c_op: return - del self._node_def.input[:] + del self._node_def_val.input[:] # pylint: disable=protected-access - self._node_def.input.extend([t._as_node_def_input() for t in self._inputs]) + self._node_def_val.input.extend( + [t._as_node_def_input() for t in self._inputs_val]) # pylint: enable=protected-access - if self._control_inputs: - self._node_def.input.extend( - ["^%s" % op.name for op in self._control_inputs]) + if self._control_inputs_val: + self._node_def_val.input.extend( + ["^%s" % op.name for op in self._control_inputs_val]) def __str__(self): return str(self.node_def) @@ -2004,7 +2038,17 @@ class Operation(object): ] # pylint: enable=protected-access return Operation._InputList(retval) - return Operation._InputList(self._inputs) + return Operation._InputList(self._inputs_val) + + @property + def _inputs(self): + logging.warning("Operation._inputs is private, use Operation.inputs " + "instead. Operation._inputs will eventually be removed.") + return self.inputs + + @_inputs.setter + def _inputs(self, value): + raise ValueError("Cannot assign _inputs") @property def _input_types(self): @@ -2018,6 +2062,10 @@ class Operation(object): else: return self._input_types_val + @_input_types.setter + def _input_types(self, value): + raise ValueError("Cannot assign _input_types") + @property def control_inputs(self): """The `Operation` objects on which this op has a control dependency. @@ -2041,7 +2089,22 @@ class Operation(object): ] # pylint: enable=protected-access else: - return self._control_inputs + return self._control_inputs_val + + @property + def _control_inputs(self): + logging.warning("Operation._control_inputs is private, use " + "Operation.control_inputs instead. " + "Operation._control_inputs will eventually be removed.") + return self.control_inputs + + @_control_inputs.setter + def _control_inputs(self, value): + logging.warning("Operation._control_inputs is private, use " + "Operation.control_inputs instead. " + "Operation._control_inputs will eventually be removed.") + self._remove_all_control_inputs() + self._add_control_inputs(value) @property def type(self): @@ -2050,7 +2113,7 @@ class Operation(object): op_type = c_api.TF_OperationOpType(self._c_op) return op_type else: - return self._node_def.op + return self._node_def_val.op @property def graph(self): @@ -2077,7 +2140,13 @@ class Operation(object): node_def.ParseFromString(compat.as_bytes(data)) return node_def else: - return self._node_def + return self._node_def_val + + @property + def _node_def(self): + logging.warning("Operation._node_def is private, use Operation.node_def " + "instead. Operation._node_def will eventually be removed.") + return self.node_def @property def op_def(self): @@ -2102,7 +2171,13 @@ class Operation(object): op_def.ParseFromString(compat.as_bytes(data)) return op_def else: - return self._op_def + return self._op_def_val + + @property + def _op_def(self): + logging.warning("Operation._op_def is private, use Operation.op_def " + "instead. Operation._op_def will eventually be removed.") + return self.op_def @property def traceback(self): @@ -2134,7 +2209,7 @@ class Operation(object): finally: c_api.TF_DeleteBuffer(buf) else: - self._node_def.attr[attr_name].CopyFrom(attr_value) + self._node_def_val.attr[attr_name].CopyFrom(attr_value) def get_attr(self, name): """Returns the value of the attr of this op with the given `name`. @@ -2161,10 +2236,10 @@ class Operation(object): x = attr_value_pb2.AttrValue() x.ParseFromString(data) else: - if name not in self._node_def.attr: + if name not in self._node_def_val.attr: raise ValueError( - "No attr named '" + name + "' in " + str(self._node_def)) - x = self._node_def.attr[name] + "No attr named '" + name + "' in " + str(self._node_def_val)) + x = self._node_def_val.attr[name] # Treat an empty oneof value as an empty list. if not x.WhichOneof("value"): @@ -2208,6 +2283,7 @@ class Operation(object): _gradient_registry = registry.Registry("gradient") +@tf_export("RegisterGradient") class RegisterGradient(object): """A decorator for registering the gradient function for an op type. @@ -2250,6 +2326,7 @@ class RegisterGradient(object): return f +@tf_export("NoGradient", "NotDifferentiable") def NotDifferentiable(op_type): """Specifies that ops of type `op_type` is not differentiable. @@ -2569,6 +2646,7 @@ def _name_from_scope_name(name): return name[:-1] if (name and name[-1] == "/") else name +@tf_export("Graph") class Graph(object): """A TensorFlow computation, represented as a dataflow graph. @@ -2695,6 +2773,7 @@ class Graph(object): self._scoped_c_graph = c_api_util.ScopedTFGraph() else: self._scoped_c_graph = None + self._variable_creator_stack = [] # TODO(apassos) remove once the C API is used by default. def _use_c_api_hack(self): @@ -2731,6 +2810,22 @@ class Graph(object): ret.append((filename, lineno, name, line)) return ret + # Note: this method is private because the API of tf.Graph() is public and + # frozen, and this functionality is still not ready for public visibility. + @tf_contextlib.contextmanager + def _variable_creator_scope(self, creator): + old = list(self._variable_creator_stack) + self._variable_creator_stack.append(creator) + try: + yield + finally: + self._variable_creator_stack = old + + # Note: this method is private because the API of tf.Graph() is public and + # frozen, and this functionality is still not ready for public visibility. + def _get_variable_creator_stack(self): + return list(self._variable_creator_stack) + def _extract_stack(self): """A lightweight, extensible re-implementation of traceback.extract_stack. @@ -4164,10 +4259,10 @@ class Graph(object): """ self._graph = graph if control_inputs is None: - self._control_inputs = [] + self._control_inputs_val = [] self._new_stack = True else: - self._control_inputs = control_inputs + self._control_inputs_val = control_inputs self._new_stack = False self._seen_nodes = set() self._old_stack = None @@ -4195,7 +4290,7 @@ class Graph(object): @property def control_inputs(self): - return self._control_inputs + return self._control_inputs_val def add_op(self, op): self._seen_nodes.add(op) @@ -4569,6 +4664,9 @@ class Graph(object): # TODO(agarwal): currently device directives in an outer eager scope will not # apply to inner graph mode code. Fix that. + + +@tf_export("device") def device(device_name_or_function): """Wrapper for `Graph.device()` using the default graph. @@ -4598,6 +4696,7 @@ def device(device_name_or_function): return context.device(device_name_or_function) +@tf_export("container") def container(container_name): """Wrapper for `Graph.container()` using the default graph. @@ -4611,6 +4710,7 @@ def container(container_name): return get_default_graph().container(container_name) +@tf_export("colocate_with") def colocate_with(op, ignore_existing=False): if context.in_graph_mode(): return get_default_graph().colocate_with(op, ignore_existing) @@ -4621,6 +4721,7 @@ def colocate_with(op, ignore_existing=False): return _NullContextmanager() +@tf_export("control_dependencies") def control_dependencies(control_inputs): """Wrapper for `Graph.control_dependencies()` using the default graph. @@ -4738,6 +4839,7 @@ def default_session(session): return _default_session_stack.get_controller(session) +@tf_export("get_default_session") def get_default_session(): """Returns the default session for the current thread. @@ -4950,6 +5052,8 @@ def enable_eager_execution(config=None, device_policy=None): right device but raises a warning. tfe.DEVICE_PLACEMENT_SILENT: silently copies the tensors. This might hide performance problems. + tfe.DEVICE_PLACEMENT_SILENT_FOR_INT32: silently copies int32 tensors, + raising errors on the other ones. Raises: ValueError: If trying to create a context after using graph operations @@ -4961,10 +5065,10 @@ def enable_eager_execution(config=None, device_policy=None): "config must be a tf.ConfigProto, but got %s" % type(config)) if device_policy not in (None, context.DEVICE_PLACEMENT_EXPLICIT, context.DEVICE_PLACEMENT_WARN, - context.DEVICE_PLACEMENT_SILENT): + context.DEVICE_PLACEMENT_SILENT, + context.DEVICE_PLACEMENT_SILENT_FOR_INT32): raise ValueError( - "device_policy must be one of None, tfe.DEVICE_PLACEMENT_EXPLICIT, " - "tfe.DEVICE_PLACEMENT_WARN, tfe.DEVICE_PLACEMENT_SILENT" + "device_policy must be one of None, tfe.DEVICE_PLACEMENT_*" ) # pylint: disable=protected-access if context._default_mode == context.GRAPH_MODE: @@ -5027,6 +5131,7 @@ def eager_run(main=None, argv=None): app.run(main, argv) +@tf_export("reset_default_graph") def reset_default_graph(): """Clears the default graph stack and resets the global default graph. @@ -5045,6 +5150,7 @@ def reset_default_graph(): _default_graph_stack.reset() +@tf_export("get_default_graph") def get_default_graph(): """Returns the default graph for the current thread. @@ -5165,6 +5271,7 @@ def _get_graph_from_inputs(op_input_list, graph=None): return graph or get_default_graph() +@tf_export("GraphKeys") class GraphKeys(object): """Standard names to use for graph collections. @@ -5313,6 +5420,7 @@ class GraphKeys(object): return cls.GLOBAL_VARIABLES +@tf_export("add_to_collection") def add_to_collection(name, value): """Wrapper for `Graph.add_to_collection()` using the default graph. @@ -5349,6 +5457,7 @@ def add_to_collections(names, value): get_default_graph().add_to_collections(names, value) +@tf_export("get_collection_ref") def get_collection_ref(key): """Wrapper for `Graph.get_collection_ref()` using the default graph. @@ -5372,6 +5481,7 @@ def get_collection_ref(key): return get_default_graph().get_collection_ref(key) +@tf_export("get_collection") def get_collection(key, scope=None): """Wrapper for `Graph.get_collection()` using the default graph. @@ -5408,6 +5518,7 @@ def get_all_collection_keys(): # Named like a function for backwards compatibility with the # @tf_contextlib.contextmanager version, which was switched to a class to avoid # some object creation overhead. +@tf_export("name_scope", "keras.backend.name_scope") class name_scope(object): # pylint: disable=invalid-name """A context manager for use when defining a Python op. @@ -5554,6 +5665,7 @@ def prepend_name_scope(name, import_scope): # pylint: disable=g-doc-return-or-yield # pylint: disable=not-context-manager +@tf_export("op_scope") @tf_contextlib.contextmanager def op_scope(values, name, default_name=None): """DEPRECATED. Same as name_scope above, just different argument order.""" diff --git a/tensorflow/python/framework/python_op_gen_internal.h b/tensorflow/python/framework/python_op_gen_internal.h index 6b53825a6d325c00eaf9f60fbcd9d4e0f9c9183c..d09b36a3e8247241420649c6a4a950be6edc3c00 100644 --- a/tensorflow/python/framework/python_op_gen_internal.h +++ b/tensorflow/python/framework/python_op_gen_internal.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_PYTHON_FRAMEWORK_PYTHON_OP_GEN_INTERNAL_H_ -#define THIRD_PARTY_TENSORFLOW_PYTHON_FRAMEWORK_PYTHON_OP_GEN_INTERNAL_H_ +#ifndef TENSORFLOW_PYTHON_FRAMEWORK_PYTHON_OP_GEN_INTERNAL_H_ +#define TENSORFLOW_PYTHON_FRAMEWORK_PYTHON_OP_GEN_INTERNAL_H_ #include @@ -112,4 +112,4 @@ class GenPythonOp { } // namespace python_op_gen_internal } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_PYTHON_FRAMEWORK_PYTHON_OP_GEN_INTERNAL_H_ +#endif // TENSORFLOW_PYTHON_FRAMEWORK_PYTHON_OP_GEN_INTERNAL_H_ diff --git a/tensorflow/python/framework/random_seed.py b/tensorflow/python/framework/random_seed.py index 5f1130570d2ec9bd964abeb7526ab03f14e067a3..1e74a790a3fb0c72b7c0fb1127ffac95f386d85e 100644 --- a/tensorflow/python/framework/random_seed.py +++ b/tensorflow/python/framework/random_seed.py @@ -22,6 +22,7 @@ from __future__ import print_function from tensorflow.python.eager import context from tensorflow.python.framework import ops +from tensorflow.python.util.tf_export import tf_export DEFAULT_GRAPH_SEED = 87654321 @@ -32,6 +33,7 @@ def _truncate_seed(seed): return seed % _MAXINT32 # Truncate to fit into 32-bit integer +@tf_export('get_seed') def get_seed(op_seed): """Returns the local seeds an operation should use given an op-specific seed. @@ -78,6 +80,7 @@ def get_seed(op_seed): return seeds +@tf_export('set_random_seed') def set_random_seed(seed): """Sets the graph-level random seed. diff --git a/tensorflow/python/framework/sparse_tensor.py b/tensorflow/python/framework/sparse_tensor.py index 6218cc34cad50aa6e291dcffcf352c717e0d85f0..1fe81e5f17a7de0a113596d920d63e5d9474c7c1 100644 --- a/tensorflow/python/framework/sparse_tensor.py +++ b/tensorflow/python/framework/sparse_tensor.py @@ -23,6 +23,7 @@ import collections from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_util +from tensorflow.python.util.tf_export import tf_export # pylint: disable=protected-access _TensorLike = ops._TensorLike @@ -31,6 +32,7 @@ _override_helper = ops._override_helper # pylint: enable=protected-access +@tf_export("SparseTensor") class SparseTensor(_TensorLike): """Represents a sparse tensor. @@ -222,8 +224,10 @@ class SparseTensor(_TensorLike): SparseTensorValue = collections.namedtuple( "SparseTensorValue", ["indices", "values", "dense_shape"]) +tf_export("SparseTensorValue")(SparseTensorValue) +@tf_export("convert_to_tensor_or_sparse_tensor") def convert_to_tensor_or_sparse_tensor(value, dtype=None, name=None): """Converts value to a `SparseTensor` or `Tensor`. diff --git a/tensorflow/python/framework/tensor_shape.py b/tensorflow/python/framework/tensor_shape.py index 54ec15ea66d637b3ef00c38d089e8cbd1c75444c..222071cb9e87aa0fdd9788d1c72df4c66ea61547 100644 --- a/tensorflow/python/framework/tensor_shape.py +++ b/tensorflow/python/framework/tensor_shape.py @@ -19,8 +19,10 @@ from __future__ import print_function from tensorflow.core.framework import tensor_shape_pb2 from tensorflow.python.util import compat +from tensorflow.python.util.tf_export import tf_export +@tf_export("Dimension") class Dimension(object): """Represents the value of one dimension in a TensorShape.""" @@ -397,6 +399,7 @@ def as_dimension(value): return Dimension(value) +@tf_export("TensorShape") class TensorShape(object): """Represents the shape of a `Tensor`. diff --git a/tensorflow/python/framework/tensor_util.py b/tensorflow/python/framework/tensor_util.py index 1b90c7ad4d68287bfa5c1c74c82d2936a20e4a80..d2b8e80305724fd12341bc089d8e0a63c40b6688 100644 --- a/tensorflow/python/framework/tensor_util.py +++ b/tensorflow/python/framework/tensor_util.py @@ -38,6 +38,7 @@ except ImportError: from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.util.tf_export import tf_export # pylint: enable=g-import-not-at-top @@ -328,6 +329,7 @@ def _AssertCompatible(values, dtype): (dtype.name, repr(mismatch), type(mismatch).__name__)) +@tf_export("make_tensor_proto") def make_tensor_proto(values, dtype=None, shape=None, verify_shape=False): """Create a TensorProto. @@ -515,6 +517,7 @@ def make_tensor_proto(values, dtype=None, shape=None, verify_shape=False): return tensor_proto +@tf_export("make_ndarray") def MakeNdarray(tensor): """Create a numpy ndarray from a tensor. diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py index 729c93987017a166cb346004c9880e764eeec4ef..0133318456219b35be11bc5ef128406292bc2feb 100644 --- a/tensorflow/python/framework/test_util.py +++ b/tensorflow/python/framework/test_util.py @@ -65,8 +65,10 @@ from tensorflow.python.training import server_lib from tensorflow.python.util import compat from tensorflow.python.util import nest from tensorflow.python.util.protobuf import compare +from tensorflow.python.util.tf_export import tf_export +@tf_export("test.gpu_device_name") def gpu_device_name(): """Returns the name of a GPU device if available or the empty string.""" for x in device_lib.list_local_devices(): @@ -101,6 +103,7 @@ def assert_ops_in_graph(expected_ops, graph): return actual_ops +@tf_export("test.assert_equal_graph_def") def assert_equal_graph_def(actual, expected, checkpoint_v2=False): """Asserts that two `GraphDef`s are (mostly) the same. @@ -630,6 +633,7 @@ def run_in_graph_and_eager_modes( return decorator +@tf_export("test.is_gpu_available") def is_gpu_available(cuda_only=False, min_cuda_compute_capability=None): """Returns whether TensorFlow can access a GPU. @@ -678,6 +682,7 @@ def device(use_gpu): yield +@tf_export("test.TestCase") class TensorFlowTestCase(googletest.TestCase): """Base class for tests that need to test TensorFlow. """ @@ -1125,43 +1130,90 @@ class TensorFlowTestCase(googletest.TestCase): print("not close dif = ", np.abs(x - y)) print("not close tol = ", atol + rtol * np.abs(y)) print("dtype = %s, shape = %s" % (a.dtype, a.shape)) - np.testing.assert_allclose(a, b, rtol=rtol, atol=atol, err_msg=msg) + # TODO(xpan): There seems to be a bug: + # tensorflow/compiler/tests:binary_ops_test pass with float32 + # nan even though the equal_nan is False by default internally. + np.testing.assert_allclose( + a, b, rtol=rtol, atol=atol, err_msg=msg, equal_nan=True) - def assertAllClose(self, a, b, rtol=1e-6, atol=1e-6): - """Asserts that two numpy arrays, or dicts of same, have near values. - - This does not support nested dicts. `a` and `b` can be namedtuples too, - which are converted to dicts. - - Args: - a: The expected numpy ndarray (or anything can be converted to one), or - dict of same. Must be a dict iff `b` is a dict. - b: The actual numpy ndarray (or anything can be converted to one), or - dict of same. Must be a dict iff `a` is a dict. - rtol: relative tolerance. - atol: absolute tolerance. + def _assertAllCloseRecursive(self, a, b, rtol=1e-6, atol=1e-6, path=None): + path = path or [] + path_str = (("[" + "][".join([str(p) for p in path]) + "]") if path else "") - Raises: - ValueError: if only one of `a` and `b` is a dict. - """ # Check if a and/or b are namedtuples. if hasattr(a, "_asdict"): a = a._asdict() if hasattr(b, "_asdict"): b = b._asdict() - is_a_dict = isinstance(a, dict) - if is_a_dict != isinstance(b, dict): - raise ValueError("Can't compare dict to non-dict, %s vs %s." % (a, b)) - if is_a_dict: + a_is_dict = isinstance(a, dict) + if a_is_dict != isinstance(b, dict): + raise ValueError("Can't compare dict to non-dict, a%s vs b%s." % + (path_str, path_str)) + if a_is_dict: self.assertItemsEqual( - a.keys(), b.keys(), - msg="mismatched keys, expected %s, got %s" % (a.keys(), b.keys())) + a.keys(), + b.keys(), + msg="mismatched keys: a%s has keys %s, but b%s has keys %s" % + (path_str, a.keys(), path_str, b.keys())) for k in a: + path.append(k) + self._assertAllCloseRecursive( + a[k], b[k], rtol=rtol, atol=atol, path=path) + del path[-1] + elif isinstance(a, (list, tuple)): + # Try to directly compare a, b as ndarrays; if not work, then traverse + # through the sequence, which is more expensive. + try: + a_as_ndarray = np.array(a) + b_as_ndarray = np.array(b) self._assertArrayLikeAllClose( - a[k], b[k], rtol=rtol, atol=atol, - msg="%s: expected %s, got %s." % (k, a, b)) + a_as_ndarray, + b_as_ndarray, + rtol=rtol, + atol=atol, + msg="Mismatched value: a%s is different from b%s." % (path_str, + path_str)) + except (ValueError, TypeError) as e: + if len(a) != len(b): + raise ValueError( + "Mismatched length: a%s has %d items, but b%s has %d items" % + (path_str, len(a), path_str, len(b))) + for idx, (a_ele, b_ele) in enumerate(zip(a, b)): + path.append(str(idx)) + self._assertAllCloseRecursive( + a_ele, b_ele, rtol=rtol, atol=atol, path=path) + del path[-1] + # a and b are ndarray like objects else: - self._assertArrayLikeAllClose(a, b, rtol=rtol, atol=atol) + self._assertArrayLikeAllClose( + a, + b, + rtol=rtol, + atol=atol, + msg="Mismatched value: a%s is different from b%s." % (path_str, + path_str)) + + def assertAllClose(self, a, b, rtol=1e-6, atol=1e-6): + """Asserts that two structures of numpy arrays, have near values. + + `a` and `b` can be arbitrarily nested structures. A layer of a nested + structure can be a `dict`, `namedtuple`, `tuple` or `list`. + + Args: + a: The expected numpy `ndarray`, or anything that can be converted into a + numpy `ndarray`, or any arbitrarily nested of structure of these. + b: The actual numpy `ndarray`, or anything that can be converted into a + numpy `ndarray`, or any arbitrarily nested of structure of these. + rtol: relative tolerance. + atol: absolute tolerance. + + Raises: + ValueError: if only one of `a[p]` and `b[p]` is a dict or + `a[p]` and `b[p]` have different length, where `[p]` denotes a path + to the nested structure, e.g. given `a = [(1, 1), {'d': (6, 7)}]` and + `[p] = [1]['d']`, then `a[p] = (6, 7)`. + """ + self._assertAllCloseRecursive(a, b, rtol=rtol, atol=atol) def assertAllCloseAccordingToType(self, a, @@ -1326,6 +1378,7 @@ class TensorFlowTestCase(googletest.TestCase): # pylint: enable=invalid-name +@tf_export("test.create_local_cluster") def create_local_cluster(num_workers, num_ps, protocol="grpc", worker_config=None, ps_config=None): """Create and start local servers and return the associated `Server` objects. diff --git a/tensorflow/python/framework/test_util_test.py b/tensorflow/python/framework/test_util_test.py index 6ddb3533e5bfcaca1dfea95ee35c078427e7529b..3594d125bf616917727bea4958eaabf159d0aee0 100644 --- a/tensorflow/python/framework/test_util_test.py +++ b/tensorflow/python/framework/test_util_test.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function import collections +import copy import random import threading @@ -252,12 +253,30 @@ class TestUtilTest(test_util.TensorFlowTestCase): with self.assertRaisesRegexp(AssertionError, r"Not equal to tolerance"): self.assertAllClose(expected, {"a": a, "b": b, "c": c_copy}) - def testAllCloseNestedDicts(self): - a = {"a": 1, "b": 2, "nested": {"d": 3, "e": 4}} - with self.assertRaisesRegexp( - TypeError, - r"inputs could not be safely coerced to any supported types"): - self.assertAllClose(a, a) + def testAllCloseListOfNamedtuples(self): + my_named_tuple = collections.namedtuple("MyNamedTuple", ["x", "y"]) + l1 = [ + my_named_tuple(x=np.array([[2.3, 2.5]]), y=np.array([[0.97, 0.96]])), + my_named_tuple(x=np.array([[3.3, 3.5]]), y=np.array([[0.98, 0.99]])) + ] + l2 = [ + ([[2.3, 2.5]], [[0.97, 0.96]]), + ([[3.3, 3.5]], [[0.98, 0.99]]), + ] + self.assertAllClose(l1, l2) + + def testAllCloseNestedStructure(self): + a = {"x": np.ones((3, 2, 4)) * 7, "y": (2, [{"nested": {"m": 3, "n": 4}}])} + self.assertAllClose(a, a) + + b = copy.deepcopy(a) + self.assertAllClose(a, b) + + # Test mismatched values + b["y"][1][0]["nested"]["n"] = 4.2 + with self.assertRaisesRegexp(AssertionError, + r"\[y\]\[1\]\[0\]\[nested\]\[n\]"): + self.assertAllClose(a, b) def testArrayNear(self): a = [1, 2] @@ -282,6 +301,9 @@ class TestUtilTest(test_util.TensorFlowTestCase): control_flow_ops.Assert(x, y).run() def testAssertAllCloseAccordingToType(self): + # test plain int + self.assertAllCloseAccordingToType(1, 1, rtol=1e-8, atol=1e-8) + # test float64 self.assertAllCloseAccordingToType( np.asarray([1e-8], dtype=np.float64), diff --git a/tensorflow/python/framework/versions.py b/tensorflow/python/framework/versions.py index f03b81eb28a7073873579390eae133d3c930c5a0..bdcbc15af63c57d712abfac97537f86b3bbe1737 100644 --- a/tensorflow/python/framework/versions.py +++ b/tensorflow/python/framework/versions.py @@ -20,6 +20,7 @@ from __future__ import division from __future__ import print_function from tensorflow.python import pywrap_tensorflow +from tensorflow.python.util.tf_export import tf_export __version__ = pywrap_tensorflow.__version__ __git_version__ = pywrap_tensorflow.__git_version__ @@ -28,16 +29,24 @@ __cxx11_abi_flag__ = pywrap_tensorflow.__cxx11_abi_flag__ __monolithic_build__ = pywrap_tensorflow.__monolithic_build__ VERSION = __version__ +tf_export("VERSION").export_constant(__name__, "VERSION") GIT_VERSION = __git_version__ +tf_export("GIT_VERSION").export_constant(__name__, "GIT_VERSION") COMPILER_VERSION = __compiler_version__ +tf_export("COMPILER_VERSION").export_constant(__name__, "COMPILER_VERSION") CXX11_ABI_FLAG = __cxx11_abi_flag__ MONOLITHIC_BUILD = __monolithic_build__ GRAPH_DEF_VERSION = pywrap_tensorflow.GRAPH_DEF_VERSION +tf_export("GRAPH_DEF_VERSION").export_constant(__name__, "GRAPH_DEF_VERSION") GRAPH_DEF_VERSION_MIN_CONSUMER = ( pywrap_tensorflow.GRAPH_DEF_VERSION_MIN_CONSUMER) +tf_export("GRAPH_DEF_VERSION_MIN_CONSUMER").export_constant( + __name__, "GRAPH_DEF_VERSION_MIN_CONSUMER") GRAPH_DEF_VERSION_MIN_PRODUCER = ( pywrap_tensorflow.GRAPH_DEF_VERSION_MIN_PRODUCER) +tf_export("GRAPH_DEF_VERSION_MIN_PRODUCER").export_constant( + __name__, "GRAPH_DEF_VERSION_MIN_PRODUCER") __all__ = [ "__version__", diff --git a/tensorflow/python/grappler/layout_optimizer_test.py b/tensorflow/python/grappler/layout_optimizer_test.py index 25c5ef6b68452c0b8f8dc67a15187db1df5e3934..578f86ca5a0c1f2446dbf26ce412e34f3bdbd23a 100644 --- a/tensorflow/python/grappler/layout_optimizer_test.py +++ b/tensorflow/python/grappler/layout_optimizer_test.py @@ -376,7 +376,7 @@ class LayoutOptimizerTest(test.TestCase): self.assertEqual(expected_num_transposes, num_transposes) self._assert_trans_nhwc_to_nchw('Conv2D-0', nodes) self._assert_trans_nchw_to_nhwc('Pad-0-0', nodes) - self.assertIn('Pad-PaddingsConst-LayoutOptimizer', nodes) + self.assertIn('Pad-1-LayoutOptimizer', nodes) self.assertAllClose(output_val_ref, output_val, atol=1e-3) def testReduceSum(self): @@ -587,7 +587,7 @@ class LayoutOptimizerTest(test.TestCase): self.assertEqual(expected_num_transposes, num_transposes) self._assert_trans_nhwc_to_nchw('Conv2D-0', nodes) self._assert_trans_nchw_to_nhwc('concat-0-0', nodes) - self.assertIn('concat-Const_2-LayoutOptimizer', nodes) + self.assertIn('concat-2-LayoutOptimizer', nodes) self.assertAllClose(output_val_ref, output_val, atol=1e-3) def testFill(self): @@ -698,7 +698,7 @@ class LayoutOptimizerTest(test.TestCase): self.assertEqual(expected_num_transposes, num_transposes) self._assert_trans_nhwc_to_nchw('Conv2D-0', nodes) self._assert_trans_nchw_to_nhwc('ReverseV2-0-0', nodes) - self.assertIn('ReverseV2-DimsConst-LayoutOptimizer', nodes) + self.assertIn('ReverseV2-1-LayoutOptimizer', nodes) self.assertAllClose(output_val_ref, output_val, atol=1e-3) def testReverseWithNonConstDims(self): @@ -867,7 +867,7 @@ class LayoutOptimizerTest(test.TestCase): self._assert_trans_nhwc_to_nchw('Conv2D-0', nodes) self._assert_trans_nchw_to_nhwc('MaxPoolV2-0-0', nodes) self._assert_vec_nhwc_to_nchw('MaxPoolV2-2', nodes) - self.assertIn('MaxPoolV2-Const_2-LayoutOptimizer', nodes) + self.assertIn('MaxPoolV2-1-LayoutOptimizer', nodes) self.assertAllClose(output_val_ref, output_val, atol=1e-3) def testMaxPoolGradV2(self): @@ -904,7 +904,7 @@ class LayoutOptimizerTest(test.TestCase): self._assert_trans_nhwc_to_nchw('Conv2D-0', nodes) self._assert_trans_nchw_to_nhwc('MaxPoolGradV2-0-0', nodes) self._assert_vec_nhwc_to_nchw('MaxPoolGradV2-4', nodes) - self.assertIn('MaxPoolGradV2-Const_2-LayoutOptimizer', nodes) + self.assertIn('MaxPoolGradV2-3-LayoutOptimizer', nodes) self.assertAllClose(output_val_ref, output_val, atol=1e-3) def testSliceWithNonConstAxis(self): @@ -977,16 +977,17 @@ class LayoutOptimizerTest(test.TestCase): self._assert_trans_nhwc_to_nchw('Conv2D-0', nodes) self._assert_trans_nchw_to_nhwc('StridedSlice-0-0', nodes) self._assert_vec_nhwc_to_nchw('StridedSlice-2', nodes) - self.assertIn('StridedSlice-StridedSlice/begin-LayoutOptimizer', nodes) - self.assertIn('StridedSlice-StridedSlice/strides-LayoutOptimizer', nodes) + self.assertIn('StridedSlice-1-LayoutOptimizer', nodes) + self.assertIn('StridedSlice-3-LayoutOptimizer', nodes) self.assertAllClose(output_val_ref, output_val, atol=1e-3) - def testStridedSliceWithMask(self): + def testStridedSliceWithMask1011(self): if test.is_gpu_available(cuda_only=True): random_seed.set_random_seed(0) x = random_ops.truncated_normal([1, 784], seed=0) conv = _two_layer_model(x) - # This will generate a StridedSlice op with begin mask and end mask. + # This will generate a StridedSlice op with begin mask and + # end mask 11(1011). s = conv[:, :, 1:-1, :] output = array_ops.identity(s) @@ -1010,11 +1011,44 @@ class LayoutOptimizerTest(test.TestCase): self.assertEqual(expected_num_transposes, num_transposes) self._assert_trans_nhwc_to_nchw('Conv2D-0', nodes) self._assert_trans_nchw_to_nhwc('strided_slice-0-0', nodes) - self.assertIn('strided_slice-strided_slice/stack-LayoutOptimizer', nodes) - self.assertIn('strided_slice-strided_slice/stack_1-LayoutOptimizer', - nodes) - self.assertIn('strided_slice-strided_slice/stack_2-LayoutOptimizer', - nodes) + self.assertIn('strided_slice-1-LayoutOptimizer', nodes) + self.assertIn('strided_slice-2-LayoutOptimizer', nodes) + self.assertIn('strided_slice-3-LayoutOptimizer', nodes) + self.assertAllClose(output_val_ref, output_val, atol=1e-3) + + def testStridedSliceWithMask0111(self): + if test.is_gpu_available(cuda_only=True): + random_seed.set_random_seed(0) + x = random_ops.truncated_normal([1, 784], seed=0) + conv = _two_layer_model(x) + # This will generate a StridedSlice op with begin mask and + # end mask 7(0111). + s = conv[:, :, :, 1:-1] + output = array_ops.identity(s) + + with session.Session() as sess: + output_val_ref = sess.run(output) + + with session.Session(config=_get_config()) as sess: + metadata = config_pb2.RunMetadata() + output_val = sess.run(output, run_metadata=metadata) + + nodes = [] + num_transposes = 0 + for node in metadata.cost_graph.node: + if _is_transpose(node.name): + num_transposes += 1 + nodes.append(node.name) + + # Four transposes were initially added in the Expand phase of + # LayoutOptimizer; two of them are cancelled out in the Collapse phase. + expected_num_transposes = 2 + self.assertEqual(expected_num_transposes, num_transposes) + self._assert_trans_nhwc_to_nchw('Conv2D-0', nodes) + self._assert_trans_nchw_to_nhwc('strided_slice-0-0', nodes) + self.assertIn('strided_slice-1-LayoutOptimizer', nodes) + self.assertIn('strided_slice-2-LayoutOptimizer', nodes) + self.assertIn('strided_slice-3-LayoutOptimizer', nodes) self.assertAllClose(output_val_ref, output_val, atol=1e-3) def testStridedSliceGradWithNonConstAxis(self): @@ -1055,10 +1089,8 @@ class LayoutOptimizerTest(test.TestCase): self._assert_trans_nhwc_to_nchw('Conv2D-0', nodes) self._assert_trans_nchw_to_nhwc('StridedSliceGrad-0-0', nodes) self._assert_vec_nhwc_to_nchw('StridedSliceGrad-2', nodes) - self.assertIn('StridedSlice-StridedSliceGrad/begin-LayoutOptimizer', - nodes) - self.assertIn('StridedSlice-StridedSliceGrad/strides-LayoutOptimizer', - nodes) + self.assertIn('StridedSlice-1-LayoutOptimizer', nodes) + self.assertIn('StridedSlice-2-LayoutOptimizer', nodes) self.assertAllClose(output_val_ref, output_val, atol=1e-3) def testShapeN(self): diff --git a/tensorflow/python/keras/_impl/keras/datasets/boston_housing.py b/tensorflow/python/keras/_impl/keras/datasets/boston_housing.py index 5d5d2c4f75003847306aad88a7a1f4804ee48707..0570e9bc0c7344641edf44cd5ef03a4f09005061 100644 --- a/tensorflow/python/keras/_impl/keras/datasets/boston_housing.py +++ b/tensorflow/python/keras/_impl/keras/datasets/boston_housing.py @@ -21,8 +21,10 @@ from __future__ import print_function import numpy as np from tensorflow.python.keras._impl.keras.utils.data_utils import get_file +from tensorflow.python.util.tf_export import tf_export +@tf_export('keras.datasets.boston_housing.load_data') def load_data(path='boston_housing.npz', seed=113, test_split=0.2): """Loads the Boston Housing dataset. diff --git a/tensorflow/python/keras/_impl/keras/datasets/cifar10.py b/tensorflow/python/keras/_impl/keras/datasets/cifar10.py index 7905da66c1e619153c75d7e05cad748710d63849..1971f434b9af820af287a3848ef538f5163a2a9a 100644 --- a/tensorflow/python/keras/_impl/keras/datasets/cifar10.py +++ b/tensorflow/python/keras/_impl/keras/datasets/cifar10.py @@ -25,8 +25,10 @@ import numpy as np from tensorflow.python.keras._impl.keras import backend as K from tensorflow.python.keras._impl.keras.datasets.cifar import load_batch from tensorflow.python.keras._impl.keras.utils.data_utils import get_file +from tensorflow.python.util.tf_export import tf_export +@tf_export('keras.datasets.cifar10.load_data') def load_data(): """Loads CIFAR10 dataset. diff --git a/tensorflow/python/keras/_impl/keras/datasets/cifar100.py b/tensorflow/python/keras/_impl/keras/datasets/cifar100.py index b69c0724c58d6d60a291c69db3de926605d90954..f4039e935076a55baaf471ad544986082a4e4ad8 100644 --- a/tensorflow/python/keras/_impl/keras/datasets/cifar100.py +++ b/tensorflow/python/keras/_impl/keras/datasets/cifar100.py @@ -25,8 +25,10 @@ import numpy as np from tensorflow.python.keras._impl.keras import backend as K from tensorflow.python.keras._impl.keras.datasets.cifar import load_batch from tensorflow.python.keras._impl.keras.utils.data_utils import get_file +from tensorflow.python.util.tf_export import tf_export +@tf_export('keras.datasets.cifar100.load_data') def load_data(label_mode='fine'): """Loads CIFAR100 dataset. diff --git a/tensorflow/python/keras/_impl/keras/datasets/imdb.py b/tensorflow/python/keras/_impl/keras/datasets/imdb.py index 7d55ebc8e47c86d2b0e24ea3802012b6e9d1d3a9..7946c46960ef15fdcaff6b5ad9f0bc2623a84b17 100644 --- a/tensorflow/python/keras/_impl/keras/datasets/imdb.py +++ b/tensorflow/python/keras/_impl/keras/datasets/imdb.py @@ -24,8 +24,10 @@ import numpy as np from six.moves import zip # pylint: disable=redefined-builtin from tensorflow.python.keras._impl.keras.utils.data_utils import get_file +from tensorflow.python.util.tf_export import tf_export +@tf_export('keras.datasets.imdb.load_data') def load_data(path='imdb.npz', num_words=None, skip_top=0, @@ -133,6 +135,7 @@ def load_data(path='imdb.npz', return (x_train, y_train), (x_test, y_test) +@tf_export('keras.datasets.imdb.get_word_index') def get_word_index(path='imdb_word_index.json'): """Retrieves the dictionary mapping word indices back to words. diff --git a/tensorflow/python/keras/_impl/keras/datasets/mnist.py b/tensorflow/python/keras/_impl/keras/datasets/mnist.py index e98f29537f4e29c649d0a1879e75505b050d6639..e9f53480150034d3e83f85cfad67f63e61422f3e 100644 --- a/tensorflow/python/keras/_impl/keras/datasets/mnist.py +++ b/tensorflow/python/keras/_impl/keras/datasets/mnist.py @@ -21,8 +21,10 @@ from __future__ import print_function import numpy as np from tensorflow.python.keras._impl.keras.utils.data_utils import get_file +from tensorflow.python.util.tf_export import tf_export +@tf_export('keras.datasets.mnist.load_data') def load_data(path='mnist.npz'): """Loads the MNIST dataset. diff --git a/tensorflow/python/keras/_impl/keras/datasets/reuters.py b/tensorflow/python/keras/_impl/keras/datasets/reuters.py index 3fed12b59fc2102fb5d3d30837772f594189082f..6da5aa4b5eb8b8eb5dcd8c75c3f1f86340436601 100644 --- a/tensorflow/python/keras/_impl/keras/datasets/reuters.py +++ b/tensorflow/python/keras/_impl/keras/datasets/reuters.py @@ -25,8 +25,10 @@ import numpy as np from six.moves import zip # pylint: disable=redefined-builtin from tensorflow.python.keras._impl.keras.utils.data_utils import get_file +from tensorflow.python.util.tf_export import tf_export +@tf_export('keras.datasets.reuters.load_data') def load_data(path='reuters.npz', num_words=None, skip_top=0, @@ -123,6 +125,7 @@ def load_data(path='reuters.npz', return (x_train, y_train), (x_test, y_test) +@tf_export('keras.datasets.reuters.get_word_index') def get_word_index(path='reuters_word_index.json'): """Retrieves the dictionary mapping word indices back to words. diff --git a/tensorflow/python/keras/_impl/keras/layers/convolutional.py b/tensorflow/python/keras/_impl/keras/layers/convolutional.py index 22496e8a765d4e86e7ef7ac5a25e6f4af94a28ce..f0f5e1fb463b828afe0f5bc19369408c98b57a08 100644 --- a/tensorflow/python/keras/_impl/keras/layers/convolutional.py +++ b/tensorflow/python/keras/_impl/keras/layers/convolutional.py @@ -563,7 +563,7 @@ class Conv2DTranspose(tf_convolutional_layers.Conv2DTranspose, Layer): return dict(list(base_config.items()) + list(config.items())) -class Conv3DTranspose(tf_convolutional_layers.Conv3D, Layer): +class Conv3DTranspose(tf_convolutional_layers.Conv3DTranspose, Layer): """Transposed convolution layer (sometimes called Deconvolution). The need for transposed convolutions generally arises diff --git a/tensorflow/python/keras/_impl/keras/layers/recurrent.py b/tensorflow/python/keras/_impl/keras/layers/recurrent.py index 6e38cf2f4181f36fdd0dfeadf699f445774459e3..9ea21c9c363455d693cc4d766b5f94ade56838d9 100644 --- a/tensorflow/python/keras/_impl/keras/layers/recurrent.py +++ b/tensorflow/python/keras/_impl/keras/layers/recurrent.py @@ -790,7 +790,8 @@ class SimpleRNNCell(Layer): units: Positive integer, dimensionality of the output space. activation: Activation function to use (see [activations](../activations.md)). - If you pass None, no activation is applied + Default: hyperbolic tangent (`tanh`). + If you pass `None`, no activation is applied (ie. "linear" activation: `a(x) = x`). use_bias: Boolean, whether the layer uses a bias vector. kernel_initializer: Initializer for the `kernel` weights matrix, @@ -946,7 +947,8 @@ class SimpleRNN(RNN): units: Positive integer, dimensionality of the output space. activation: Activation function to use (see [activations](../activations.md)). - If you pass None, no activation is applied + Default: hyperbolic tangent (`tanh`). + If you pass `None`, no activation is applied (ie. "linear" activation: `a(x) = x`). use_bias: Boolean, whether the layer uses a bias vector. kernel_initializer: Initializer for the `kernel` weights matrix, @@ -1155,11 +1157,15 @@ class GRUCell(Layer): units: Positive integer, dimensionality of the output space. activation: Activation function to use (see [activations](../activations.md)). - If you pass None, no activation is applied + Default: hyperbolic tangent (`tanh`). + If you pass `None`, no activation is applied (ie. "linear" activation: `a(x) = x`). recurrent_activation: Activation function to use for the recurrent step (see [activations](../activations.md)). + Default: hard sigmoid (`hard_sigmoid`). + If you pass `None`, no activation is applied + (ie. "linear" activation: `a(x) = x`). use_bias: Boolean, whether the layer uses a bias vector. kernel_initializer: Initializer for the `kernel` weights matrix, used for the linear transformation of the inputs. @@ -1392,11 +1398,15 @@ class GRU(RNN): units: Positive integer, dimensionality of the output space. activation: Activation function to use (see [activations](../activations.md)). - If you pass None, no activation is applied + Default: hyperbolic tangent (`tanh`). + If you pass `None`, no activation is applied (ie. "linear" activation: `a(x) = x`). recurrent_activation: Activation function to use for the recurrent step (see [activations](../activations.md)). + Default: hard sigmoid (`hard_sigmoid`). + If you pass `None`, no activation is applied + (ie. "linear" activation: `a(x) = x`). use_bias: Boolean, whether the layer uses a bias vector. kernel_initializer: Initializer for the `kernel` weights matrix, used for the linear transformation of the inputs. @@ -1630,11 +1640,15 @@ class LSTMCell(Layer): units: Positive integer, dimensionality of the output space. activation: Activation function to use (see [activations](../activations.md)). - If you pass None, no activation is applied + Default: hyperbolic tangent (`tanh`). + If you pass `None`, no activation is applied (ie. "linear" activation: `a(x) = x`). recurrent_activation: Activation function to use for the recurrent step (see [activations](../activations.md)). + Default: hard sigmoid (`hard_sigmoid`). + If you pass `None`, no activation is applied + (ie. "linear" activation: `a(x) = x`). use_bias: Boolean, whether the layer uses a bias vector. kernel_initializer: Initializer for the `kernel` weights matrix, used for the linear transformation of the inputs. @@ -1896,11 +1910,16 @@ class LSTM(RNN): units: Positive integer, dimensionality of the output space. activation: Activation function to use (see [activations](../activations.md)). - If you pass None, no activation is applied + Default: hyperbolic tangent (`tanh`). + If you pass `None`, no activation is applied (ie. "linear" activation: `a(x) = x`). recurrent_activation: Activation function to use for the recurrent step (see [activations](../activations.md)). + Default: hyperbolic tangent (`tanh`). + Default: hard sigmoid (`hard_sigmoid`). + If you pass `None`, no activation is applied + (ie. "linear" activation: `a(x) = x`). use_bias: Boolean, whether the layer uses a bias vector. kernel_initializer: Initializer for the `kernel` weights matrix, used for the linear transformation of the inputs. diff --git a/tensorflow/python/keras/_impl/keras/utils/data_utils.py b/tensorflow/python/keras/_impl/keras/utils/data_utils.py index d0be29f8298fbc83ac518bb7ddf5eda312119e96..d9e8f37e36cff0723c02820e16cc502bb0aea294 100644 --- a/tensorflow/python/keras/_impl/keras/utils/data_utils.py +++ b/tensorflow/python/keras/_impl/keras/utils/data_utils.py @@ -38,6 +38,7 @@ from six.moves.urllib.error import URLError from six.moves.urllib.request import urlopen from tensorflow.python.keras._impl.keras.utils.generic_utils import Progbar +from tensorflow.python.util.tf_export import tf_export try: import queue # pylint:disable=g-import-not-at-top @@ -135,6 +136,7 @@ def _extract_archive(file_path, path='.', archive_format='auto'): return False +@tf_export('keras.utils.get_file') def get_file(fname, origin, untar=False, @@ -315,6 +317,7 @@ def validate_file(fpath, file_hash, algorithm='auto', chunk_size=65535): return False +@tf_export('keras.utils.Sequence') class Sequence(object): """Base object for fitting to a sequence of data, such as a dataset. @@ -402,6 +405,7 @@ def get_index(uid, i): return _SHARED_SEQUENCES[uid][i] +@tf_export('keras.utils.SequenceEnqueuer') class SequenceEnqueuer(object): """Base class to enqueue inputs. @@ -608,6 +612,7 @@ class OrderedEnqueuer(SequenceEnqueuer): self.executor.join() +@tf_export('keras.utils.GeneratorEnqueuer') class GeneratorEnqueuer(SequenceEnqueuer): """Builds a queue out of a data generator. @@ -752,4 +757,3 @@ class GeneratorEnqueuer(SequenceEnqueuer): success, value = self.queue.get() if not success: six.reraise(value.__class__, value, value.__traceback__) - diff --git a/tensorflow/python/keras/_impl/keras/utils/generic_utils.py b/tensorflow/python/keras/_impl/keras/utils/generic_utils.py index e9e54c2a2a713423b77e8279740f0338263206eb..a805315c94628f263dd4ce7a8b0f751cdf685ca0 100644 --- a/tensorflow/python/keras/_impl/keras/utils/generic_utils.py +++ b/tensorflow/python/keras/_impl/keras/utils/generic_utils.py @@ -29,10 +29,12 @@ import six from tensorflow.python.util import tf_decorator from tensorflow.python.util import tf_inspect +from tensorflow.python.util.tf_export import tf_export _GLOBAL_CUSTOM_OBJECTS = {} +@tf_export('keras.utils.CustomObjectScope') class CustomObjectScope(object): """Provides a scope that changes to `_GLOBAL_CUSTOM_OBJECTS` cannot escape. @@ -68,6 +70,7 @@ class CustomObjectScope(object): _GLOBAL_CUSTOM_OBJECTS.update(self.backup) +@tf_export('keras.utils.custom_object_scope') def custom_object_scope(*args): """Provides a scope that changes to `_GLOBAL_CUSTOM_OBJECTS` cannot escape. @@ -98,6 +101,7 @@ def custom_object_scope(*args): return CustomObjectScope(*args) +@tf_export('keras.utils.get_custom_objects') def get_custom_objects(): """Retrieves a live reference to the global dictionary of custom objects. @@ -118,6 +122,7 @@ def get_custom_objects(): return _GLOBAL_CUSTOM_OBJECTS +@tf_export('keras.utils.serialize_keras_object') def serialize_keras_object(instance): _, instance = tf_decorator.unwrap(instance) if instance is None: @@ -133,6 +138,7 @@ def serialize_keras_object(instance): raise ValueError('Cannot serialize', instance) +@tf_export('keras.utils.deserialize_keras_object') def deserialize_keras_object(identifier, module_objects=None, custom_objects=None, @@ -275,6 +281,7 @@ def has_arg(fn, name, accept_all=False): return name in arg_spec.args +@tf_export('keras.utils.Progbar') class Progbar(object): """Displays a progress bar. diff --git a/tensorflow/python/keras/_impl/keras/utils/io_utils.py b/tensorflow/python/keras/_impl/keras/utils/io_utils.py index a8fc18c17aee58fa406c3057cc98844d9687a9ba..e123339f5a7cc629778e2247d985dbe4591da54a 100644 --- a/tensorflow/python/keras/_impl/keras/utils/io_utils.py +++ b/tensorflow/python/keras/_impl/keras/utils/io_utils.py @@ -21,6 +21,7 @@ from collections import defaultdict import sys import numpy as np +from tensorflow.python.util.tf_export import tf_export try: @@ -29,6 +30,7 @@ except ImportError: h5py = None +@tf_export('keras.utils.HDF5Matrix') class HDF5Matrix(object): """Representation of HDF5 dataset to be used instead of a Numpy array. diff --git a/tensorflow/python/keras/_impl/keras/utils/layer_utils.py b/tensorflow/python/keras/_impl/keras/utils/layer_utils.py index 053c0600a33d6ab0151ecc8879cbc68fe731dbe5..30af285cbfb8b8bc38e62d20f0698f9d3c121d10 100644 --- a/tensorflow/python/keras/_impl/keras/utils/layer_utils.py +++ b/tensorflow/python/keras/_impl/keras/utils/layer_utils.py @@ -22,6 +22,7 @@ import numpy as np from tensorflow.python.keras._impl.keras import backend as K from tensorflow.python.keras._impl.keras.utils.conv_utils import convert_kernel +from tensorflow.python.util.tf_export import tf_export def count_params(weights): @@ -187,6 +188,7 @@ def print_summary(model, line_length=None, positions=None, print_fn=None): print_fn('_' * line_length) +@tf_export('keras.utils.convert_all_kernels_in_model') def convert_all_kernels_in_model(model): """Converts all convolution kernels in a model from Theano to TensorFlow. diff --git a/tensorflow/python/keras/_impl/keras/utils/np_utils.py b/tensorflow/python/keras/_impl/keras/utils/np_utils.py index 67d83bf42c4387be6e5ba578663ecf02ade054c8..3dddb99191c8a40adf8f39216679a0975d4e830c 100644 --- a/tensorflow/python/keras/_impl/keras/utils/np_utils.py +++ b/tensorflow/python/keras/_impl/keras/utils/np_utils.py @@ -18,8 +18,10 @@ from __future__ import division from __future__ import print_function import numpy as np +from tensorflow.python.util.tf_export import tf_export +@tf_export('keras.utils.to_categorical') def to_categorical(y, num_classes=None): """Converts a class vector (integers) to binary class matrix. @@ -48,6 +50,7 @@ def to_categorical(y, num_classes=None): return categorical +@tf_export('keras.utils.normalize') def normalize(x, axis=-1, order=2): """Normalizes a Numpy array. diff --git a/tensorflow/python/keras/_impl/keras/utils/training_utils.py b/tensorflow/python/keras/_impl/keras/utils/training_utils.py index 0bf4ac8a24d3011e05f2db101cd02931e0b65849..ce7402e9d279278eaaf5aab58a3973eec6de8e99 100644 --- a/tensorflow/python/keras/_impl/keras/utils/training_utils.py +++ b/tensorflow/python/keras/_impl/keras/utils/training_utils.py @@ -21,6 +21,7 @@ from tensorflow.python.framework import ops from tensorflow.python.keras._impl.keras import backend as K from tensorflow.python.keras._impl.keras.engine.training import Model from tensorflow.python.ops import array_ops +from tensorflow.python.util.tf_export import tf_export def _get_available_devices(): @@ -32,6 +33,7 @@ def _normalize_device_name(name): return name +@tf_export('keras.utils.multi_gpu_model') def multi_gpu_model(model, gpus): """Replicates a model on different GPUs. @@ -203,4 +205,3 @@ def multi_gpu_model(model, gpus): for name, outputs in zip(model.output_names, all_outputs): merged.append(concatenate(outputs, axis=0, name=name)) return Model(model.inputs, merged) - diff --git a/tensorflow/python/keras/_impl/keras/utils/vis_utils.py b/tensorflow/python/keras/_impl/keras/utils/vis_utils.py index d56c4484ce35d0c6af08d6199867b7845f367c88..1ec8e3a2bf6d539655b4417cbd413a926978cee2 100644 --- a/tensorflow/python/keras/_impl/keras/utils/vis_utils.py +++ b/tensorflow/python/keras/_impl/keras/utils/vis_utils.py @@ -19,6 +19,7 @@ from __future__ import print_function import os import sys +from tensorflow.python.util.tf_export import tf_export try: # pydot-ng is a fork of pydot that is better maintained. @@ -128,6 +129,7 @@ def model_to_dot(model, show_shapes=False, show_layer_names=True, rankdir='TB'): return dot +@tf_export('keras.utils.plot_model') def plot_model(model, to_file='model.png', show_shapes=False, diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD index de6aba4477fee84c01da6d684418f1101733ce39..8c1d16c2a8fc2ed1130d81c46aa233bf8416caf8 100644 --- a/tensorflow/python/kernel_tests/BUILD +++ b/tensorflow/python/kernel_tests/BUILD @@ -87,6 +87,8 @@ cuda_py_test( srcs = ["list_ops_test.py"], additional_deps = [ "//third_party/py/numpy", + "//tensorflow/python:array_ops", + "//tensorflow/python:math_ops", "//tensorflow/python:list_ops", "//tensorflow/python/eager:context", "//tensorflow/python:framework_for_generated_wrappers", @@ -2488,6 +2490,7 @@ cuda_py_test( "//tensorflow/python:sparse_ops", ], shard_count = 5, + tags = ["noasan"], ) cuda_py_test( diff --git a/tensorflow/python/kernel_tests/array_ops_test.py b/tensorflow/python/kernel_tests/array_ops_test.py index 1dbe7deb97c6e4305dbf998813249bd80ace3363..ec6184aacdb1ee6376944114ace3f1c1c1407aa9 100644 --- a/tensorflow/python/kernel_tests/array_ops_test.py +++ b/tensorflow/python/kernel_tests/array_ops_test.py @@ -975,6 +975,7 @@ class ShapeSizeRankTest(test_util.TensorFlowTestCase): self.assertEqual(2, array_ops.rank(sp).eval()) +@test_util.with_c_api class SequenceMaskTest(test_util.TensorFlowTestCase): def testExceptions(self): @@ -993,7 +994,10 @@ class SequenceMaskTest(test_util.TensorFlowTestCase): # test dtype and default maxlen: res = array_ops.sequence_mask( constant_op.constant([0, 1, 4]), dtype=dtypes.float32) - self.assertAllEqual(res.get_shape().as_list(), [3, None]) + if ops._USE_C_API: + self.assertAllEqual(res.get_shape().as_list(), [3, 4]) + else: + self.assertAllEqual(res.get_shape().as_list(), [3, None]) self.assertAllEqual(res.eval(), [[0.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 1.0, 1.0, 1.0]]) @@ -1009,7 +1013,10 @@ class SequenceMaskTest(test_util.TensorFlowTestCase): # test dtype and default maxlen: res = array_ops.sequence_mask( constant_op.constant([[0, 1, 4], [1, 2, 3]]), dtype=dtypes.float32) - self.assertAllEqual(res.get_shape().as_list(), [2, 3, None]) + if ops._USE_C_API: + self.assertAllEqual(res.get_shape().as_list(), [2, 3, 4]) + else: + self.assertAllEqual(res.get_shape().as_list(), [2, 3, None]) self.assertAllEqual(res.eval(), [[[0.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 1.0, 1.0, 1.0]], diff --git a/tensorflow/python/kernel_tests/constant_op_test.py b/tensorflow/python/kernel_tests/constant_op_test.py index 030c690167fd7edef9ad929eb5cee5f03d9d5883..576bb68ba49cf5a5c7618131ad8a567672cb08d8 100644 --- a/tensorflow/python/kernel_tests/constant_op_test.py +++ b/tensorflow/python/kernel_tests/constant_op_test.py @@ -454,18 +454,20 @@ class ZerosLikeTest(test.TestCase): def testZerosLikeCPU(self): for dtype in [ - dtypes_lib.float32, dtypes_lib.float64, dtypes_lib.int8, - dtypes_lib.uint8, dtypes_lib.int16, dtypes_lib.uint16, dtypes_lib.int32, - dtypes_lib.int64, dtypes_lib.bool, dtypes_lib.complex64, - dtypes_lib.complex128, dtypes_lib.string + dtypes_lib.half, dtypes_lib.float32, dtypes_lib.float64, + dtypes_lib.int8, dtypes_lib.uint8, dtypes_lib.int16, dtypes_lib.uint16, + dtypes_lib.int32, dtypes_lib.int64, dtypes_lib.bool, + dtypes_lib.complex64, dtypes_lib.complex128, dtypes_lib.string ]: self._compareZeros(dtype, fully_defined_shape=False, use_gpu=False) self._compareZeros(dtype, fully_defined_shape=True, use_gpu=False) def testZerosLikeGPU(self): for dtype in [ - dtypes_lib.float32, dtypes_lib.float64, dtypes_lib.int32, - dtypes_lib.bool, dtypes_lib.int64, dtypes_lib.string + dtypes_lib.half, dtypes_lib.float32, dtypes_lib.float64, + dtypes_lib.int32, dtypes_lib.int64, + dtypes_lib.complex64, dtypes_lib.complex128, + dtypes_lib.bool ]: self._compareZeros(dtype, fully_defined_shape=False, use_gpu=True) self._compareZeros(dtype, fully_defined_shape=True, use_gpu=True) diff --git a/tensorflow/python/kernel_tests/conv1d_test.py b/tensorflow/python/kernel_tests/conv1d_test.py index d92797a7d38cbe359d8166ea9ad7c25bd9cd1f4b..e2e6205911caa06b52f21658a91a53d60a0130ff 100644 --- a/tensorflow/python/kernel_tests/conv1d_test.py +++ b/tensorflow/python/kernel_tests/conv1d_test.py @@ -30,27 +30,29 @@ from tensorflow.python.platform import test class Conv1DTest(test.TestCase): def testBasic(self): - """Test that argument passing to conv2d is handled properly.""" - - x = constant_op.constant([1, 2, 3, 4], dtype=dtypes.float32) - x = array_ops.expand_dims(x, 0) # Add batch dimension - x = array_ops.expand_dims(x, 2) # And depth dimension - filters = constant_op.constant([2, 1], dtype=dtypes.float32) - filters = array_ops.expand_dims(filters, 1) # in_channels - filters = array_ops.expand_dims(filters, 2) # out_channels - # Filters is 2x1x1 - for stride in [1, 2]: - with self.test_session(use_gpu=test.is_gpu_available()): - c = nn_ops.conv1d(x, filters, stride, padding="VALID") - reduced = array_ops.squeeze(c) - output = reduced.eval() - if stride == 1: - self.assertEqual(len(output), 3) - self.assertAllClose(output, - [2 * 1 + 1 * 2, 2 * 2 + 1 * 3, 2 * 3 + 1 * 4]) - else: - self.assertEqual(len(output), 2) - self.assertAllClose(output, [2 * 1 + 1 * 2, 2 * 3 + 1 * 4]) + """Test that argument passing to conv1d is handled properly.""" + # TODO(yongtang): dtypes.float64 can only be enabled once conv2d support + # dtypes.float64, as conv1d implicitly calls conv2d after expand_dims. + for dtype in [dtypes.float16, dtypes.float32]: + x = constant_op.constant([1, 2, 3, 4], dtype=dtype) + x = array_ops.expand_dims(x, 0) # Add batch dimension + x = array_ops.expand_dims(x, 2) # And depth dimension + filters = constant_op.constant([2, 1], dtype=dtype) + filters = array_ops.expand_dims(filters, 1) # in_channels + filters = array_ops.expand_dims(filters, 2) # out_channels + # Filters is 2x1x1 + for stride in [1, 2]: + with self.test_session(use_gpu=test.is_gpu_available()): + c = nn_ops.conv1d(x, filters, stride, padding="VALID") + reduced = array_ops.squeeze(c) + output = reduced.eval() + if stride == 1: + self.assertEqual(len(output), 3) + self.assertAllClose(output, + [2 * 1 + 1 * 2, 2 * 2 + 1 * 3, 2 * 3 + 1 * 4]) + else: + self.assertEqual(len(output), 2) + self.assertAllClose(output, [2 * 1 + 1 * 2, 2 * 3 + 1 * 4]) def testConv1DTranspose(self): with self.test_session(): diff --git a/tensorflow/python/kernel_tests/cwise_ops_test.py b/tensorflow/python/kernel_tests/cwise_ops_test.py index cea12ea8ecfa7a4f592454a96f7f3dc9dd3663ed..a91917b27faf46710d3f494b76929f4c7b9e9eec 100644 --- a/tensorflow/python/kernel_tests/cwise_ops_test.py +++ b/tensorflow/python/kernel_tests/cwise_ops_test.py @@ -24,6 +24,7 @@ import numpy as np from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes as dtypes_lib +from tensorflow.python.framework import errors_impl from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import array_ops @@ -1168,6 +1169,32 @@ class BinaryOpTest(test.TestCase): self._compareCpu(x1, x2, np.arctan2, math_ops.atan2) self._compareGpu(x1, x2, np.arctan2, math_ops.atan2) + def testPowNegativeExponent(self): + for dtype in [np.int32, np.int64]: + with self.test_session(use_gpu=False) as sess: + with self.assertRaisesRegexp( + errors_impl.InvalidArgumentError, + "Integers to negative integer powers are not allowed"): + x = np.array([5, 2]).astype(dtype) + y = np.array([-2, 3]).astype(dtype) + sess.run(math_ops.pow(x, y)) + + with self.test_session(use_gpu=False) as sess: + with self.assertRaisesRegexp( + errors_impl.InvalidArgumentError, + "Integers to negative integer powers are not allowed"): + x = np.array([5, 2]).astype(dtype) + y = np.array([2, -3]).astype(dtype) + sess.run(math_ops.pow(x, y)) + + with self.test_session(use_gpu=False) as sess: + with self.assertRaisesRegexp( + errors_impl.InvalidArgumentError, + "Integers to negative integer powers are not allowed"): + x = np.array([5, 2]).astype(dtype) + y = -3 + sess.run(math_ops.pow(x, y)) + class ComparisonOpTest(test.TestCase): diff --git a/tensorflow/python/kernel_tests/distributions/categorical_test.py b/tensorflow/python/kernel_tests/distributions/categorical_test.py index 019c1bc353a9891da6967a7ce9114b58226a980a..ca2358fe99934e110ba743c6085d1f25ff0f5e5e 100644 --- a/tensorflow/python/kernel_tests/distributions/categorical_test.py +++ b/tensorflow/python/kernel_tests/distributions/categorical_test.py @@ -100,6 +100,10 @@ class CategoricalTest(test.TestCase): self.assertEqual( dist.logits.dtype, dist.log_prob(np.array( 0, dtype=np.int64)).dtype) + for dtype in [dtypes.float16, dtypes.float32, dtypes.float64]: + dist = make_categorical([], 5, dtype=dtype) + self.assertEqual(dist.dtype, dtype) + self.assertEqual(dist.dtype, dist.sample(5).dtype) def testUnknownShape(self): with self.test_session(): diff --git a/tensorflow/python/kernel_tests/list_ops_test.py b/tensorflow/python/kernel_tests/list_ops_test.py index 8fae044e2e1e8a92db898c97b4e7824564747f69..1577b7bc8021a326eb720bdf059b8d1c568c0cc1 100644 --- a/tensorflow/python/kernel_tests/list_ops_test.py +++ b/tensorflow/python/kernel_tests/list_ops_test.py @@ -26,6 +26,7 @@ from tensorflow.python.eager import backprop from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops @@ -82,6 +83,21 @@ class ListOpsTest(test_util.TensorFlowTestCase): with context.device("gpu:0"): self.testTensorListFromTensor() + def testGetSetItem(self): + t = constant_op.constant([1.0, 2.0]) + l = list_ops.tensor_list_from_tensor(t, element_shape=scalar_shape()) + e0 = list_ops.tensor_list_get_item(l, 0, element_dtype=dtypes.float32) + self.assertAllEqual(e0, 1.0) + l = list_ops.tensor_list_set_item(l, 0, 3.0) + t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32) + self.assertAllEqual(t, [3.0, 2.0]) + + def testGetSetGPU(self): + if not context.num_gpus(): + return + with context.device("gpu:0"): + self.testGetSetItem() + def testUnknownShape(self): l = list_ops.empty_tensor_list(element_dtype=dtypes.float32, element_shape=-1) @@ -159,6 +175,27 @@ class ListOpsTest(test_util.TensorFlowTestCase): result = c2 * 2.0 self.assertAllEqual(tape.gradient(result, [c])[0], [2.0, 2.0]) + def testGetSetGradients(self): + with backprop.GradientTape() as tape: + c = constant_op.constant([1.0, 2.0]) + tape.watch(c) + l = list_ops.tensor_list_from_tensor(c, element_shape=scalar_shape()) + c2 = constant_op.constant(3.0) + tape.watch(c2) + l = list_ops.tensor_list_set_item(l, 0, c2) + e = list_ops.tensor_list_get_item(l, 0, element_dtype=dtypes.float32) + ee = list_ops.tensor_list_get_item(l, 1, element_dtype=dtypes.float32) + y = e * e + ee * ee + grad_c, grad_c2 = tape.gradient(y, [c, c2]) + self.assertAllEqual(grad_c, [0.0, 4.0]) + self.assertAllEqual(grad_c2, 6.0) + + def testSetOutOfBounds(self): + c = constant_op.constant([1.0, 2.0]) + l = list_ops.tensor_list_from_tensor(c, element_shape=scalar_shape()) + with self.assertRaises(errors.InvalidArgumentError): + list_ops.tensor_list_set_item(l, 20, 3.0) + if __name__ == "__main__": ops.enable_eager_execution() diff --git a/tensorflow/python/kernel_tests/metrics_test.py b/tensorflow/python/kernel_tests/metrics_test.py index 3358b78efd22f86b455041d72e6ff663f74acdd8..e0e752147cdf8690d22fa782aca2561b2935fa8e 100644 --- a/tensorflow/python/kernel_tests/metrics_test.py +++ b/tensorflow/python/kernel_tests/metrics_test.py @@ -3628,7 +3628,8 @@ class MeanPerClassAccuracyTest(test.TestCase): predictions=array_ops.ones([10, 1]), labels=array_ops.ones([10, 1]), num_classes=2) - _assert_metric_variables(self, ('mean_accuracy/total_confusion_matrix:0',)) + _assert_metric_variables(self, ('mean_accuracy/count:0', + 'mean_accuracy/total:0')) def testMetricsCollections(self): my_collection_name = '__metrics__' @@ -3797,23 +3798,6 @@ class MeanPerClassAccuracyTest(test.TestCase): desired_output = np.mean([1.0 / 2.0, 2.0 / 3.0, 0.]) self.assertAlmostEqual(desired_output, mean_accuracy.eval()) - def testUpdateOpEvalIsAccumulatedConfusionMatrix(self): - predictions = array_ops.concat([ - constant_op.constant(0, shape=[5]), constant_op.constant(1, shape=[5]) - ], 0) - labels = array_ops.concat([ - constant_op.constant(0, shape=[3]), constant_op.constant(1, shape=[7]) - ], 0) - num_classes = 2 - with self.test_session() as sess: - mean_accuracy, update_op = metrics.mean_per_class_accuracy( - labels, predictions, num_classes) - sess.run(variables.local_variables_initializer()) - confusion_matrix = update_op.eval() - self.assertAllEqual([[3, 0], [2, 5]], confusion_matrix) - desired_mean_accuracy = np.mean([3. / 3., 5. / 7.]) - self.assertAlmostEqual(desired_mean_accuracy, mean_accuracy.eval()) - def testAllCorrect(self): predictions = array_ops.zeros([40]) labels = array_ops.zeros([40]) @@ -3822,7 +3806,7 @@ class MeanPerClassAccuracyTest(test.TestCase): mean_accuracy, update_op = metrics.mean_per_class_accuracy( labels, predictions, num_classes) sess.run(variables.local_variables_initializer()) - self.assertEqual(40, update_op.eval()[0]) + self.assertEqual(1.0, update_op.eval()[0]) self.assertEqual(1.0, mean_accuracy.eval()) def testAllWrong(self): @@ -3833,7 +3817,7 @@ class MeanPerClassAccuracyTest(test.TestCase): mean_accuracy, update_op = metrics.mean_per_class_accuracy( labels, predictions, num_classes) sess.run(variables.local_variables_initializer()) - self.assertAllEqual([[0, 0], [40, 0]], update_op.eval()) + self.assertAllEqual([0.0, 0.0], update_op.eval()) self.assertEqual(0., mean_accuracy.eval()) def testResultsWithSomeMissing(self): @@ -3852,8 +3836,9 @@ class MeanPerClassAccuracyTest(test.TestCase): mean_accuracy, update_op = metrics.mean_per_class_accuracy( labels, predictions, num_classes, weights=weights) sess.run(variables.local_variables_initializer()) - self.assertAllEqual([[2, 0], [2, 4]], update_op.eval()) - desired_mean_accuracy = np.mean([2. / 2., 4. / 6.]) + desired_accuracy = np.array([2. / 2., 4. / 6.], dtype=np.float32) + self.assertAllEqual(desired_accuracy, update_op.eval()) + desired_mean_accuracy = np.mean(desired_accuracy) self.assertAlmostEqual(desired_mean_accuracy, mean_accuracy.eval()) diff --git a/tensorflow/python/kernel_tests/resource_variable_ops_test.py b/tensorflow/python/kernel_tests/resource_variable_ops_test.py index 7b131a5b8ca46cc205ec29d5a48cd704b1c67b04..b4b555591d054226210eb6af20036967b240928f 100644 --- a/tensorflow/python/kernel_tests/resource_variable_ops_test.py +++ b/tensorflow/python/kernel_tests/resource_variable_ops_test.py @@ -38,6 +38,7 @@ from tensorflow.python.ops import variables from tensorflow.python.platform import test +@test_util.with_c_api class ResourceVariableOpsTest(test_util.TensorFlowTestCase): def tearDown(self): @@ -342,14 +343,14 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): v = resource_variable_ops.ResourceVariable( 2.0, caching_device="/job:localhost") self.assertEqual("/job:localhost", v.value().device) - with self.assertRaisesRegexp(ValueError, "No attr named '_class'"): + with self.assertRaises(ValueError): _ = v.value().op.get_attr("_class") with ops.colocate_with(v.op): w = resource_variable_ops.ResourceVariable( 2.0, caching_device="/job:localhost") self.assertEqual("/job:localhost", w.value().device) - with self.assertRaisesRegexp(ValueError, "No attr named '_class'"): + with self.assertRaises(ValueError): _ = w.value().op.get_attr("_class") def testSharedName(self): diff --git a/tensorflow/python/kernel_tests/tensordot_op_test.py b/tensorflow/python/kernel_tests/tensordot_op_test.py index f375157287460daff42670db4e30a06b6e75d177..38205518b528b44313b1de83d06707b4498f061d 100644 --- a/tensorflow/python/kernel_tests/tensordot_op_test.py +++ b/tensorflow/python/kernel_tests/tensordot_op_test.py @@ -64,7 +64,7 @@ class TensordotTest(test_lib.TestCase): a = [[1, 2], [3, 4]] b = [[1, 2], [3, 4]] # Invalid static axes. - for axes_value in -1, 0, [1], [[1]], [[1], [0, 1]]: + for axes_value in -1, 3, [1], [[1]], [[1], [0, 1]]: with self.assertRaises(ValueError): math_ops.tensordot(a, b, axes_value) @@ -87,7 +87,7 @@ class TensordotTest(test_lib.TestCase): # Test case for 11950 def test_valid_axis(self): - for axes_value in [1, 2], [[1], [2]]: + for axes_value in [1, 2], [[1], [2]], [[], []], 0: with self.test_session() as sess: np_a = np.ones((3,3)) np_b = np.array([2, 3, 1])[None, None] @@ -102,29 +102,29 @@ class TensordotTest(test_lib.TestCase): def test_partial_shape_inference(self): - a = array_ops.placeholder(dtypes.float32) - b = array_ops.placeholder(dtypes.float32) - axes = ([1], [0]) - output = math_ops.tensordot(a, b, axes) - self.assertEqual(output.get_shape().ndims, None) - a.set_shape([None, 2]) - b.set_shape([2, 3]) - output = math_ops.tensordot(a, b, axes) - output_shape = output.get_shape() - self.assertEqual(output_shape.ndims, 2) - output_shape = output_shape.as_list() - self.assertEqual(output_shape[0], None) - self.assertEqual(output_shape[1], 3) - a = array_ops.placeholder(dtypes.float32) - b = array_ops.placeholder(dtypes.float32) - a.set_shape([2, 2]) - b.set_shape([2, None]) - output = math_ops.tensordot(a, b, axes) - output_shape = output.get_shape() - self.assertEqual(output_shape.ndims, 2) - output_shape = output_shape.as_list() - self.assertEqual(output_shape[0], 2) - self.assertEqual(output_shape[1], None) + for axes in ([1],[0]), 1: + a = array_ops.placeholder(dtypes.float32) + b = array_ops.placeholder(dtypes.float32) + output = math_ops.tensordot(a, b, axes) + self.assertEqual(output.get_shape().ndims, None) + a.set_shape([None, 2]) + b.set_shape([2, 3]) + output = math_ops.tensordot(a, b, axes) + output_shape = output.get_shape() + self.assertEqual(output_shape.ndims, 2) + output_shape = output_shape.as_list() + self.assertEqual(output_shape[0], None) + self.assertEqual(output_shape[1], 3) + a = array_ops.placeholder(dtypes.float32) + b = array_ops.placeholder(dtypes.float32) + a.set_shape([2, 2]) + b.set_shape([2, None]) + output = math_ops.tensordot(a, b, axes) + output_shape = output.get_shape() + self.assertEqual(output_shape.ndims, 2) + output_shape = output_shape.as_list() + self.assertEqual(output_shape[0], 2) + self.assertEqual(output_shape[1], None) def _get_tensordot_tests(dtype_, rank_a_, rank_b_, num_dims_, dynamic_shape_): @@ -191,8 +191,8 @@ def _get_tensordot_tests(dtype_, rank_a_, rank_b_, num_dims_, dynamic_shape_): low=-1.0, high=1.0, size=np.prod(shape)).reshape(shape).astype(dtype_) b_np = np.random.uniform( low=-1.0, high=1.0, size=np.prod(shape)).reshape(shape).astype(dtype_) - all_axes = [1] - if a_np.ndim > 1: + all_axes = [0, 1] + if a_np.ndim > 2: all_axes.append(a_np.ndim - 1) for axes in all_axes: np_ans = np.tensordot(a_np, b_np, axes=axes) diff --git a/tensorflow/python/kernel_tests/variable_scope_test.py b/tensorflow/python/kernel_tests/variable_scope_test.py index f1a86625e080274e33abecc9db6e0c9957010d01..8527f116f9541942e52ba2ab635ca1212ea38583 100644 --- a/tensorflow/python/kernel_tests/variable_scope_test.py +++ b/tensorflow/python/kernel_tests/variable_scope_test.py @@ -131,6 +131,30 @@ class VariableScopeTest(test.TestCase): self.assertFalse(v in store.non_trainable_variables()) self.assertTrue(w in store.non_trainable_variables()) + # Test copying. + new_store = store.copy() + with new_store.as_default(): + new_v = variable_scope.get_variable("v") + new_w = variable_scope.get_variable("w") + self.assertEqual(new_v.numpy(), v.numpy()) + self.assertEqual(new_w.numpy(), w.numpy()) + self.assertTrue(new_v in new_store.variables()) + self.assertTrue(new_w in new_store.variables()) + self.assertTrue(new_v in new_store.trainable_variables()) + self.assertFalse(new_w in new_store.trainable_variables()) + self.assertFalse(new_v in new_store.non_trainable_variables()) + self.assertTrue(new_w in new_store.non_trainable_variables()) + + # Check that variables are separate instances. + for v in store.variables(): + v.assign(-1) + for v in new_store.variables(): + v.assign(1) + for v in store.variables(): + self.assertEqual(v.numpy(), -1) + for v in new_store.variables(): + self.assertEqual(v.numpy(), 1) + @test_util.run_in_graph_and_eager_modes() def testInitFromNonTensorValue(self): v = variable_scope.get_variable("v4", initializer=4, dtype=dtypes.int32) @@ -1253,6 +1277,24 @@ class VariableScopeWithCustomGetterTest(test.TestCase): (((np_vars[0] * np_vars[1]) + (np_vars[2] * np_vars[3])) + ((np_vars[4] * np_vars[5]) + (np_vars[6] * np_vars[7])))) + def testVariableCreator(self): + + variable_names = [] + + def creator_a(next_creator, **kwargs): + variable_names.append(kwargs.get("name", "")) + return next_creator(**kwargs) + + def creator_b(next_creator, **kwargs): + kwargs["name"] = "forced_name" + return next_creator(**kwargs) + + with variable_scope.variable_creator_scope(creator_a): + with variable_scope.variable_creator_scope(creator_b): + variable_scope.variable(1.0, name="one_name") + + self.assertAllEqual(variable_names, ["forced_name"]) + class PartitionInfoTest(test.TestCase): diff --git a/tensorflow/python/kernel_tests/xent_op_test.py b/tensorflow/python/kernel_tests/xent_op_test.py index 43be08f8a1436eebdd712a4bbb69ce8ae8d12827..c6c7c4e26cb5e4eff22d1bb9d3e32c227c1c838f 100644 --- a/tensorflow/python/kernel_tests/xent_op_test.py +++ b/tensorflow/python/kernel_tests/xent_op_test.py @@ -240,6 +240,16 @@ class XentTest(test.TestCase): self._testXentWrapper(features, labels, dim=-1, use_gpu=False) self._testXentWrapper(features, labels, dim=-1, use_gpu=True) + def testZeroDimension(self): + features = np.zeros([0, 2, 4]).astype(np.float32) + labels = np.zeros([0, 2, 4]).astype(np.float32) + np_loss, _ = self._npXent(features, labels) + with self.test_session(use_gpu=True) as sess: + loss = nn_ops.softmax_cross_entropy_with_logits( + labels=labels, logits=features) + tf_loss = sess.run(loss) + self.assertAllEqual(np_loss, tf_loss) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/layers/convolutional.py b/tensorflow/python/layers/convolutional.py index ab1fa551e1171db60cbb3b080f453036862c895c..e8dba3cea321a415b84e1ec89fd7b021e2b272d0 100644 --- a/tensorflow/python/layers/convolutional.py +++ b/tensorflow/python/layers/convolutional.py @@ -819,8 +819,8 @@ def conv3d(inputs, return layer.apply(inputs) -class SeparableConv2D(Conv2D): - """Depthwise separable 2D convolution. +class _SeparableConv(_Conv): + """Abstract base layer for separable nD convolution. This layer performs a depthwise convolution that acts separately on channels, followed by a pointwise convolution that mixes channels. @@ -829,12 +829,13 @@ class SeparableConv2D(Conv2D): It then optionally applies an activation function to produce the final output. Arguments: + rank: An integer, the rank of the convolution, e.g. "2" for 2D convolution. filters: Integer, the dimensionality of the output space (i.e. the number of filters in the convolution). - kernel_size: A tuple or list of 2 integers specifying the spatial + kernel_size: A tuple or list of integers specifying the spatial dimensions of the filters. Can be a single integer to specify the same value for all spatial dimensions. - strides: A tuple or list of 2 positive integers specifying the strides + strides: A tuple or list of integers specifying the strides of the convolution. Can be a single integer to specify the same value for all spatial dimensions. Specifying any `stride` value != 1 is incompatible with specifying @@ -843,9 +844,8 @@ class SeparableConv2D(Conv2D): data_format: A string, one of `channels_last` (default) or `channels_first`. The ordering of the dimensions in the inputs. `channels_last` corresponds to inputs with shape - `(batch, height, width, channels)` while `channels_first` corresponds to - inputs with shape `(batch, channels, height, width)`. - + `(batch, ..., channels)` while `channels_first` corresponds to + inputs with shape `(batch, channels, ...)`. dilation_rate: An integer or tuple/list of 2 integers, specifying the dilation rate to use for dilated convolution. Can be a single integer to specify the same value for @@ -883,12 +883,14 @@ class SeparableConv2D(Conv2D): name: A string, the name of the layer. """ - def __init__(self, filters, + def __init__(self, + rank, + filters, kernel_size, - strides=(1, 1), + strides=1, padding='valid', data_format='channels_last', - dilation_rate=(1, 1), + dilation_rate=1, depth_multiplier=1, activation=None, use_bias=True, @@ -905,7 +907,8 @@ class SeparableConv2D(Conv2D): trainable=True, name=None, **kwargs): - super(SeparableConv2D, self).__init__( + super(_SeparableConv, self).__init__( + rank=rank, filters=filters, kernel_size=kernel_size, strides=strides, @@ -920,7 +923,6 @@ class SeparableConv2D(Conv2D): trainable=trainable, name=name, **kwargs) - self.data_format = data_format self.depth_multiplier = depth_multiplier self.depthwise_initializer = depthwise_initializer self.pointwise_initializer = pointwise_initializer @@ -930,26 +932,21 @@ class SeparableConv2D(Conv2D): self.pointwise_constraint = pointwise_constraint def build(self, input_shape): - if len(input_shape) < 4: - raise ValueError('Inputs to `SeparableConv2D` should have rank 4. ' - 'Received input shape:', str(input_shape)) + input_shape = tensor_shape.TensorShape(input_shape) if self.data_format == 'channels_first': channel_axis = 1 else: - channel_axis = 3 - if input_shape[channel_axis] is None: - raise ValueError('The channel dimension of the inputs to ' - '`SeparableConv2D` ' + channel_axis = -1 + if input_shape[channel_axis].value is None: + raise ValueError('The channel dimension of the inputs ' 'should be defined. Found `None`.') - input_dim = int(input_shape[channel_axis]) - self.input_spec = base.InputSpec(ndim=4, axes={channel_axis: input_dim}) - depthwise_kernel_shape = (self.kernel_size[0], - self.kernel_size[1], - input_dim, - self.depth_multiplier) - pointwise_kernel_shape = (1, 1, - self.depth_multiplier * input_dim, - self.filters) + input_dim = input_shape[channel_axis].value + self.input_spec = base.InputSpec(ndim=self.rank + 2, + axes={channel_axis: input_dim}) + depthwise_kernel_shape = self.kernel_size + (input_dim, + self.depth_multiplier) + pointwise_kernel_shape = ( + 1,) * self.rank + (self.depth_multiplier * input_dim, self.filters) self.depthwise_kernel = self.add_variable( name='depthwise_kernel', @@ -979,6 +976,264 @@ class SeparableConv2D(Conv2D): self.bias = None self.built = True + def call(self, inputs): + raise NotImplementedError + + +class SeparableConv1D(_SeparableConv): + """Depthwise separable 1D convolution. + + This layer performs a depthwise convolution that acts separately on + channels, followed by a pointwise convolution that mixes channels. + If `use_bias` is True and a bias initializer is provided, + it adds a bias vector to the output. + It then optionally applies an activation function to produce the final output. + + Arguments: + filters: Integer, the dimensionality of the output space (i.e. the number + of filters in the convolution). + kernel_size: A single integer specifying the spatial + dimensions of the filters. + strides: A single integer specifying the strides + of the convolution. + Specifying any `stride` value != 1 is incompatible with specifying + any `dilation_rate` value != 1. + padding: One of `"valid"` or `"same"` (case-insensitive). + data_format: A string, one of `channels_last` (default) or `channels_first`. + The ordering of the dimensions in the inputs. + `channels_last` corresponds to inputs with shape + `(batch, length, channels)` while `channels_first` corresponds to + inputs with shape `(batch, channels, length)`. + dilation_rate: A single integer, specifying + the dilation rate to use for dilated convolution. + Currently, specifying any `dilation_rate` value != 1 is + incompatible with specifying any stride value != 1. + depth_multiplier: The number of depthwise convolution output channels for + each input channel. The total number of depthwise convolution output + channels will be equal to `num_filters_in * depth_multiplier`. + activation: Activation function. Set it to None to maintain a + linear activation. + use_bias: Boolean, whether the layer uses a bias. + depthwise_initializer: An initializer for the depthwise convolution kernel. + pointwise_initializer: An initializer for the pointwise convolution kernel. + bias_initializer: An initializer for the bias vector. If None, the default + initializer will be used. + depthwise_regularizer: Optional regularizer for the depthwise + convolution kernel. + pointwise_regularizer: Optional regularizer for the pointwise + convolution kernel. + bias_regularizer: Optional regularizer for the bias vector. + activity_regularizer: Optional regularizer function for the output. + depthwise_constraint: Optional projection function to be applied to the + depthwise kernel after being updated by an `Optimizer` (e.g. used for + norm constraints or value constraints for layer weights). The function + must take as input the unprojected variable and must return the + projected variable (which must have the same shape). Constraints are + not safe to use when doing asynchronous distributed training. + pointwise_constraint: Optional projection function to be applied to the + pointwise kernel after being updated by an `Optimizer`. + bias_constraint: Optional projection function to be applied to the + bias after being updated by an `Optimizer`. + trainable: Boolean, if `True` also add variables to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). + name: A string, the name of the layer. + """ + + def __init__(self, filters, + kernel_size, + strides=1, + padding='valid', + data_format='channels_last', + dilation_rate=1, + depth_multiplier=1, + activation=None, + use_bias=True, + depthwise_initializer=None, + pointwise_initializer=None, + bias_initializer=init_ops.zeros_initializer(), + depthwise_regularizer=None, + pointwise_regularizer=None, + bias_regularizer=None, + activity_regularizer=None, + depthwise_constraint=None, + pointwise_constraint=None, + bias_constraint=None, + trainable=True, + name=None, + **kwargs): + super(SeparableConv1D, self).__init__( + rank=1, + filters=filters, + kernel_size=kernel_size, + strides=strides, + padding=padding, + data_format=data_format, + dilation_rate=dilation_rate, + depth_multiplier=depth_multiplier, + activation=activation, + use_bias=use_bias, + depthwise_initializer=depthwise_initializer, + pointwise_initializer=pointwise_initializer, + bias_initializer=bias_initializer, + depthwise_regularizer=depthwise_regularizer, + pointwise_regularizer=pointwise_regularizer, + bias_regularizer=bias_regularizer, + activity_regularizer=activity_regularizer, + depthwise_constraint=depthwise_constraint, + pointwise_constraint=pointwise_constraint, + bias_constraint=bias_constraint, + trainable=trainable, + name=name, + **kwargs) + + def call(self, inputs): + if self.data_format == 'channels_last': + strides = (1, 1) + self.strides + (1,) + spatial_start_dim = 1 + else: + strides = (1, 1, 1) + self.strides + spatial_start_dim = 2 + + # Explicitly broadcast inputs and kernels to 4D. + # TODO(fchollet): refactor when a native separable_conv1d op is available. + inputs = array_ops.expand_dims(inputs, spatial_start_dim) + depthwise_kernel = array_ops.expand_dims(self.depthwise_kernel, 0) + pointwise_kernel = array_ops.expand_dims(self.pointwise_kernel, 0) + dilation_rate = (1,) + self.dilation_rate + + outputs = nn.separable_conv2d( + inputs, + depthwise_kernel, + pointwise_kernel, + strides=strides, + padding=self.padding.upper(), + rate=dilation_rate, + data_format=utils.convert_data_format(self.data_format, ndim=4)) + + if self.use_bias: + outputs = nn.bias_add( + outputs, + self.bias, + data_format=utils.convert_data_format(self.data_format, ndim=4)) + + outputs = array_ops.squeeze(outputs, [spatial_start_dim]) + + if self.activation is not None: + return self.activation(outputs) + return outputs + + +class SeparableConv2D(_SeparableConv): + """Depthwise separable 2D convolution. + + This layer performs a depthwise convolution that acts separately on + channels, followed by a pointwise convolution that mixes channels. + If `use_bias` is True and a bias initializer is provided, + it adds a bias vector to the output. + It then optionally applies an activation function to produce the final output. + + Arguments: + filters: Integer, the dimensionality of the output space (i.e. the number + of filters in the convolution). + kernel_size: A tuple or list of 2 integers specifying the spatial + dimensions of the filters. Can be a single integer to specify the same + value for all spatial dimensions. + strides: A tuple or list of 2 positive integers specifying the strides + of the convolution. Can be a single integer to specify the same value for + all spatial dimensions. + Specifying any `stride` value != 1 is incompatible with specifying + any `dilation_rate` value != 1. + padding: One of `"valid"` or `"same"` (case-insensitive). + data_format: A string, one of `channels_last` (default) or `channels_first`. + The ordering of the dimensions in the inputs. + `channels_last` corresponds to inputs with shape + `(batch, height, width, channels)` while `channels_first` corresponds to + inputs with shape `(batch, channels, height, width)`. + + dilation_rate: An integer or tuple/list of 2 integers, specifying + the dilation rate to use for dilated convolution. + Can be a single integer to specify the same value for + all spatial dimensions. + Currently, specifying any `dilation_rate` value != 1 is + incompatible with specifying any stride value != 1. + depth_multiplier: The number of depthwise convolution output channels for + each input channel. The total number of depthwise convolution output + channels will be equal to `num_filters_in * depth_multiplier`. + activation: Activation function. Set it to None to maintain a + linear activation. + use_bias: Boolean, whether the layer uses a bias. + depthwise_initializer: An initializer for the depthwise convolution kernel. + pointwise_initializer: An initializer for the pointwise convolution kernel. + bias_initializer: An initializer for the bias vector. If None, the default + initializer will be used. + depthwise_regularizer: Optional regularizer for the depthwise + convolution kernel. + pointwise_regularizer: Optional regularizer for the pointwise + convolution kernel. + bias_regularizer: Optional regularizer for the bias vector. + activity_regularizer: Optional regularizer function for the output. + depthwise_constraint: Optional projection function to be applied to the + depthwise kernel after being updated by an `Optimizer` (e.g. used for + norm constraints or value constraints for layer weights). The function + must take as input the unprojected variable and must return the + projected variable (which must have the same shape). Constraints are + not safe to use when doing asynchronous distributed training. + pointwise_constraint: Optional projection function to be applied to the + pointwise kernel after being updated by an `Optimizer`. + bias_constraint: Optional projection function to be applied to the + bias after being updated by an `Optimizer`. + trainable: Boolean, if `True` also add variables to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). + name: A string, the name of the layer. + """ + + def __init__(self, filters, + kernel_size, + strides=(1, 1), + padding='valid', + data_format='channels_last', + dilation_rate=(1, 1), + depth_multiplier=1, + activation=None, + use_bias=True, + depthwise_initializer=None, + pointwise_initializer=None, + bias_initializer=init_ops.zeros_initializer(), + depthwise_regularizer=None, + pointwise_regularizer=None, + bias_regularizer=None, + activity_regularizer=None, + depthwise_constraint=None, + pointwise_constraint=None, + bias_constraint=None, + trainable=True, + name=None, + **kwargs): + super(SeparableConv2D, self).__init__( + rank=2, + filters=filters, + kernel_size=kernel_size, + strides=strides, + padding=padding, + data_format=data_format, + dilation_rate=dilation_rate, + depth_multiplier=depth_multiplier, + activation=activation, + use_bias=use_bias, + depthwise_initializer=depthwise_initializer, + pointwise_initializer=pointwise_initializer, + bias_initializer=bias_initializer, + depthwise_regularizer=depthwise_regularizer, + pointwise_regularizer=pointwise_regularizer, + bias_regularizer=bias_regularizer, + activity_regularizer=activity_regularizer, + depthwise_constraint=depthwise_constraint, + pointwise_constraint=pointwise_constraint, + bias_constraint=bias_constraint, + trainable=trainable, + name=name, + **kwargs) + def call(self, inputs): # Apply the actual ops. if self.data_format == 'channels_last': @@ -1004,25 +1259,121 @@ class SeparableConv2D(Conv2D): return self.activation(outputs) return outputs - def compute_output_shape(self, input_shape): - input_shape = tensor_shape.TensorShape(input_shape).as_list() - if self.data_format == 'channels_first': - rows = input_shape[2] - cols = input_shape[3] - else: - rows = input_shape[1] - cols = input_shape[2] - rows = utils.conv_output_length(rows, self.kernel_size[0], - self.padding, self.strides[0]) - cols = utils.conv_output_length(cols, self.kernel_size[1], - self.padding, self.strides[1]) - if self.data_format == 'channels_first': - return tensor_shape.TensorShape( - [input_shape[0], self.filters, rows, cols]) - else: - return tensor_shape.TensorShape( - [input_shape[0], rows, cols, self.filters]) +def separable_conv1d(inputs, + filters, + kernel_size, + strides=1, + padding='valid', + data_format='channels_last', + dilation_rate=1, + depth_multiplier=1, + activation=None, + use_bias=True, + depthwise_initializer=None, + pointwise_initializer=None, + bias_initializer=init_ops.zeros_initializer(), + depthwise_regularizer=None, + pointwise_regularizer=None, + bias_regularizer=None, + activity_regularizer=None, + depthwise_constraint=None, + pointwise_constraint=None, + bias_constraint=None, + trainable=True, + name=None, + reuse=None): + """Functional interface for the depthwise separable 1D convolution layer. + + This layer performs a depthwise convolution that acts separately on + channels, followed by a pointwise convolution that mixes channels. + If `use_bias` is True and a bias initializer is provided, + it adds a bias vector to the output. + It then optionally applies an activation function to produce the final output. + + Arguments: + inputs: Input tensor. + filters: Integer, the dimensionality of the output space (i.e. the number + of filters in the convolution). + kernel_size: A single integer specifying the spatial + dimensions of the filters. + strides: A single integer specifying the strides + of the convolution. + Specifying any `stride` value != 1 is incompatible with specifying + any `dilation_rate` value != 1. + padding: One of `"valid"` or `"same"` (case-insensitive). + data_format: A string, one of `channels_last` (default) or `channels_first`. + The ordering of the dimensions in the inputs. + `channels_last` corresponds to inputs with shape + `(batch, length, channels)` while `channels_first` corresponds to + inputs with shape `(batch, channels, length)`. + dilation_rate: A single integer, specifying + the dilation rate to use for dilated convolution. + Currently, specifying any `dilation_rate` value != 1 is + incompatible with specifying any stride value != 1. + depth_multiplier: The number of depthwise convolution output channels for + each input channel. The total number of depthwise convolution output + channels will be equal to `num_filters_in * depth_multiplier`. + activation: Activation function. Set it to None to maintain a + linear activation. + use_bias: Boolean, whether the layer uses a bias. + depthwise_initializer: An initializer for the depthwise convolution kernel. + pointwise_initializer: An initializer for the pointwise convolution kernel. + bias_initializer: An initializer for the bias vector. If None, the default + initializer will be used. + depthwise_regularizer: Optional regularizer for the depthwise + convolution kernel. + pointwise_regularizer: Optional regularizer for the pointwise + convolution kernel. + bias_regularizer: Optional regularizer for the bias vector. + activity_regularizer: Optional regularizer function for the output. + depthwise_constraint: Optional projection function to be applied to the + depthwise kernel after being updated by an `Optimizer` (e.g. used for + norm constraints or value constraints for layer weights). The function + must take as input the unprojected variable and must return the + projected variable (which must have the same shape). Constraints are + not safe to use when doing asynchronous distributed training. + pointwise_constraint: Optional projection function to be applied to the + pointwise kernel after being updated by an `Optimizer`. + bias_constraint: Optional projection function to be applied to the + bias after being updated by an `Optimizer`. + trainable: Boolean, if `True` also add variables to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). + name: A string, the name of the layer. + reuse: Boolean, whether to reuse the weights of a previous layer + by the same name. + + Returns: + Output tensor. + + Raises: + ValueError: if eager execution is enabled. + """ + layer = SeparableConv1D( + filters=filters, + kernel_size=kernel_size, + strides=strides, + padding=padding, + data_format=data_format, + dilation_rate=dilation_rate, + depth_multiplier=depth_multiplier, + activation=activation, + use_bias=use_bias, + depthwise_initializer=depthwise_initializer, + pointwise_initializer=pointwise_initializer, + bias_initializer=bias_initializer, + depthwise_regularizer=depthwise_regularizer, + pointwise_regularizer=pointwise_regularizer, + bias_regularizer=bias_regularizer, + activity_regularizer=activity_regularizer, + depthwise_constraint=depthwise_constraint, + pointwise_constraint=pointwise_constraint, + bias_constraint=bias_constraint, + trainable=trainable, + name=name, + _reuse=reuse, + _scope=name) + return layer.apply(inputs) def separable_conv2d(inputs, @@ -1553,6 +1904,7 @@ class Conv3DTranspose(Conv3D): dtype=self.dtype) else: self.bias = None + self.built = True def call(self, inputs): inputs_shape = array_ops.shape(inputs) @@ -1623,6 +1975,8 @@ class Conv3DTranspose(Conv3D): if self.use_bias: outputs_shape = outputs.shape.as_list() + if outputs_shape[0] is None: + outputs_shape[0] = -1 if self.data_format == 'channels_first': outputs_4d = array_ops.reshape(outputs, [ outputs_shape[0], outputs_shape[1], @@ -1656,11 +2010,11 @@ class Conv3DTranspose(Conv3D): output_shape[c_axis] = self.filters output_shape[d_axis] = utils.deconv_output_length( - output_shape[d_axis], stride_d, kernel_d, self.padding) + output_shape[d_axis], kernel_d, self.padding, stride_d) output_shape[h_axis] = utils.deconv_output_length( - output_shape[h_axis], stride_h, kernel_h, self.padding) + output_shape[h_axis], kernel_h, self.padding, stride_h) output_shape[w_axis] = utils.deconv_output_length( - output_shape[w_axis], stride_w, kernel_w, self.padding) + output_shape[w_axis], kernel_w, self.padding, stride_w) return tensor_shape.TensorShape(output_shape) diff --git a/tensorflow/python/layers/convolutional_test.py b/tensorflow/python/layers/convolutional_test.py index e41eb5c32ff8ee825c0bd900efd58166017004d5..160e732b6798697d05815e13a7b1c399070f0783 100644 --- a/tensorflow/python/layers/convolutional_test.py +++ b/tensorflow/python/layers/convolutional_test.py @@ -326,6 +326,168 @@ class ConvTest(test.TestCase): self.assertEqual(conv3d.bias_constraint, b_constraint) +@test_util.with_c_api +class SeparableConv1DTest(test.TestCase): + + def testInvalidDataFormat(self): + length = 9 + data = random_ops.random_uniform((5, length, 3), seed=1) + with self.assertRaisesRegexp(ValueError, 'data_format'): + conv_layers.separable_conv1d(data, 32, 3, data_format='invalid') + + def testInvalidStrides(self): + length = 9 + data = random_ops.random_uniform((5, length, 3), seed=1) + with self.assertRaisesRegexp(ValueError, 'strides'): + conv_layers.separable_conv1d(data, 32, 3, strides=(1, 2)) + + with self.assertRaisesRegexp(ValueError, 'strides'): + conv_layers.separable_conv1d(data, 32, 3, strides=None) + + def testInvalidKernelSize(self): + length = 9 + data = random_ops.random_uniform((5, length, 3), seed=1) + with self.assertRaisesRegexp(ValueError, 'kernel_size'): + conv_layers.separable_conv1d(data, 32, (1, 2)) + + with self.assertRaisesRegexp(ValueError, 'kernel_size'): + conv_layers.separable_conv1d(data, 32, None) + + def testCreateSeparableConv1D(self): + length = 9 + data = random_ops.random_uniform((5, length, 4)) + layer = conv_layers.SeparableConv1D(32, 3, activation=nn_ops.relu) + output = layer.apply(data) + self.assertEqual(output.op.name, 'separable_conv1d/Relu') + self.assertEqual(output.get_shape().as_list(), [5, length - 2, 32]) + self.assertEqual(layer.depthwise_kernel.get_shape().as_list(), [3, 4, 1]) + self.assertEqual(layer.pointwise_kernel.get_shape().as_list(), [1, 4, 32]) + self.assertEqual(layer.bias.get_shape().as_list(), [32]) + + def testCreateSeparableConv1DDepthMultiplier(self): + length = 9 + data = random_ops.random_uniform((5, length, 4)) + layer = conv_layers.SeparableConv1D(32, 3, depth_multiplier=2) + output = layer.apply(data) + self.assertEqual(output.get_shape().as_list(), [5, length - 2, 32]) + self.assertEqual(layer.depthwise_kernel.get_shape().as_list(), [3, 4, 2]) + self.assertEqual(layer.pointwise_kernel.get_shape().as_list(), [1, 8, 32]) + self.assertEqual(layer.bias.get_shape().as_list(), [32]) + + def testCreateSeparableConv1DChannelsFirst(self): + length = 9 + data = random_ops.random_uniform((5, 4, length)) + layer = conv_layers.SeparableConv1D(32, 3, data_format='channels_first') + output = layer.apply(data) + self.assertEqual(output.get_shape().as_list(), [5, 32, length - 2]) + self.assertEqual(layer.depthwise_kernel.get_shape().as_list(), [3, 4, 1]) + self.assertEqual(layer.pointwise_kernel.get_shape().as_list(), [1, 4, 32]) + self.assertEqual(layer.bias.get_shape().as_list(), [32]) + + def testSeparableConv1DPaddingSame(self): + length = 9 + data = random_ops.random_uniform((5, length, 32), seed=1) + layer = conv_layers.SeparableConv1D( + 64, length, padding='same') + output = layer.apply(data) + self.assertEqual(output.get_shape().as_list(), [5, length, 64]) + + def testCreateSeparableConv1DWithStrides(self): + length = 10 + data = random_ops.random_uniform((5, length, 3), seed=1) + layer = conv_layers.SeparableConv1D(32, 3, strides=2, padding='same') + output = layer.apply(data) + self.assertEqual(output.get_shape().as_list(), [5, length // 2, 32]) + + def testCreateSeparableConv1DWithStridesChannelsFirst(self): + data_format = 'channels_first' + length = 10 + data = random_ops.random_uniform((5, 3, length), seed=1) + layer = conv_layers.SeparableConv1D( + 32, 3, strides=2, padding='same', data_format=data_format) + output = layer.apply(data) + self.assertEqual(output.get_shape().as_list(), [5, 32, length // 2]) + + def testFunctionalConv1DReuse(self): + length = 10 + data = random_ops.random_uniform((5, length, 3), seed=1) + conv_layers.separable_conv1d(data, 32, 3, name='sepconv1') + self.assertEqual(len(variables.trainable_variables()), 3) + conv_layers.separable_conv1d(data, 32, 3, name='sepconv1', reuse=True) + self.assertEqual(len(variables.trainable_variables()), 3) + + def testFunctionalConv1DReuseFromScope(self): + with variable_scope.variable_scope('scope'): + length = 10 + data = random_ops.random_uniform((5, length, 3), seed=1) + conv_layers.separable_conv1d(data, 32, 3, name='sepconv1') + self.assertEqual(len(variables.trainable_variables()), 3) + with variable_scope.variable_scope('scope', reuse=True): + conv_layers.separable_conv1d(data, 32, 3, name='sepconv1') + self.assertEqual(len(variables.trainable_variables()), 3) + + def testFunctionalConv1DNoReuse(self): + length = 10 + data = random_ops.random_uniform((5, length, 3), seed=1) + conv_layers.separable_conv1d(data, 32, 3) + self.assertEqual(len(variables.trainable_variables()), 3) + conv_layers.separable_conv1d(data, 32, 3) + self.assertEqual(len(variables.trainable_variables()), 6) + + def testSeparableConv1DDepthwiseRegularizer(self): + length = 9 + data = random_ops.random_uniform((5, length, 4)) + reg = lambda x: 0.1 * math_ops.reduce_sum(x) + layer = conv_layers.SeparableConv1D(32, 3, depthwise_regularizer=reg) + layer.apply(data) + loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) + self.assertEqual(len(loss_keys), 1) + self.assertEqual(layer.losses, loss_keys) + + def testSeparableConv1DPointwiseRegularizer(self): + length = 9 + data = random_ops.random_uniform((5, length, 4)) + reg = lambda x: 0.1 * math_ops.reduce_sum(x) + layer = conv_layers.SeparableConv1D(32, 3, pointwise_regularizer=reg) + layer.apply(data) + loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) + self.assertEqual(len(loss_keys), 1) + self.assertEqual(layer.losses, loss_keys) + + def testSeparableConv1DBiasRegularizer(self): + length = 9 + data = random_ops.random_uniform((5, length, 4)) + reg = lambda x: 0.1 * math_ops.reduce_sum(x) + layer = conv_layers.SeparableConv1D(32, 3, bias_regularizer=reg) + layer.apply(data) + loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) + self.assertEqual(len(loss_keys), 1) + self.assertEqual(layer.losses, loss_keys) + + def testSeparableConv1DNoBias(self): + length = 9 + data = random_ops.random_uniform((5, length, 4)) + layer = conv_layers.SeparableConv1D( + 32, 3, activation=nn_ops.relu, use_bias=False) + output = layer.apply(data) + self.assertEqual(output.op.name, 'separable_conv1d/Relu') + self.assertEqual(layer.bias, None) + + def testConstraints(self): + d_constraint = lambda x: x / math_ops.reduce_sum(x) + p_constraint = lambda x: x / math_ops.reduce_sum(x) + b_constraint = lambda x: x / math_ops.reduce_max(x) + layer = conv_layers.SeparableConv1D(2, 3, + depthwise_constraint=d_constraint, + pointwise_constraint=p_constraint, + bias_constraint=b_constraint) + inputs = random_ops.random_uniform((5, 3, 5), seed=1) + layer(inputs) + self.assertEqual(layer.depthwise_constraint, d_constraint) + self.assertEqual(layer.pointwise_constraint, p_constraint) + self.assertEqual(layer.bias_constraint, b_constraint) + + @test_util.with_c_api class SeparableConv2DTest(test.TestCase): diff --git a/tensorflow/python/layers/layers.py b/tensorflow/python/layers/layers.py index 0a52b1e8d9216a2535f5ae99751a4f9e9757031d..1555846efde812b9e31f48315decaf1f86aa4f70 100644 --- a/tensorflow/python/layers/layers.py +++ b/tensorflow/python/layers/layers.py @@ -22,6 +22,7 @@ @@Conv1D @@Conv2D @@Conv3D +@@SeparableConv1D @@SeparableConv2D @@Conv2DTranspose @@Conv3DTranspose @@ -43,6 +44,7 @@ @@conv1d @@conv2d @@conv3d +@@separable_conv1d @@separable_conv2d @@conv2d_transpose @@conv3d_transpose @@ -78,6 +80,7 @@ from tensorflow.python.layers.core import dropout from tensorflow.python.layers.core import flatten # Convolutional layers. +from tensorflow.python.layers.convolutional import SeparableConv1D from tensorflow.python.layers.convolutional import SeparableConv2D from tensorflow.python.layers.convolutional import SeparableConvolution2D from tensorflow.python.layers.convolutional import Conv2DTranspose @@ -91,6 +94,7 @@ from tensorflow.python.layers.convolutional import Convolution2D from tensorflow.python.layers.convolutional import Conv3D from tensorflow.python.layers.convolutional import Convolution3D +from tensorflow.python.layers.convolutional import separable_conv1d from tensorflow.python.layers.convolutional import separable_conv2d from tensorflow.python.layers.convolutional import conv2d_transpose from tensorflow.python.layers.convolutional import conv3d_transpose diff --git a/tensorflow/python/layers/pooling.py b/tensorflow/python/layers/pooling.py index c6bd7aae07f55772d96cb60b39cf4ef40d9b3581..ab06a3a40826e7d41c040066fd41c56c1ed84ad2 100644 --- a/tensorflow/python/layers/pooling.py +++ b/tensorflow/python/layers/pooling.py @@ -63,14 +63,18 @@ class _Pooling1D(base.Layer): def call(self, inputs): # There is no TF op for 1D pooling, hence we make the inputs 4D. if self.data_format == 'channels_last': - inputs = array_ops.expand_dims(inputs, 2) - pool_shape = (1,) + self.pool_size + (1, 1) - strides = (1,) + self.strides + (1, 1) - data_format = 'NHWC' - else: + # input is NWC, make it NHWC inputs = array_ops.expand_dims(inputs, 1) + # pool on the W dim pool_shape = (1, 1) + self.pool_size + (1,) strides = (1, 1) + self.strides + (1,) + data_format = 'NHWC' + else: + # input is NCW, make it NCHW + inputs = array_ops.expand_dims(inputs, 2) + # pool on the W dim + pool_shape = (1, 1, 1) + self.pool_size + strides = (1, 1, 1) + self.strides data_format = 'NCHW' outputs = self.pool_function( @@ -81,9 +85,9 @@ class _Pooling1D(base.Layer): data_format=data_format) if self.data_format == 'channels_last': - return array_ops.squeeze(outputs, 2) - else: return array_ops.squeeze(outputs, 1) + else: + return array_ops.squeeze(outputs, 2) def compute_output_shape(self, input_shape): input_shape = tensor_shape.TensorShape(input_shape).as_list() diff --git a/tensorflow/python/layers/pooling_test.py b/tensorflow/python/layers/pooling_test.py index 589fee5f7196cc542b39506c5bda580a92647f0d..7533674e5a0cf60f91551cd6333c8d802612e03d 100644 --- a/tensorflow/python/layers/pooling_test.py +++ b/tensorflow/python/layers/pooling_test.py @@ -96,33 +96,41 @@ class PoolingTest(test.TestCase): def testCreateMaxPooling1D(self): width = 7 - images = random_ops.random_uniform((5, width, 4)) + channels = 3 + images = random_ops.random_uniform((5, width, channels)) layer = pooling_layers.MaxPooling1D(2, strides=2) output = layer.apply(images) - self.assertListEqual(output.get_shape().as_list(), [5, 3, 4]) + self.assertListEqual(output.get_shape().as_list(), + [5, width // 2, channels]) def testCreateAveragePooling1D(self): width = 7 - images = random_ops.random_uniform((5, width, 4)) + channels = 3 + images = random_ops.random_uniform((5, width, channels)) layer = pooling_layers.AveragePooling1D(2, strides=2) output = layer.apply(images) - self.assertListEqual(output.get_shape().as_list(), [5, 3, 4]) + self.assertListEqual(output.get_shape().as_list(), + [5, width // 2, channels]) def testCreateMaxPooling1DChannelsFirst(self): width = 7 - images = random_ops.random_uniform((5, width, 4)) + channels = 3 + images = random_ops.random_uniform((5, channels, width)) layer = pooling_layers.MaxPooling1D( 2, strides=2, data_format='channels_first') output = layer.apply(images) - self.assertListEqual(output.get_shape().as_list(), [5, 3, 4]) + self.assertListEqual(output.get_shape().as_list(), + [5, channels, width // 2]) def testCreateAveragePooling1DChannelsFirst(self): width = 7 - images = random_ops.random_uniform((5, width, 4)) + channels = 3 + images = random_ops.random_uniform((5, channels, width)) layer = pooling_layers.AveragePooling1D( 2, strides=2, data_format='channels_first') output = layer.apply(images) - self.assertListEqual(output.get_shape().as_list(), [5, 3, 4]) + self.assertListEqual(output.get_shape().as_list(), + [5, channels, width // 2]) def testCreateMaxPooling3D(self): depth, height, width = 6, 7, 9 diff --git a/tensorflow/python/lib/core/bfloat16_test.py b/tensorflow/python/lib/core/bfloat16_test.py index 985a11272c8a633d80b0372c0b6c669949e9cba8..09d4b01fa43babdc09f8f255e79bbed539ddc04c 100644 --- a/tensorflow/python/lib/core/bfloat16_test.py +++ b/tensorflow/python/lib/core/bfloat16_test.py @@ -25,6 +25,7 @@ import numpy as np # pylint: disable=unused-import,g-bad-import-order from tensorflow.python import pywrap_tensorflow +from tensorflow.python.framework import dtypes from tensorflow.python.platform import test @@ -160,6 +161,24 @@ class Bfloat16Test(test.TestCase): for w in self.float_values(): self.assertEqual(v != w, bfloat16(v) != bfloat16(w)) + def testNan(self): + a = np.isnan(bfloat16(float("nan"))) + self.assertTrue(a) + np.testing.assert_allclose(np.array([1.0, a]), np.array([1.0, a])) + + a = np.array( + [bfloat16(1.34375), + bfloat16(1.4375), + bfloat16(float("nan"))], + dtype=dtypes.bfloat16.as_numpy_dtype) + b = np.array( + [bfloat16(1.3359375), + bfloat16(1.4375), + bfloat16(float("nan"))], + dtype=dtypes.bfloat16.as_numpy_dtype) + np.testing.assert_allclose( + a, b, rtol=0.1, atol=0.1, equal_nan=True, err_msg="", verbose=True) + class Bfloat16NumPyTest(test.TestCase): diff --git a/tensorflow/python/lib/core/ndarray_tensor.h b/tensorflow/python/lib/core/ndarray_tensor.h index 5172d504bd47d2f88afb088161d74a575a4213aa..b2cd4133ca65205ee432487e80430222064ef1a4 100644 --- a/tensorflow/python/lib/core/ndarray_tensor.h +++ b/tensorflow/python/lib/core/ndarray_tensor.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_PYTHON_LIB_CORE_NDARRAY_TENSOR_H_ -#define THIRD_PARTY_TENSORFLOW_PYTHON_LIB_CORE_NDARRAY_TENSOR_H_ +#ifndef TENSORFLOW_PYTHON_LIB_CORE_NDARRAY_TENSOR_H_ +#define TENSORFLOW_PYTHON_LIB_CORE_NDARRAY_TENSOR_H_ // Must be included first. #include "tensorflow/python/lib/core/numpy.h" @@ -45,4 +45,4 @@ Status TensorToNdarray(const Tensor& t, PyObject** ret); } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_PYTHON_LIB_CORE_NDARRAY_TENSOR_H_ +#endif // TENSORFLOW_PYTHON_LIB_CORE_NDARRAY_TENSOR_H_ diff --git a/tensorflow/python/lib/core/py_func.cc b/tensorflow/python/lib/core/py_func.cc index dc56b3948626de7d76895378ade04b14e7d779b1..d3bfa0ee337d1f606e5e994406969685a2986ab4 100644 --- a/tensorflow/python/lib/core/py_func.cc +++ b/tensorflow/python/lib/core/py_func.cc @@ -31,6 +31,7 @@ limitations under the License. #include "tensorflow/python/lib/core/ndarray_tensor_bridge.h" #include "tensorflow/python/lib/core/py_util.h" #include "tensorflow/python/lib/core/safe_ptr.h" + #include namespace tensorflow { @@ -141,7 +142,8 @@ bool IsSingleNone(PyObject* obj) { return false; } std::array indices; - char* item_ptr = static_cast(PyArray_GetPtr(array_obj, indices.data())); + char* item_ptr = + static_cast(PyArray_GetPtr(array_obj, indices.data())); PyObject* item = PyArray_GETITEM(array_obj, item_ptr); CHECK(item); return item == Py_None; @@ -301,13 +303,22 @@ Status ConvertNdarrayToTensor(PyObject* obj, Tensor* ret) { if (PyBytes_AsStringAndSize(input_data[i], &el, &el_size) == -1) { #if PY_MAJOR_VERSION >= 3 el = PyUnicode_AsUTF8AndSize(input_data[i], &el_size); - if (!el) { +#else + el = nullptr; + if (PyUnicode_Check(input_data[i])) { + PyObject* unicode = PyUnicode_AsUTF8String(input_data[i]); + if (unicode) { + if (PyString_AsStringAndSize(unicode, &el, &el_size) == -1) { + Py_DECREF(unicode); + el = nullptr; + } + } + } #endif + if (!el) { return errors::Unimplemented("Unsupported object type ", input_data[i]->ob_type->tp_name); -#if PY_MAJOR_VERSION >= 3 } -#endif } tflat(i) = string(el, el_size); } diff --git a/tensorflow/python/lib/core/py_seq_tensor.h b/tensorflow/python/lib/core/py_seq_tensor.h index 6dc4d9c77755bd416fe709ad7a4bf350799f3eb1..c6e5080c62e96e79ca1ccf7e09e1b744ed293e07 100644 --- a/tensorflow/python/lib/core/py_seq_tensor.h +++ b/tensorflow/python/lib/core/py_seq_tensor.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_PYTHON_LIB_CORE_PY_SEQ_TENSOR_H_ -#define THIRD_PARTY_TENSORFLOW_PYTHON_LIB_CORE_PY_SEQ_TENSOR_H_ +#ifndef TENSORFLOW_PYTHON_LIB_CORE_PY_SEQ_TENSOR_H_ +#define TENSORFLOW_PYTHON_LIB_CORE_PY_SEQ_TENSOR_H_ #include @@ -34,4 +34,4 @@ Status PySeqToTensor(PyObject* obj, PyObject* dtype, Tensor* ret); } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_PYTHON_LIB_CORE_PY_SEQ_TENSOR_H_ +#endif // TENSORFLOW_PYTHON_LIB_CORE_PY_SEQ_TENSOR_H_ diff --git a/tensorflow/python/lib/core/safe_ptr.h b/tensorflow/python/lib/core/safe_ptr.h index 80db840aebcc7ca341b0f6c40fdaee2136d21aaa..32d286888666bde8742403bb8e231b3d6d4bf695 100644 --- a/tensorflow/python/lib/core/safe_ptr.h +++ b/tensorflow/python/lib/core/safe_ptr.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_PYTHON_LIB_CORE_SAFE_PTR_H_ -#define THIRD_PARTY_TENSORFLOW_PYTHON_LIB_CORE_SAFE_PTR_H_ +#ifndef TENSORFLOW_PYTHON_LIB_CORE_SAFE_PTR_H_ +#define TENSORFLOW_PYTHON_LIB_CORE_SAFE_PTR_H_ #include @@ -66,4 +66,4 @@ Safe_TF_StatusPtr make_safe(TF_Status* status); } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_PYTHON_LIB_CORE_SAFE_PTR_H_ +#endif // TENSORFLOW_PYTHON_LIB_CORE_SAFE_PTR_H_ diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py index 78b4a7101cd25844419d25f78ee97edddae03c3b..24a0c186198c7389af9add64ec6466b1f3d2afbd 100644 --- a/tensorflow/python/ops/array_ops.py +++ b/tensorflow/python/ops/array_ops.py @@ -103,16 +103,19 @@ from tensorflow.python.ops import gen_math_ops # pylint: disable=wildcard-import from tensorflow.python.ops.gen_array_ops import * from tensorflow.python.util import deprecation +from tensorflow.python.util.tf_export import tf_export # pylint: enable=wildcard-import # Used for slicing to specify a new 1 size dimension newaxis = None +tf_export("newaxis").export_constant(__name__, "newaxis") # We override the 'slice' for the "slice" op, so we keep python's # existing 'slice' for later use in this module. _BaseSlice = slice +@tf_export("identity") def identity(input, name=None): # pylint: disable=redefined-builtin r"""Return a tensor with the same shape and contents as input. @@ -135,6 +138,7 @@ def identity(input, name=None): # pylint: disable=redefined-builtin # pylint: disable=redefined-builtin,protected-access +@tf_export("expand_dims") def expand_dims(input, axis=None, name=None, dim=None): """Inserts a dimension of 1 into a tensor's shape. @@ -211,6 +215,7 @@ listdiff.__doc__ = gen_array_ops._list_diff.__doc__ + "\n" + listdiff.__doc__ # pylint: disable=undefined-variable,protected-access +@tf_export("setdiff1d") def setdiff1d(x, y, index_dtype=dtypes.int32, name=None): return gen_array_ops._list_diff(x, y, index_dtype, name) @@ -220,6 +225,7 @@ setdiff1d.__doc__ = gen_array_ops._list_diff.__doc__ # pylint: enable=protected-access +@tf_export("broadcast_dynamic_shape") def broadcast_dynamic_shape(shape_x, shape_y): # pylint: disable=protected-access """Returns the broadcasted dynamic shape between `shape_x` and `shape_y`. @@ -235,6 +241,7 @@ def broadcast_dynamic_shape(shape_x, shape_y): # pylint: enable=protected-access +@tf_export("broadcast_static_shape") def broadcast_static_shape(shape_x, shape_y): """Returns the broadcasted static shape between `shape_x` and `shape_y`. @@ -251,6 +258,7 @@ def broadcast_static_shape(shape_x, shape_y): return common_shapes.broadcast_shape(shape_x, shape_y) +@tf_export("shape") def shape(input, name=None, out_type=dtypes.int32): # pylint: disable=redefined-builtin """Returns the shape of a tensor. @@ -304,6 +312,7 @@ def shape_internal(input, name=None, optimize=True, out_type=dtypes.int32): return gen_array_ops.shape(input, name=name, out_type=out_type) +@tf_export("shape_n") def shape_n(input, out_type=dtypes.int32, name=None): # pylint: disable=redefined-builtin """Returns shape of tensors. @@ -330,6 +339,7 @@ def shape_n(input, out_type=dtypes.int32, name=None): return output +@tf_export("size") def size(input, name=None, out_type=dtypes.int32): # pylint: disable=redefined-builtin """Returns the size of a tensor. @@ -387,6 +397,7 @@ def size_internal(input, name=None, optimize=True, out_type=dtypes.int32): return gen_array_ops.size(input, name=name, out_type=out_type) +@tf_export("rank") def rank(input, name=None): # pylint: disable=redefined-builtin """Returns the rank of a tensor. @@ -577,6 +588,7 @@ def _slice_helper(tensor, slice_spec, var=None): # pylint: disable=undefined-variable,protected-access,redefined-outer-name +@tf_export("slice") def slice(input_, begin, size, name=None): # pylint: disable=redefined-builtin """Extracts a slice from a tensor. @@ -629,6 +641,7 @@ def slice(input_, begin, size, name=None): # pylint: disable=invalid-name +@tf_export("strided_slice") def strided_slice(input_, begin, end, @@ -817,6 +830,7 @@ def _SliceHelperVar(var, slice_spec): ops.Tensor._override_operator("__getitem__", _slice_helper) +@tf_export("parallel_stack") def parallel_stack(values, name="parallel_stack"): """Stacks a list of rank-`R` tensors into one rank-`(R+1)` tensor in parallel. @@ -867,6 +881,7 @@ def parallel_stack(values, name="parallel_stack"): [expand_dims(value, 0) for value in values], shape=output_shape) +@tf_export("stack") def stack(values, axis=0, name="stack"): """Stacks a list of rank-`R` tensors into one rank-`(R+1)` tensor. @@ -1012,6 +1027,7 @@ ops.register_tensor_conversion_function((list, tuple), _autopacking_conversion_function, 99) +@tf_export("unstack") def unstack(value, num=None, axis=0, name="unstack"): """Unpacks the given dimension of a rank-`R` tensor into rank-`(R-1)` tensors. @@ -1061,6 +1077,7 @@ def unstack(value, num=None, axis=0, name="unstack"): return gen_array_ops._unpack(value, num=num, axis=axis, name=name) +@tf_export("concat") def concat(values, axis, name="concat"): """Concatenates tensors along one dimension. @@ -1157,6 +1174,7 @@ def concat(values, axis, name="concat"): return gen_array_ops._concat_v2(values=values, axis=axis, name=name) +@tf_export("boolean_mask") def boolean_mask(tensor, mask, name="boolean_mask", axis=None): """Apply boolean mask to tensor. Numpy equivalent is `tensor[mask]`. @@ -1237,6 +1255,7 @@ def boolean_mask(tensor, mask, name="boolean_mask", axis=None): return _apply_mask_1d(tensor, mask, axis) +@tf_export("sparse_mask") def sparse_mask(a, mask_indices, name=None): """Masks elements of `IndexedSlices`. @@ -1279,6 +1298,7 @@ def sparse_mask(a, mask_indices, name=None): return ops.IndexedSlices(out_values, out_indices, a.dense_shape) +@tf_export("unique") def unique(x, out_idx=dtypes.int32, name=None): # TODO(yongtang): switch to v2 once API deprecation # period (3 weeks) pass. @@ -1290,6 +1310,7 @@ def unique(x, out_idx=dtypes.int32, name=None): unique.__doc__ = gen_array_ops._unique.__doc__ +@tf_export("split") def split(value, num_or_size_splits, axis=0, num=None, name="split"): """Splits a tensor into sub tensors. @@ -1356,6 +1377,7 @@ def split(value, num_or_size_splits, axis=0, num=None, name="split"): name=name) +@tf_export("transpose") def transpose(a, perm=None, name="transpose", conjugate=False): """Transposes `a`. Permutes the dimensions according to `perm`. @@ -1432,6 +1454,7 @@ def transpose(a, perm=None, name="transpose", conjugate=False): # pylint: disable=invalid-name +@tf_export("matrix_transpose", "linalg.transpose") def matrix_transpose(a, name="matrix_transpose", conjugate=False): """Transposes last two dimensions of tensor `a`. @@ -1503,6 +1526,7 @@ def matrix_transpose(a, name="matrix_transpose", conjugate=False): # pylint: enable=invalid-name +@tf_export("zeros") def zeros(shape, dtype=dtypes.float32, name=None): """Creates a tensor with all elements set to zero. @@ -1547,6 +1571,7 @@ def zeros(shape, dtype=dtypes.float32, name=None): return output +@tf_export("zeros_like") def zeros_like(tensor, dtype=None, name=None, optimize=True): """Creates a tensor with all elements set to zero. @@ -1563,9 +1588,9 @@ def zeros_like(tensor, dtype=None, name=None, optimize=True): Args: tensor: A `Tensor`. - dtype: A type for the returned `Tensor`. Must be `float32`, `float64`, - `int8`, `uint8`, `int16`, `uint16`, int32`, `int64`, - `complex64`, `complex128` or `bool`. + dtype: A type for the returned `Tensor`. Must be `float16`, `float32`, + `float64`, `int8`, `uint8`, `int16`, `uint16`, `int32`, `int64`, + `complex64`, `complex128`, `bool` or `string`. name: A name for the operation (optional). optimize: if true, attempt to statically determine the shape of 'tensor' and encode it as a constant. @@ -1599,6 +1624,7 @@ def zeros_like(tensor, dtype=None, name=None, optimize=True): return gen_array_ops._zeros_like(tensor, name=name) +@tf_export("ones_like") def ones_like(tensor, dtype=None, name=None, optimize=True): """Creates a tensor with all elements set to 1. @@ -1636,6 +1662,7 @@ def ones_like(tensor, dtype=None, name=None, optimize=True): return ret +@tf_export("ones") def ones(shape, dtype=dtypes.float32, name=None): """Creates a tensor with all elements set to 1. @@ -1675,6 +1702,7 @@ def ones(shape, dtype=dtypes.float32, name=None): return output +@tf_export("placeholder") def placeholder(dtype, shape=None, name=None): """Inserts a placeholder for a tensor that will be always fed. @@ -1728,6 +1756,7 @@ def _normalize_sparse_shape(shape, name): return (ops.convert_to_tensor(shape, dtype=dtypes.int64, name=name), rank) +@tf_export("sparse_placeholder") def sparse_placeholder(dtype, shape=None, name=None): """Inserts a placeholder for a sparse tensor that will be always fed. @@ -1794,6 +1823,7 @@ def sparse_placeholder(dtype, shape=None, name=None): # pylint: enable=redefined-outer-name +@tf_export("pad") def pad(tensor, paddings, mode="CONSTANT", name=None, constant_values=0): # pylint: disable=invalid-name """Pads a tensor. @@ -1887,6 +1917,7 @@ def pad(tensor, paddings, mode="CONSTANT", name=None, constant_values=0): # pyl return result +@tf_export("meshgrid") def meshgrid(*args, **kwargs): """Broadcasts parameters for evaluation on an N-D grid. @@ -2026,6 +2057,7 @@ def _TileGradShape(op): return [tensor_shape.TensorShape(output_dims)] +@tf_export("edit_distance") def edit_distance(hypothesis, truth, normalize=True, name="edit_distance"): """Computes the Levenshtein distance between sequences. @@ -2139,6 +2171,7 @@ def _FakeQuantWithMinMaxVarsPerChannelGradient(op, grad): narrow_range=op.get_attr("narrow_range")) +@tf_export("required_space_to_batch_paddings") def required_space_to_batch_paddings(input_shape, block_shape, base_paddings=None, @@ -2217,6 +2250,7 @@ def required_space_to_batch_paddings(input_shape, return result_paddings, result_crops +@tf_export("space_to_batch") def space_to_batch(input, paddings, block_size, name=None): # pylint: disable=redefined-builtin result = space_to_batch_nd( input, @@ -2230,6 +2264,7 @@ def space_to_batch(input, paddings, block_size, name=None): # pylint: disable=r space_to_batch.__doc__ = gen_array_ops._space_to_batch.__doc__ +@tf_export("space_to_depth") def space_to_depth(input, block_size, name=None, data_format="NHWC"): # pylint: disable=redefined-builtin return gen_array_ops.space_to_depth(input, block_size, data_format, name=name) @@ -2237,6 +2272,7 @@ def space_to_depth(input, block_size, name=None, data_format="NHWC"): # pylint: space_to_depth.__doc__ = gen_array_ops.space_to_depth.__doc__ +@tf_export("depth_to_space") def depth_to_space(input, block_size, name=None, data_format="NHWC"): # pylint: disable=redefined-builtin return gen_array_ops.depth_to_space(input, block_size, data_format, name=name) @@ -2244,6 +2280,7 @@ def depth_to_space(input, block_size, name=None, data_format="NHWC"): # pylint: depth_to_space.__doc__ = gen_array_ops.depth_to_space.__doc__ +@tf_export("batch_to_space") def batch_to_space(input, crops, block_size, name=None): # pylint: disable=redefined-builtin result = batch_to_space_nd( input, @@ -2257,6 +2294,7 @@ def batch_to_space(input, crops, block_size, name=None): # pylint: disable=rede batch_to_space.__doc__ = gen_array_ops._batch_to_space.__doc__ +@tf_export("one_hot") def one_hot(indices, depth, on_value=None, @@ -2416,6 +2454,7 @@ def _all_dimensions(x): return range(0, rank(x)) +@tf_export("sequence_mask") def sequence_mask(lengths, maxlen=None, dtype=dtypes.bool, name=None): """Returns a mask tensor representing the first N positions of each cell. @@ -2478,6 +2517,7 @@ def sequence_mask(lengths, maxlen=None, dtype=dtypes.bool, name=None): return gen_math_ops.cast(result, dtype) +@tf_export("squeeze") def squeeze(input, axis=None, name=None, squeeze_dims=None): # pylint: disable=redefined-builtin """Removes dimensions of size 1 from the shape of a tensor. @@ -2527,6 +2567,7 @@ def squeeze(input, axis=None, name=None, squeeze_dims=None): return gen_array_ops._squeeze(input, axis, name) +@tf_export("where") def where(condition, x=None, y=None, name=None): """Return the elements, either from `x` or `y`, depending on the `condition`. @@ -2579,6 +2620,7 @@ def where(condition, x=None, y=None, name=None): raise ValueError("x and y must both be non-None or both be None.") +@tf_export("reverse") def reverse(tensor, axis, name=None): return gen_array_ops.reverse_v2(tensor, axis, name) @@ -2587,6 +2629,7 @@ reverse.__doc__ = gen_array_ops.reverse_v2.__doc__ # pylint: disable=redefined-builtin +@tf_export("reverse_sequence") def reverse_sequence(input, seq_lengths, seq_axis=None, @@ -2614,6 +2657,7 @@ reverse_sequence.__doc__ = deprecation.rewrite_argument_docstring( "seq_dim", "seq_axis") +@tf_export("gather") def gather(params, indices, validate_indices=None, name=None, axis=0): # TODO(rjryan): Remove "Gather" creation in favor of GatherV2 once the forward # compatibility 3 week period has passed. @@ -2629,6 +2673,7 @@ gather.__doc__ = gen_array_ops.gather_v2.__doc__ # Define quantize_v2 here in order to make name the second-to-last attribute, # because round_mode was added later. +@tf_export("quantize_v2") @deprecation.deprecated( "2017-10-25", "`tf.quantize_v2` is deprecated, please use `tf.quantize` instead.") @@ -2653,6 +2698,7 @@ quantize_v2.__doc__ = """Please use `tf.quantize` instead.""" # We want to expose tf.quantize instead of tf.quantize_v2; we can deprecate # tf.quantize_v2 in next version of TensorFlow. +@tf_export("quantize") def quantize(input, # pylint: disable=redefined-builtin min_range, max_range, diff --git a/tensorflow/python/ops/candidate_sampling_ops.py b/tensorflow/python/ops/candidate_sampling_ops.py index d6294c24f5cf9427209c9f5e84d05e32686908bf..20445c78a290a4fe67cad668dd714dd2c61c5f3d 100644 --- a/tensorflow/python/ops/candidate_sampling_ops.py +++ b/tensorflow/python/ops/candidate_sampling_ops.py @@ -23,8 +23,10 @@ from tensorflow.python.framework import random_seed from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_candidate_sampling_ops from tensorflow.python.ops import math_ops +from tensorflow.python.util.tf_export import tf_export +@tf_export('nn.uniform_candidate_sampler') def uniform_candidate_sampler(true_classes, num_true, num_sampled, unique, range_max, seed=None, name=None): """Samples a set of classes using a uniform base distribution. @@ -80,6 +82,7 @@ def uniform_candidate_sampler(true_classes, num_true, num_sampled, unique, seed2=seed2, name=name) +@tf_export('nn.log_uniform_candidate_sampler') def log_uniform_candidate_sampler(true_classes, num_true, num_sampled, unique, range_max, seed=None, name=None): """Samples a set of classes using a log-uniform (Zipfian) base distribution. @@ -138,6 +141,7 @@ def log_uniform_candidate_sampler(true_classes, num_true, num_sampled, unique, seed2=seed2, name=name) +@tf_export('nn.learned_unigram_candidate_sampler') def learned_unigram_candidate_sampler(true_classes, num_true, num_sampled, unique, range_max, seed=None, name=None): """Samples a set of classes from a distribution learned during training. @@ -194,6 +198,7 @@ def learned_unigram_candidate_sampler(true_classes, num_true, num_sampled, seed2=seed2, name=name) +@tf_export('nn.fixed_unigram_candidate_sampler') def fixed_unigram_candidate_sampler(true_classes, num_true, num_sampled, @@ -285,6 +290,7 @@ def fixed_unigram_candidate_sampler(true_classes, unigrams=unigrams, seed=seed1, seed2=seed2, name=name) +@tf_export('nn.all_candidate_sampler') def all_candidate_sampler(true_classes, num_true, num_sampled, unique, seed=None, name=None): """Generate the set of all classes. @@ -320,6 +326,7 @@ def all_candidate_sampler(true_classes, num_true, num_sampled, unique, name=name) +@tf_export('nn.compute_accidental_hits') def compute_accidental_hits(true_classes, sampled_candidates, num_true, seed=None, name=None): """Compute the position ids in `sampled_candidates` matching `true_classes`. diff --git a/tensorflow/python/ops/check_ops.py b/tensorflow/python/ops/check_ops.py index eb7806ed0b4dc3022671d6b4248dc5924988534b..0fd6e29a49c8e4e31e244bfbbfca525d72e4d811 100644 --- a/tensorflow/python/ops/check_ops.py +++ b/tensorflow/python/ops/check_ops.py @@ -57,6 +57,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.util import compat +from tensorflow.python.util.tf_export import tf_export NUMERIC_TYPES = frozenset( [dtypes.float32, dtypes.float64, dtypes.int8, dtypes.int16, dtypes.int32, @@ -111,6 +112,7 @@ def _shape_and_dtype_str(tensor): return 'shape=%s dtype=%s' % (tensor.shape, tensor.dtype.name) +@tf_export('assert_proper_iterable') def assert_proper_iterable(values): """Static assert that values is a "proper" iterable. @@ -138,6 +140,7 @@ def assert_proper_iterable(values): 'Expected argument "values" to be iterable. Found: %s' % type(values)) +@tf_export('assert_negative') def assert_negative(x, data=None, summarize=None, message=None, name=None): """Assert the condition `x < 0` holds element-wise. @@ -178,6 +181,7 @@ def assert_negative(x, data=None, summarize=None, message=None, name=None): return assert_less(x, zero, data=data, summarize=summarize) +@tf_export('assert_positive') def assert_positive(x, data=None, summarize=None, message=None, name=None): """Assert the condition `x > 0` holds element-wise. @@ -217,6 +221,7 @@ def assert_positive(x, data=None, summarize=None, message=None, name=None): return assert_less(zero, x, data=data, summarize=summarize) +@tf_export('assert_non_negative') def assert_non_negative(x, data=None, summarize=None, message=None, name=None): """Assert the condition `x >= 0` holds element-wise. @@ -258,6 +263,7 @@ def assert_non_negative(x, data=None, summarize=None, message=None, name=None): return assert_less_equal(zero, x, data=data, summarize=summarize) +@tf_export('assert_non_positive') def assert_non_positive(x, data=None, summarize=None, message=None, name=None): """Assert the condition `x <= 0` holds element-wise. @@ -299,6 +305,7 @@ def assert_non_positive(x, data=None, summarize=None, message=None, name=None): return assert_less_equal(x, zero, data=data, summarize=summarize) +@tf_export('assert_equal') def assert_equal(x, y, data=None, summarize=None, message=None, name=None): """Assert the condition `x == y` holds element-wise. @@ -395,6 +402,7 @@ def assert_equal(x, y, data=None, summarize=None, message=None, name=None): return control_flow_ops.Assert(condition, data, summarize=summarize) +@tf_export('assert_none_equal') def assert_none_equal( x, y, data=None, summarize=None, message=None, name=None): """Assert the condition `x != y` holds for all elements. @@ -445,6 +453,7 @@ def assert_none_equal( return control_flow_ops.Assert(condition, data, summarize=summarize) +@tf_export('assert_near') def assert_near( x, y, rtol=None, atol=None, data=None, summarize=None, message=None, name=None): @@ -522,6 +531,7 @@ def assert_near( return control_flow_ops.Assert(condition, data, summarize=summarize) +@tf_export('assert_less') def assert_less(x, y, data=None, summarize=None, message=None, name=None): """Assert the condition `x < y` holds element-wise. @@ -569,6 +579,7 @@ def assert_less(x, y, data=None, summarize=None, message=None, name=None): return control_flow_ops.Assert(condition, data, summarize=summarize) +@tf_export('assert_less_equal') def assert_less_equal(x, y, data=None, summarize=None, message=None, name=None): """Assert the condition `x <= y` holds element-wise. @@ -616,6 +627,7 @@ def assert_less_equal(x, y, data=None, summarize=None, message=None, name=None): return control_flow_ops.Assert(condition, data, summarize=summarize) +@tf_export('assert_greater') def assert_greater(x, y, data=None, summarize=None, message=None, name=None): """Assert the condition `x > y` holds element-wise. @@ -663,6 +675,7 @@ def assert_greater(x, y, data=None, summarize=None, message=None, name=None): return control_flow_ops.Assert(condition, data, summarize=summarize) +@tf_export('assert_greater_equal') def assert_greater_equal(x, y, data=None, summarize=None, message=None, name=None): """Assert the condition `x >= y` holds element-wise. @@ -760,6 +773,7 @@ def _assert_rank_condition( return control_flow_ops.Assert(condition, data, summarize=summarize) +@tf_export('assert_rank') def assert_rank(x, rank, data=None, summarize=None, message=None, name=None): """Assert `x` has rank equal to `rank`. @@ -821,6 +835,7 @@ def assert_rank(x, rank, data=None, summarize=None, message=None, name=None): return assert_op +@tf_export('assert_rank_at_least') def assert_rank_at_least( x, rank, data=None, summarize=None, message=None, name=None): """Assert `x` has rank equal to `rank` or higher. @@ -951,6 +966,7 @@ def _assert_ranks_condition( return control_flow_ops.Assert(condition, data, summarize=summarize) +@tf_export('assert_rank_in') def assert_rank_in( x, ranks, data=None, summarize=None, message=None, name=None): """Assert `x` has rank in `ranks`. @@ -1012,6 +1028,7 @@ def assert_rank_in( return assert_op +@tf_export('assert_integer') def assert_integer(x, message=None, name=None): """Assert that `x` is of integer dtype. @@ -1049,6 +1066,7 @@ def assert_integer(x, message=None, name=None): return control_flow_ops.no_op('statically_determined_was_integer') +@tf_export('assert_type') def assert_type(tensor, tf_type, message=None, name=None): """Statically asserts that the given `Tensor` is of the specified type. @@ -1096,10 +1114,12 @@ def _get_diff_for_monotonic_comparison(x): return control_flow_ops.cond(is_shorter_than_two, short_result, diff) +@tf_export('is_numeric_tensor') def is_numeric_tensor(tensor): return isinstance(tensor, ops.Tensor) and tensor.dtype in NUMERIC_TYPES +@tf_export('is_non_decreasing') def is_non_decreasing(x, name=None): """Returns `True` if `x` is non-decreasing. @@ -1126,6 +1146,7 @@ def is_non_decreasing(x, name=None): return math_ops.reduce_all(math_ops.less_equal(zero, diff)) +@tf_export('is_strictly_increasing') def is_strictly_increasing(x, name=None): """Returns `True` if `x` is strictly increasing. @@ -1184,6 +1205,7 @@ def _assert_same_base_type(items, expected_type=None): return expected_type +@tf_export('assert_same_float_dtype') def assert_same_float_dtype(tensors=None, dtype=None): """Validate and return float type based on `tensors` and `dtype`. @@ -1212,6 +1234,7 @@ def assert_same_float_dtype(tensors=None, dtype=None): return dtype +@tf_export('assert_scalar') def assert_scalar(tensor, name=None): with ops.name_scope(name, 'assert_scalar', [tensor]) as name_scope: tensor = ops.convert_to_tensor(tensor, name=name_scope) diff --git a/tensorflow/python/ops/clip_ops.py b/tensorflow/python/ops/clip_ops.py index 80803530c1ede4537e729ef77958a5d905005dd3..dd8c33247c2436413ee8c9a3ceeca4d8a493bb4e 100644 --- a/tensorflow/python/ops/clip_ops.py +++ b/tensorflow/python/ops/clip_ops.py @@ -28,8 +28,10 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_nn_ops from tensorflow.python.ops import math_ops +from tensorflow.python.util.tf_export import tf_export +@tf_export("clip_by_value") def clip_by_value(t, clip_value_min, clip_value_max, name=None): """Clips tensor values to a specified min and max. @@ -70,6 +72,7 @@ def clip_by_value(t, clip_value_min, clip_value_max, return t_max +@tf_export("clip_by_norm") def clip_by_norm(t, clip_norm, axes=None, name=None): """Clips tensor values to a maximum L2-norm. @@ -117,6 +120,8 @@ def clip_by_norm(t, clip_norm, axes=None, name=None): return tclip + +@tf_export("global_norm") def global_norm(t_list, name=None): """Computes the global norm of multiple tensors. @@ -164,6 +169,8 @@ def global_norm(t_list, name=None): return norm + +@tf_export("clip_by_global_norm") def clip_by_global_norm(t_list, clip_norm, use_norm=None, name=None): """Clips values of multiple tensors by the ratio of the sum of their norms. @@ -246,6 +253,7 @@ def clip_by_global_norm(t_list, clip_norm, use_norm=None, name=None): return list_clipped, use_norm +@tf_export("clip_by_average_norm") def clip_by_average_norm(t, clip_norm, name=None): """Clips tensor values to a maximum average L2-norm. diff --git a/tensorflow/python/ops/confusion_matrix.py b/tensorflow/python/ops/confusion_matrix.py index 32e071db1749ceed56e2f31446e58213d0603705..50690cd891f73df1e345817b834ce6c361bff9e8 100644 --- a/tensorflow/python/ops/confusion_matrix.py +++ b/tensorflow/python/ops/confusion_matrix.py @@ -31,6 +31,7 @@ from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import sparse_ops +from tensorflow.python.util.tf_export import tf_export def remove_squeezable_dimensions( @@ -93,6 +94,7 @@ def remove_squeezable_dimensions( return labels, predictions +@tf_export('confusion_matrix') def confusion_matrix(labels, predictions, num_classes=None, dtype=dtypes.int32, name=None, weights=None): """Computes the confusion matrix from predictions and labels. diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py index 86941a7f2ae7d6ba8622d5c4ceafdb9a689eaca0..d379eccc20dcd63255ee8c2dbe3fbd3e6a9077af 100644 --- a/tensorflow/python/ops/control_flow_ops.py +++ b/tensorflow/python/ops/control_flow_ops.py @@ -82,6 +82,7 @@ from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import deprecation from tensorflow.python.util import nest from tensorflow.python.util import tf_should_use +from tensorflow.python.util.tf_export import tf_export # We override the 'tuple' for a control flow op, so we keep python's @@ -117,6 +118,7 @@ def _summarize_eager(tensor, summarize=None): # Assert and Print are special symbols in python, so we must # use an upper-case version of them. +@tf_export("Assert") @tf_should_use.should_use_result def Assert(condition, data, summarize=None, name=None): """Asserts that the given condition is true. @@ -1867,6 +1869,7 @@ def _UnpackIfSingleton(res): # pylint: disable=redefined-outer-name # pylint: disable=g-doc-args +@tf_export("cond") @deprecation.deprecated_args( None, "fn1/fn2 are deprecated in favor of the true_fn/false_fn arguments.", @@ -2843,6 +2846,7 @@ class WhileContext(ControlFlowContext): # pylint: disable=redefined-outer-name +@tf_export("while_loop") def while_loop(cond, body, loop_vars, shape_invariants=None, parallel_iterations=10, back_prop=True, swap_memory=False, name=None, maximum_iterations=None): @@ -3110,6 +3114,7 @@ def _GroupControlDeps(dev, deps, name=None): # TODO(touts): Accept "inputs" as a list. +@tf_export("group") def group(*inputs, **kwargs): """Create an op that groups multiple operations. @@ -3175,6 +3180,7 @@ def group(*inputs, **kwargs): return no_op(name=name) +@tf_export("tuple") def tuple(tensors, name=None, control_inputs=None): """Group tensors together. @@ -3328,6 +3334,7 @@ def _case_verify_and_canonicalize_args(pred_fn_pairs, exclusive, name): return predicates, actions +@tf_export("case") def case(pred_fn_pairs, default=None, exclusive=False, diff --git a/tensorflow/python/ops/ctc_ops.py b/tensorflow/python/ops/ctc_ops.py index f037767cf4051d058a2da0cca9c4515fd9705d28..83da6739db673644f59fda3044769b18b2138fbc 100644 --- a/tensorflow/python/ops/ctc_ops.py +++ b/tensorflow/python/ops/ctc_ops.py @@ -25,9 +25,11 @@ from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_ctc_ops from tensorflow.python.ops.nn_grad import _BroadcastMul +from tensorflow.python.util.tf_export import tf_export # pylint: disable=protected-access, invalid-name +@tf_export("nn.ctc_loss") def ctc_loss(labels, inputs, sequence_length, preprocess_collapse_repeated=False, ctc_merge_repeated=True, @@ -185,6 +187,7 @@ def _CTCLossGrad(op, grad_loss, _): return [_BroadcastMul(grad_loss, grad_without_gradient), None, None, None] +@tf_export("nn.ctc_greedy_decoder") def ctc_greedy_decoder(inputs, sequence_length, merge_repeated=True): """Performs greedy decoding on the logits given in input (best path). @@ -228,6 +231,7 @@ def ctc_greedy_decoder(inputs, sequence_length, merge_repeated=True): log_probabilities) +@tf_export("nn.ctc_beam_search_decoder") def ctc_beam_search_decoder(inputs, sequence_length, beam_width=100, top_paths=1, merge_repeated=True): """Performs beam search decoding on the logits given in input. diff --git a/tensorflow/python/ops/data_flow_ops.py b/tensorflow/python/ops/data_flow_ops.py index f441f6d4bf7986bbfb15593edf2b2b1bfe6ec71f..34f0bf7b78a75533cb89ed549afad90f3c066b94 100644 --- a/tensorflow/python/ops/data_flow_ops.py +++ b/tensorflow/python/ops/data_flow_ops.py @@ -39,6 +39,7 @@ from tensorflow.python.ops import math_ops # go/tf-wildcard-import # pylint: disable=wildcard-import from tensorflow.python.ops.gen_data_flow_ops import * +from tensorflow.python.util.tf_export import tf_export # pylint: enable=wildcard-import @@ -107,6 +108,7 @@ def _shape_common(s1, s2): # pylint: disable=protected-access +@tf_export("QueueBase") class QueueBase(object): """Base class for queue implementations. @@ -596,6 +598,7 @@ class QueueBase(object): return gen_data_flow_ops._queue_size(self._queue_ref, name=name) +@tf_export("RandomShuffleQueue") class RandomShuffleQueue(QueueBase): """A queue implementation that dequeues elements in a random order. @@ -674,6 +677,7 @@ class RandomShuffleQueue(QueueBase): super(RandomShuffleQueue, self).__init__(dtypes, shapes, names, queue_ref) +@tf_export("FIFOQueue") class FIFOQueue(QueueBase): """A queue implementation that dequeues elements in first-in first-out order. @@ -727,6 +731,7 @@ class FIFOQueue(QueueBase): super(FIFOQueue, self).__init__(dtypes, shapes, names, queue_ref) +@tf_export("PaddingFIFOQueue") class PaddingFIFOQueue(QueueBase): """A FIFOQueue that supports batching variable-sized tensors by padding. @@ -797,6 +802,7 @@ class PaddingFIFOQueue(QueueBase): super(PaddingFIFOQueue, self).__init__(dtypes, shapes, names, queue_ref) +@tf_export("PriorityQueue") class PriorityQueue(QueueBase): """A queue implementation that dequeues elements in prioritized order. @@ -1106,6 +1112,7 @@ class Barrier(object): self._barrier_ref, name=name) +@tf_export("ConditionalAccumulatorBase") class ConditionalAccumulatorBase(object): """A conditional accumulator for aggregating gradients. @@ -1184,6 +1191,7 @@ class ConditionalAccumulatorBase(object): name=name) +@tf_export("ConditionalAccumulator") class ConditionalAccumulator(ConditionalAccumulatorBase): """A conditional accumulator for aggregating gradients. @@ -1263,6 +1271,7 @@ class ConditionalAccumulator(ConditionalAccumulatorBase): return out +@tf_export("SparseConditionalAccumulator") class SparseConditionalAccumulator(ConditionalAccumulatorBase): """A conditional accumulator for aggregating sparse gradients. diff --git a/tensorflow/python/ops/distributions/bernoulli.py b/tensorflow/python/ops/distributions/bernoulli.py index b6b20d1b4a893a4c109560be717339d75fc7ccfc..1f300b7147be505a316c38ae57cadeae2bd7ea10 100644 --- a/tensorflow/python/ops/distributions/bernoulli.py +++ b/tensorflow/python/ops/distributions/bernoulli.py @@ -29,8 +29,10 @@ from tensorflow.python.ops import random_ops from tensorflow.python.ops.distributions import distribution from tensorflow.python.ops.distributions import kullback_leibler from tensorflow.python.ops.distributions import util as distribution_util +from tensorflow.python.util.tf_export import tf_export +@tf_export("distributions.Bernoulli") class Bernoulli(distribution.Distribution): """Bernoulli distribution. diff --git a/tensorflow/python/ops/distributions/beta.py b/tensorflow/python/ops/distributions/beta.py index 2b93478cdf9f9e80f4c2c19ad25cb270a8e7aa98..6d6b40b04557a4483f60d8c06c35f937d38a24b9 100644 --- a/tensorflow/python/ops/distributions/beta.py +++ b/tensorflow/python/ops/distributions/beta.py @@ -33,6 +33,7 @@ from tensorflow.python.ops import random_ops from tensorflow.python.ops.distributions import distribution from tensorflow.python.ops.distributions import kullback_leibler from tensorflow.python.ops.distributions import util as distribution_util +from tensorflow.python.util.tf_export import tf_export __all__ = [ @@ -45,6 +46,7 @@ _beta_sample_note = """Note: `x` must have dtype `self.dtype` and be in `[0, 1].` It must have a shape compatible with `self.batch_shape()`.""" +@tf_export("distributions.Beta") class Beta(distribution.Distribution): """Beta distribution. diff --git a/tensorflow/python/ops/distributions/bijector_impl.py b/tensorflow/python/ops/distributions/bijector_impl.py index 8f6d18d91ae19ada5ff3715b523635ec8c88adc3..44d64070ce48c0c115ea7edb1237124bc6698e90 100644 --- a/tensorflow/python/ops/distributions/bijector_impl.py +++ b/tensorflow/python/ops/distributions/bijector_impl.py @@ -32,6 +32,7 @@ from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops +from tensorflow.python.util.tf_export import tf_export __all__ = [ @@ -111,6 +112,7 @@ class _Mapping(collections.namedtuple( @six.add_metaclass(abc.ABCMeta) +@tf_export("distributions.bijectors.Bijector") class Bijector(object): """Interface for transformations of a `Distribution` sample. diff --git a/tensorflow/python/ops/distributions/categorical.py b/tensorflow/python/ops/distributions/categorical.py index 2046a08d618faf592fb3fc8230d8f3c4c5e8c7c7..9161e3fa9f5f7f844e7f4926992c954acae246d6 100644 --- a/tensorflow/python/ops/distributions/categorical.py +++ b/tensorflow/python/ops/distributions/categorical.py @@ -29,6 +29,7 @@ from tensorflow.python.ops import random_ops from tensorflow.python.ops.distributions import distribution from tensorflow.python.ops.distributions import kullback_leibler from tensorflow.python.ops.distributions import util as distribution_util +from tensorflow.python.util.tf_export import tf_export def _broadcast_cat_event_and_params(event, params, base_dtype=dtypes.int32): @@ -58,6 +59,7 @@ def _broadcast_cat_event_and_params(event, params, base_dtype=dtypes.int32): return event, params +@tf_export("distributions.Categorical") class Categorical(distribution.Distribution): """Categorical distribution. @@ -263,12 +265,13 @@ class Categorical(distribution.Distribution): logits_2d = self.logits else: logits_2d = array_ops.reshape(self.logits, [-1, self.event_size]) + sample_dtype = dtypes.int64 if self.dtype.size > 4 else dtypes.int32 draws = random_ops.multinomial( - logits_2d, n, seed=seed, output_dtype=self.dtype) + logits_2d, n, seed=seed, output_dtype=sample_dtype) draws = array_ops.reshape( array_ops.transpose(draws), array_ops.concat([[n], self.batch_shape_tensor()], 0)) - return draws + return math_ops.cast(draws, self.dtype) def _cdf(self, k): k = ops.convert_to_tensor(k, name="k") diff --git a/tensorflow/python/ops/distributions/dirichlet.py b/tensorflow/python/ops/distributions/dirichlet.py index 2accedf1b963f01034f0b4059f44e46eb9bfc5ab..25afeec936069b9cbf926cdc3bbb79226a79aa30 100644 --- a/tensorflow/python/ops/distributions/dirichlet.py +++ b/tensorflow/python/ops/distributions/dirichlet.py @@ -29,6 +29,7 @@ from tensorflow.python.ops import random_ops from tensorflow.python.ops import special_math_ops from tensorflow.python.ops.distributions import distribution from tensorflow.python.ops.distributions import util as distribution_util +from tensorflow.python.util.tf_export import tf_export __all__ = [ @@ -42,6 +43,7 @@ dtype `self.dtype` and be in the `(self.event_shape() - 1)`-simplex, i.e., `self.batch_shape() + self.event_shape()`.""" +@tf_export("distributions.Dirichlet") class Dirichlet(distribution.Distribution): """Dirichlet distribution. diff --git a/tensorflow/python/ops/distributions/dirichlet_multinomial.py b/tensorflow/python/ops/distributions/dirichlet_multinomial.py index aa2b511c5413944df665198eacc26066b8457773..03a98c56ba509ea1f70f12a74ba67b903013cf70 100644 --- a/tensorflow/python/ops/distributions/dirichlet_multinomial.py +++ b/tensorflow/python/ops/distributions/dirichlet_multinomial.py @@ -28,6 +28,7 @@ from tensorflow.python.ops import random_ops from tensorflow.python.ops import special_math_ops from tensorflow.python.ops.distributions import distribution from tensorflow.python.ops.distributions import util as distribution_util +from tensorflow.python.util.tf_export import tf_export __all__ = [ @@ -49,6 +50,7 @@ fractional components, and such that with `self.concentration` and `self.total_count`.""" +@tf_export("distributions.DirichletMultinomial") class DirichletMultinomial(distribution.Distribution): """Dirichlet-Multinomial compound distribution. diff --git a/tensorflow/python/ops/distributions/distribution.py b/tensorflow/python/ops/distributions/distribution.py index 098622c52f431fed1f6e21ffaaed9ebc7142f227..4071e50e815b01d30f3e24ba4677cc37b325f24d 100644 --- a/tensorflow/python/ops/distributions/distribution.py +++ b/tensorflow/python/ops/distributions/distribution.py @@ -34,6 +34,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import kullback_leibler from tensorflow.python.ops.distributions import util from tensorflow.python.util import tf_inspect +from tensorflow.python.util.tf_export import tf_export __all__ = [ @@ -197,6 +198,7 @@ class _DistributionMeta(abc.ABCMeta): return abc.ABCMeta.__new__(mcs, classname, baseclasses, attrs) +@tf_export("distributions.ReparameterizationType") class ReparameterizationType(object): """Instances of this class represent how sampling is reparameterized. @@ -239,15 +241,20 @@ class ReparameterizationType(object): # reparameterized distribution support straight-through gradients with # respect to all parameters. FULLY_REPARAMETERIZED = ReparameterizationType("FULLY_REPARAMETERIZED") +tf_export("distributions.FULLY_REPARAMETERIZED").export_constant( + __name__, "FULLY_REPARAMETERIZED") # Not reparameterized distribution: samples from a non- # reparameterized distribution do not support straight-through gradients for # at least some of the parameters. NOT_REPARAMETERIZED = ReparameterizationType("NOT_REPARAMETERIZED") +tf_export("distributions.NOT_REPARAMETERIZED").export_constant( + __name__, "NOT_REPARAMETERIZED") @six.add_metaclass(_DistributionMeta) +@tf_export("distributions.Distribution") class Distribution(_BaseDistribution): """A generic probability distribution base class. diff --git a/tensorflow/python/ops/distributions/exponential.py b/tensorflow/python/ops/distributions/exponential.py index 281641b9156b9631199efc78ea1c2d30119dadb8..6345a76d485c64659aa01fa1611cd27426d8c8a5 100644 --- a/tensorflow/python/ops/distributions/exponential.py +++ b/tensorflow/python/ops/distributions/exponential.py @@ -27,6 +27,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn from tensorflow.python.ops import random_ops from tensorflow.python.ops.distributions import gamma +from tensorflow.python.util.tf_export import tf_export __all__ = [ @@ -35,6 +36,7 @@ __all__ = [ ] +@tf_export("distributions.Exponential") class Exponential(gamma.Gamma): """Exponential distribution. diff --git a/tensorflow/python/ops/distributions/gamma.py b/tensorflow/python/ops/distributions/gamma.py index 4ac2b9b4ef894fd9a603ff67bf9c8754f1e23b8e..8fb218be3ac7e17e18d85b8e1c100ccd58aa1034 100644 --- a/tensorflow/python/ops/distributions/gamma.py +++ b/tensorflow/python/ops/distributions/gamma.py @@ -33,6 +33,7 @@ from tensorflow.python.ops import random_ops from tensorflow.python.ops.distributions import distribution from tensorflow.python.ops.distributions import kullback_leibler from tensorflow.python.ops.distributions import util as distribution_util +from tensorflow.python.util.tf_export import tf_export __all__ = [ @@ -41,6 +42,7 @@ __all__ = [ ] +@tf_export("distributions.Gamma") class Gamma(distribution.Distribution): """Gamma distribution. diff --git a/tensorflow/python/ops/distributions/identity_bijector.py b/tensorflow/python/ops/distributions/identity_bijector.py index f277eda8bbfb88f2344dfd620c573e0acd8d8078..2972c3554b3639a1ae30a4167f73613b1ff8add2 100644 --- a/tensorflow/python/ops/distributions/identity_bijector.py +++ b/tensorflow/python/ops/distributions/identity_bijector.py @@ -20,6 +20,7 @@ from __future__ import print_function from tensorflow.python.framework import constant_op from tensorflow.python.ops.distributions import bijector +from tensorflow.python.util.tf_export import tf_export __all__ = [ @@ -27,6 +28,7 @@ __all__ = [ ] +@tf_export("distributions.bijectors.Identity") class Identity(bijector.Bijector): """Compute Y = g(X) = X. diff --git a/tensorflow/python/ops/distributions/kullback_leibler.py b/tensorflow/python/ops/distributions/kullback_leibler.py index 829b9611cff02895b67ec39711b8c53e682eb3c5..e3c6f3e789eaf57d1fc5a1fcf244c3a0ef2fe0b8 100644 --- a/tensorflow/python/ops/distributions/kullback_leibler.py +++ b/tensorflow/python/ops/distributions/kullback_leibler.py @@ -23,6 +23,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.util import tf_inspect +from tensorflow.python.util.tf_export import tf_export _DIVERGENCES = {} @@ -50,6 +51,7 @@ def _registered_kl(type_a, type_b): return kl_fn +@tf_export("distributions.kl_divergence") def kl_divergence(distribution_a, distribution_b, allow_nan_stats=True, name=None): """Get the KL-divergence KL(distribution_a || distribution_b). @@ -142,6 +144,7 @@ def cross_entropy(ref, other, ref, other, allow_nan_stats=allow_nan_stats) +@tf_export("distributions.RegisterKL") class RegisterKL(object): """Decorator to register a KL divergence implementation function. diff --git a/tensorflow/python/ops/distributions/laplace.py b/tensorflow/python/ops/distributions/laplace.py index 5c964ff78a53b6d2dec588b85abff2c5b1173c06..e98ac855c58efa1ef3ccef2de24f329d839bac26 100644 --- a/tensorflow/python/ops/distributions/laplace.py +++ b/tensorflow/python/ops/distributions/laplace.py @@ -33,6 +33,7 @@ from tensorflow.python.ops import nn from tensorflow.python.ops import random_ops from tensorflow.python.ops.distributions import distribution from tensorflow.python.ops.distributions import special_math +from tensorflow.python.util.tf_export import tf_export __all__ = [ @@ -41,6 +42,7 @@ __all__ = [ ] +@tf_export("distributions.Laplace") class Laplace(distribution.Distribution): """The Laplace distribution with location `loc` and `scale` parameters. diff --git a/tensorflow/python/ops/distributions/multinomial.py b/tensorflow/python/ops/distributions/multinomial.py index 04762565c2a982f4df47a1a85547db7a104a5ec3..26b5c5aef98fc11b07a8c8357e7ec37819587da9 100644 --- a/tensorflow/python/ops/distributions/multinomial.py +++ b/tensorflow/python/ops/distributions/multinomial.py @@ -29,6 +29,7 @@ from tensorflow.python.ops import nn_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops.distributions import distribution from tensorflow.python.ops.distributions import util as distribution_util +from tensorflow.python.util.tf_export import tf_export __all__ = [ @@ -50,6 +51,7 @@ fractional components, and such that with `self.probs` and `self.total_count`.""" +@tf_export("distributions.Multinomial") class Multinomial(distribution.Distribution): """Multinomial distribution. diff --git a/tensorflow/python/ops/distributions/normal.py b/tensorflow/python/ops/distributions/normal.py index 0ef1c91df8c83146fdae086d6056b1d947bae128..e7f120ea2da525e20a1ae42e6418cf2ac83686af 100644 --- a/tensorflow/python/ops/distributions/normal.py +++ b/tensorflow/python/ops/distributions/normal.py @@ -32,6 +32,7 @@ from tensorflow.python.ops import random_ops from tensorflow.python.ops.distributions import distribution from tensorflow.python.ops.distributions import kullback_leibler from tensorflow.python.ops.distributions import special_math +from tensorflow.python.util.tf_export import tf_export __all__ = [ @@ -40,6 +41,7 @@ __all__ = [ ] +@tf_export("distributions.Normal") class Normal(distribution.Distribution): """The Normal distribution with location `loc` and `scale` parameters. diff --git a/tensorflow/python/ops/distributions/student_t.py b/tensorflow/python/ops/distributions/student_t.py index 073ac4286be170dcfd564f61f1026a85d95c772c..778fefb8c2991153b7e7a1f20df61680153dab2a 100644 --- a/tensorflow/python/ops/distributions/student_t.py +++ b/tensorflow/python/ops/distributions/student_t.py @@ -33,6 +33,7 @@ from tensorflow.python.ops import random_ops from tensorflow.python.ops import special_math_ops from tensorflow.python.ops.distributions import distribution from tensorflow.python.ops.distributions import util as distribution_util +from tensorflow.python.util.tf_export import tf_export __all__ = [ @@ -41,6 +42,7 @@ __all__ = [ ] +@tf_export("distributions.StudentT") class StudentT(distribution.Distribution): """Student's t-distribution. diff --git a/tensorflow/python/ops/distributions/uniform.py b/tensorflow/python/ops/distributions/uniform.py index 9b555f87eae14fe30ff020f996778a4ad8f98ab9..3580af18f241d777c81340f1c565074914838029 100644 --- a/tensorflow/python/ops/distributions/uniform.py +++ b/tensorflow/python/ops/distributions/uniform.py @@ -29,8 +29,10 @@ from tensorflow.python.ops import check_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops.distributions import distribution +from tensorflow.python.util.tf_export import tf_export +@tf_export("distributions.Uniform") class Uniform(distribution.Distribution): """Uniform distribution with `low` and `high` parameters. diff --git a/tensorflow/python/ops/embedding_ops.py b/tensorflow/python/ops/embedding_ops.py index f4561d1a830141a069c12ddb33b83744363844f2..3826585f59c31133b12c365816729e090c9ab561 100644 --- a/tensorflow/python/ops/embedding_ops.py +++ b/tensorflow/python/ops/embedding_ops.py @@ -32,6 +32,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variables from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.util.tf_export import tf_export def _gather(params, ids, name=None): @@ -257,6 +258,7 @@ def _embedding_lookup_and_transform(params, return ret +@tf_export("nn.embedding_lookup") def embedding_lookup( params, ids, @@ -325,6 +327,7 @@ def embedding_lookup( transform_fn=None) +@tf_export("nn.embedding_lookup_sparse") def embedding_lookup_sparse(params, sp_ids, sp_weights, diff --git a/tensorflow/python/ops/functional_ops.py b/tensorflow/python/ops/functional_ops.py index 688512bea6b274eed2823794f017e14eb4f128f5..7dbccf1caf1486bb247a1bef0ac37c36adbcc53e 100644 --- a/tensorflow/python/ops/functional_ops.py +++ b/tensorflow/python/ops/functional_ops.py @@ -44,9 +44,11 @@ from tensorflow.python.ops.gen_functional_ops import * from tensorflow.python.ops.gen_functional_ops import _symbolic_gradient # pylint: enable=unused-import from tensorflow.python.util import nest +from tensorflow.python.util.tf_export import tf_export # TODO(yuanbyu, mrry): Handle stride to support sliding windows. +@tf_export("foldl") def foldl(fn, elems, initializer=None, parallel_iterations=10, back_prop=True, swap_memory=False, name=None): """foldl on the list of tensors unpacked from `elems` on dimension 0. @@ -134,6 +136,7 @@ def foldl(fn, elems, initializer=None, parallel_iterations=10, back_prop=True, return r_a +@tf_export("foldr") def foldr(fn, elems, initializer=None, parallel_iterations=10, back_prop=True, swap_memory=False, name=None): """foldr on the list of tensors unpacked from `elems` on dimension 0. @@ -221,6 +224,7 @@ def foldr(fn, elems, initializer=None, parallel_iterations=10, back_prop=True, return r_a +@tf_export("map_fn") def map_fn(fn, elems, dtype=None, parallel_iterations=10, back_prop=True, swap_memory=False, infer_shape=True, name=None): """map on the list of tensors unpacked from `elems` on dimension 0. @@ -424,6 +428,7 @@ def map_fn(fn, elems, dtype=None, parallel_iterations=10, back_prop=True, return output_pack(results_flat) +@tf_export("scan") def scan(fn, elems, initializer=None, parallel_iterations=10, back_prop=True, swap_memory=False, infer_shape=True, name=None): """scan on the list of tensors unpacked from `elems` on dimension 0. diff --git a/tensorflow/python/ops/gradient_checker.py b/tensorflow/python/ops/gradient_checker.py index 193046ba70e3448db4e5baac54be3699983b34b8..12afcd0b517d5e85112c067ccaca5693e5a4e231 100644 --- a/tensorflow/python/ops/gradient_checker.py +++ b/tensorflow/python/ops/gradient_checker.py @@ -31,6 +31,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import gradients from tensorflow.python.ops import math_ops from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.util.tf_export import tf_export def _product(t): @@ -264,6 +265,7 @@ def _compute_gradient_list(x, return ret +@tf_export("test.compute_gradient") def compute_gradient(x, x_shape, y, @@ -325,6 +327,7 @@ def compute_gradient(x, return ret +@tf_export("test.compute_gradient_error") def compute_gradient_error(x, x_shape, y, diff --git a/tensorflow/python/ops/gradients_impl.py b/tensorflow/python/ops/gradients_impl.py index 20c7a9fd6629cfe4657d8c0a25e2c6c2aad8ed49..5d4b9ecd8bee31c5092b04535e97b036eec9f1be 100644 --- a/tensorflow/python/ops/gradients_impl.py +++ b/tensorflow/python/ops/gradients_impl.py @@ -50,6 +50,7 @@ from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import spectral_grad # pylint: disable=unused-import from tensorflow.python.ops import tensor_array_ops from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.util.tf_export import tf_export # Warn the user if we convert a sparse representation to dense with at @@ -394,6 +395,7 @@ def _MaybeCompile(scope, op, func, grad_fn): return grad_fn() +@tf_export("gradients") def gradients(ys, xs, grad_ys=None, @@ -799,6 +801,7 @@ def _MultiDeviceAddN(tensor_list): return math_ops.add_n(summands) +@tf_export("AggregationMethod") class AggregationMethod(object): """A class listing aggregation methods used to combine gradients. @@ -971,6 +974,7 @@ def _hessian_vector_product(ys, xs, v): return gradients(elemwise_products, xs) +@tf_export("hessians") def hessians(ys, xs, name="hessians", colocate_gradients_with_ops=False, gate_gradients=False, aggregation_method=None): """Constructs the Hessian of sum of `ys` with respect to `x` in `xs`. diff --git a/tensorflow/python/ops/histogram_ops.py b/tensorflow/python/ops/histogram_ops.py index 4313b79b5b3e6045a5102c6ac29a2c3291e1b0aa..f079e56b10ed484225d8f09c6eaf7cf85a02d12a 100644 --- a/tensorflow/python/ops/histogram_ops.py +++ b/tensorflow/python/ops/histogram_ops.py @@ -31,6 +31,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import clip_ops from tensorflow.python.ops import gen_math_ops from tensorflow.python.ops import math_ops +from tensorflow.python.util.tf_export import tf_export def histogram_fixed_width_bins(values, @@ -56,7 +57,7 @@ def histogram_fixed_width_bins(values, Returns: A `Tensor` holding the indices of the binned values whose shape matches - `values`. + `values`. Examples: @@ -73,7 +74,7 @@ def histogram_fixed_width_bins(values, ``` """ with ops.name_scope(name, 'histogram_fixed_width_bins', - [values, value_range, nbins]) as scope: + [values, value_range, nbins]): values = ops.convert_to_tensor(values, name='values') shape = array_ops.shape(values) @@ -83,9 +84,10 @@ def histogram_fixed_width_bins(values, nbins_float = math_ops.cast(nbins, values.dtype) # Map tensor values that fall within value_range to [0, 1]. - scaled_values = math_ops.truediv(values - value_range[0], - value_range[1] - value_range[0], - name='scaled_values') + scaled_values = math_ops.truediv( + values - value_range[0], + value_range[1] - value_range[0], + name='scaled_values') # map tensor values within the open interval value_range to {0,.., nbins-1}, # values outside the open interval will be zero or less, or nbins or more. @@ -97,6 +99,7 @@ def histogram_fixed_width_bins(values, return array_ops.reshape(indices, shape) +@tf_export('histogram_fixed_width') def histogram_fixed_width(values, value_range, nbins=100, @@ -136,5 +139,5 @@ def histogram_fixed_width(values, """ with ops.name_scope(name, 'histogram_fixed_width', [values, value_range, nbins]) as name: - return gen_math_ops._histogram_fixed_width(values, value_range, nbins, - dtype=dtype, name=name) + return gen_math_ops._histogram_fixed_width( # pylint: disable=protected-access + values, value_range, nbins, dtype=dtype, name=name) diff --git a/tensorflow/python/ops/histogram_ops_test.py b/tensorflow/python/ops/histogram_ops_test.py index 80ee09057581db7298562fc22b443f5ddee73ef8..a226ac81bb536934cd191872ffc1aca84925abc0 100644 --- a/tensorflow/python/ops/histogram_ops_test.py +++ b/tensorflow/python/ops/histogram_ops_test.py @@ -36,7 +36,8 @@ class BinValuesFixedWidth(test.TestCase): values = [] expected_bins = [] with self.test_session(): - bins = histogram_ops.histogram_fixed_width_bins(values, value_range, nbins=5) + bins = histogram_ops.histogram_fixed_width_bins( + values, value_range, nbins=5) self.assertEqual(dtypes.int32, bins.dtype) self.assertAllClose(expected_bins, bins.eval()) @@ -69,8 +70,7 @@ class BinValuesFixedWidth(test.TestCase): # (-inf, 1), [1, 2), [2, 3), [3, 4), [4, inf) value_range = [0.0, 5.0] values = constant_op.constant( - [[-1.0, 0.0, 1.5], [2.0, 5.0, 15]], - shape=(2, 3)) + [[-1.0, 0.0, 1.5], [2.0, 5.0, 15]], shape=(2, 3)) expected_bins = [[0, 0, 1], [2, 4, 4]] with self.test_session(): bins = histogram_ops.histogram_fixed_width_bins( @@ -140,8 +140,8 @@ class HistogramFixedWidthTest(test.TestCase): self.assertEqual(dtypes.int32, hist.dtype) self.assertAllClose(expected_bin_counts, hist.eval()) - hist = histogram_ops.histogram_fixed_width(values, value_range, - nbins=placeholder) + hist = histogram_ops.histogram_fixed_width( + values, value_range, nbins=placeholder) self.assertEquals(hist.shape.ndims, 1) self.assertIs(hist.shape[0].value, None) self.assertEqual(dtypes.int32, hist.dtype) diff --git a/tensorflow/python/ops/image_ops.py b/tensorflow/python/ops/image_ops.py index 3b0b5a978c9f79dca9b87d3a7b6478b63e1fcb8d..de12c5f63f4357e0982dd2e16999caf2de0b30f8 100644 --- a/tensorflow/python/ops/image_ops.py +++ b/tensorflow/python/ops/image_ops.py @@ -49,6 +49,10 @@ See the @{$python/image} guide. @@grayscale_to_rgb @@hsv_to_rgb @@rgb_to_hsv +@@rgb_to_yiq +@@yiq_to_rgb +@@rgb_to_yuv +@@yuv_to_rgb @@convert_image_dtype @@adjust_brightness @@random_brightness diff --git a/tensorflow/python/ops/image_ops_impl.py b/tensorflow/python/ops/image_ops_impl.py index 9f09d0a4d1ff4eed9647b6c74db0b1803df0ad70..721efcf78656a8832763a473668c108454bde915 100644 --- a/tensorflow/python/ops/image_ops_impl.py +++ b/tensorflow/python/ops/image_ops_impl.py @@ -12,15 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - """Implementation of image ops.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -import os - from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -28,7 +25,6 @@ from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops -from tensorflow.python.ops import clip_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gen_image_ops from tensorflow.python.ops import gen_nn_ops @@ -36,7 +32,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import string_ops from tensorflow.python.ops import variables - +from tensorflow.python.util.tf_export import tf_export ops.NotDifferentiable('RandomCrop') # TODO(b/31222613): This op may be differentiable, and there may be @@ -109,8 +105,9 @@ def _ImageDimensions(image, rank): else: static_shape = image.get_shape().with_rank(rank).as_list() dynamic_shape = array_ops.unstack(array_ops.shape(image), rank) - return [s if s is not None else d - for s, d in zip(static_shape, dynamic_shape)] + return [ + s if s is not None else d for s, d in zip(static_shape, dynamic_shape) + ] def _Check3DImage(image, require_static=True): @@ -131,22 +128,45 @@ def _Check3DImage(image, require_static=True): try: image_shape = image.get_shape().with_rank(3) except ValueError: - raise ValueError("'image' (shape %s) must be three-dimensional." % - image.shape) + raise ValueError( + "'image' (shape %s) must be three-dimensional." % image.shape) if require_static and not image_shape.is_fully_defined(): - raise ValueError("'image' (shape %s) must be fully defined." % - image_shape) + raise ValueError("'image' (shape %s) must be fully defined." % image_shape) if any(x == 0 for x in image_shape): - raise ValueError("all dims of 'image.shape' must be > 0: %s" % - image_shape) + raise ValueError("all dims of 'image.shape' must be > 0: %s" % image_shape) if not image_shape.is_fully_defined(): - return [check_ops.assert_positive(array_ops.shape(image), - ["all dims of 'image.shape' " - "must be > 0."])] + return [ + check_ops.assert_positive( + array_ops.shape(image), + ["all dims of 'image.shape' " + 'must be > 0.']) + ] else: return [] +def _Assert3DImage(image): + """Assert that we are working with a properly shaped image. + + Performs the check statically if possible (i.e. if the shape + is statically known). Otherwise adds a control dependency + to an assert op that checks the dynamic shape. + + Args: + image: 3-D Tensor of shape [height, width, channels] + + Raises: + ValueError: if `image.shape` is not a 3-vector. + + Returns: + If the shape of `image` could be verified statically, `image` is + returned unchanged, otherwise there will be a control dependency + added that asserts the correct dynamic shape. + """ + return control_flow_ops.with_dependencies( + _Check3DImage(image, require_static=False), image) + + def _CheckAtLeast3DImage(image, require_static=True): """Assert that we are working with properly shaped image. @@ -172,12 +192,15 @@ def _CheckAtLeast3DImage(image, require_static=True): if require_static and not image_shape.is_fully_defined(): raise ValueError('\'image\' must be fully defined.') if any(x == 0 for x in image_shape): - raise ValueError('all dims of \'image.shape\' must be > 0: %s' % - image_shape) + raise ValueError( + 'all dims of \'image.shape\' must be > 0: %s' % image_shape) if not image_shape.is_fully_defined(): - return [check_ops.assert_positive(array_ops.shape(image), - ["all dims of 'image.shape' " - "must be > 0."])] + return [ + check_ops.assert_positive( + array_ops.shape(image), + ["all dims of 'image.shape' " + 'must be > 0.']) + ] else: return [] @@ -201,6 +224,7 @@ def fix_image_flip_shape(image, result): return result +@tf_export('image.random_flip_up_down') def random_flip_up_down(image, seed=None): """Randomly flips an image vertically (upside down). @@ -221,17 +245,18 @@ def random_flip_up_down(image, seed=None): """ with ops.name_scope(None, 'random_flip_up_down', [image]) as scope: image = ops.convert_to_tensor(image, name='image') - image = control_flow_ops.with_dependencies( - _Check3DImage(image, require_static=False), image) + image = _Assert3DImage(image) uniform_random = random_ops.random_uniform([], 0, 1.0, seed=seed) mirror_cond = math_ops.less(uniform_random, .5) - result = control_flow_ops.cond(mirror_cond, - lambda: array_ops.reverse(image, [0]), - lambda: image, - name=scope) + result = control_flow_ops.cond( + mirror_cond, + lambda: array_ops.reverse(image, [0]), + lambda: image, + name=scope) return fix_image_flip_shape(image, result) +@tf_export('image.random_flip_left_right') def random_flip_left_right(image, seed=None): """Randomly flip an image horizontally (left to right). @@ -252,17 +277,18 @@ def random_flip_left_right(image, seed=None): """ with ops.name_scope(None, 'random_flip_left_right', [image]) as scope: image = ops.convert_to_tensor(image, name='image') - image = control_flow_ops.with_dependencies( - _Check3DImage(image, require_static=False), image) + image = _Assert3DImage(image) uniform_random = random_ops.random_uniform([], 0, 1.0, seed=seed) mirror_cond = math_ops.less(uniform_random, .5) - result = control_flow_ops.cond(mirror_cond, - lambda: array_ops.reverse(image, [1]), - lambda: image, - name=scope) + result = control_flow_ops.cond( + mirror_cond, + lambda: array_ops.reverse(image, [1]), + lambda: image, + name=scope) return fix_image_flip_shape(image, result) +@tf_export('image.flip_left_right') def flip_left_right(image): """Flip an image horizontally (left to right). @@ -282,12 +308,12 @@ def flip_left_right(image): """ with ops.name_scope(None, 'flip_left_right', [image]) as scope: image = ops.convert_to_tensor(image, name='image') - image = control_flow_ops.with_dependencies( - _Check3DImage(image, require_static=False), image) - return fix_image_flip_shape(image, - array_ops.reverse(image, [1], name=scope)) + image = _Assert3DImage(image) + return fix_image_flip_shape(image, array_ops.reverse( + image, [1], name=scope)) +@tf_export('image.flip_up_down') def flip_up_down(image): """Flip an image vertically (upside down). @@ -307,12 +333,12 @@ def flip_up_down(image): """ with ops.name_scope(None, 'flip_up_down', [image]) as scope: image = ops.convert_to_tensor(image, name='image') - image = control_flow_ops.with_dependencies( - _Check3DImage(image, require_static=False), image) - return fix_image_flip_shape(image, - array_ops.reverse(image, [0], name=scope)) + image = _Assert3DImage(image) + return fix_image_flip_shape(image, array_ops.reverse( + image, [0], name=scope)) +@tf_export('image.rot90') def rot90(image, k=1, name=None): """Rotate an image counter-clockwise by 90 degrees. @@ -326,30 +352,30 @@ def rot90(image, k=1, name=None): """ with ops.name_scope(name, 'rot90', [image, k]) as scope: image = ops.convert_to_tensor(image, name='image') - image = control_flow_ops.with_dependencies( - _Check3DImage(image, require_static=False), image) + image = _Assert3DImage(image) k = ops.convert_to_tensor(k, dtype=dtypes.int32, name='k') k.get_shape().assert_has_rank(0) k = math_ops.mod(k, 4) def _rot90(): - return array_ops.transpose(array_ops.reverse_v2(image, [1]), - [1, 0, 2]) + return array_ops.transpose(array_ops.reverse_v2(image, [1]), [1, 0, 2]) + def _rot180(): return array_ops.reverse_v2(image, [0, 1]) + def _rot270(): - return array_ops.reverse_v2(array_ops.transpose(image, [1, 0, 2]), - [1]) - cases = [(math_ops.equal(k, 1), _rot90), - (math_ops.equal(k, 2), _rot180), + return array_ops.reverse_v2(array_ops.transpose(image, [1, 0, 2]), [1]) + + cases = [(math_ops.equal(k, 1), _rot90), (math_ops.equal(k, 2), _rot180), (math_ops.equal(k, 3), _rot270)] - ret = control_flow_ops.case(cases, default=lambda: image, exclusive=True, - name=scope) + ret = control_flow_ops.case( + cases, default=lambda: image, exclusive=True, name=scope) ret.set_shape([None, None, image.get_shape()[2]]) return ret +@tf_export('image.transpose_image') def transpose_image(image): """Transpose an image by swapping the first and second dimension. @@ -366,11 +392,11 @@ def transpose_image(image): """ with ops.name_scope(None, 'transpose_image', [image]) as scope: image = ops.convert_to_tensor(image, name='image') - image = control_flow_ops.with_dependencies( - _Check3DImage(image, require_static=False), image) + image = _Assert3DImage(image) return array_ops.transpose(image, [1, 0, 2], name=scope) +@tf_export('image.central_crop') def central_crop(image, central_fraction): """Crop the central region of the image. @@ -402,8 +428,7 @@ def central_crop(image, central_fraction): if central_fraction == 1.0: return image - image = control_flow_ops.with_dependencies( - _Check3DImage(image, require_static=False), image) + image = _Assert3DImage(image) img_shape = array_ops.shape(image) depth = image.get_shape()[2] @@ -424,6 +449,7 @@ def central_crop(image, central_fraction): return image +@tf_export('image.pad_to_bounding_box') def pad_to_bounding_box(image, offset_height, offset_width, target_height, target_width): """Pad `image` with zeros to the specified `height` and `width`. @@ -494,8 +520,10 @@ def pad_to_bounding_box(image, offset_height, offset_width, target_height, ]), [4, 2]) padded = array_ops.pad(image, paddings) - padded_shape = [None if _is_tensor(i) else i - for i in [batch, target_height, target_width, depth]] + padded_shape = [ + None if _is_tensor(i) else i + for i in [batch, target_height, target_width, depth] + ] padded.set_shape(padded_shape) if not is_batch: @@ -504,6 +532,7 @@ def pad_to_bounding_box(image, offset_height, offset_width, target_height, return padded +@tf_export('image.crop_to_bounding_box') def crop_to_bounding_box(image, offset_height, offset_width, target_height, target_width): """Crops an image to a specified bounding box. @@ -568,12 +597,13 @@ def crop_to_bounding_box(image, offset_height, offset_width, target_height, image = control_flow_ops.with_dependencies(assert_ops, image) cropped = array_ops.slice( - image, - array_ops.stack([0, offset_height, offset_width, 0]), + image, array_ops.stack([0, offset_height, offset_width, 0]), array_ops.stack([-1, target_height, target_width, -1])) - cropped_shape = [None if _is_tensor(i) else i - for i in [batch, target_height, target_width, depth]] + cropped_shape = [ + None if _is_tensor(i) else i + for i in [batch, target_height, target_width, depth] + ] cropped.set_shape(cropped_shape) if not is_batch: @@ -582,6 +612,7 @@ def crop_to_bounding_box(image, offset_height, offset_width, target_height, return cropped +@tf_export('image.resize_image_with_crop_or_pad') def resize_image_with_crop_or_pad(image, target_height, target_width): """Crops and/or pads an image to a target width and height. @@ -637,8 +668,8 @@ def resize_image_with_crop_or_pad(image, target_height, target_width): target_height = control_flow_ops.with_dependencies( assert_ops, target_height) if _is_tensor(target_width): - target_width = control_flow_ops.with_dependencies( - assert_ops, target_width) + target_width = control_flow_ops.with_dependencies(assert_ops, + target_width) def max_(x, y): if _is_tensor(x) or _is_tensor(y): @@ -683,10 +714,12 @@ def resize_image_with_crop_or_pad(image, target_height, target_width): _, resized_height, resized_width, _ = _ImageDimensions(resized, rank=4) assert_ops = [] - assert_ops += _assert(equal_(resized_height, target_height), ValueError, - 'resized height is not correct.') - assert_ops += _assert(equal_(resized_width, target_width), ValueError, - 'resized width is not correct.') + assert_ops += _assert( + equal_(resized_height, target_height), ValueError, + 'resized height is not correct.') + assert_ops += _assert( + equal_(resized_width, target_width), ValueError, + 'resized width is not correct.') resized = control_flow_ops.with_dependencies(assert_ops, resized) @@ -696,6 +729,7 @@ def resize_image_with_crop_or_pad(image, target_height, target_width): return resized +@tf_export('image.ResizeMethod') class ResizeMethod(object): BILINEAR = 0 NEAREST_NEIGHBOR = 1 @@ -703,6 +737,7 @@ class ResizeMethod(object): AREA = 3 +@tf_export('image.resize_images') def resize_images(images, size, method=ResizeMethod.BILINEAR, @@ -785,22 +820,17 @@ def resize_images(images, return images if method == ResizeMethod.BILINEAR: - images = gen_image_ops.resize_bilinear(images, - size, - align_corners=align_corners) + images = gen_image_ops.resize_bilinear( + images, size, align_corners=align_corners) elif method == ResizeMethod.NEAREST_NEIGHBOR: - images = gen_image_ops.resize_nearest_neighbor(images, - size, - align_corners= - align_corners) + images = gen_image_ops.resize_nearest_neighbor( + images, size, align_corners=align_corners) elif method == ResizeMethod.BICUBIC: - images = gen_image_ops.resize_bicubic(images, - size, - align_corners=align_corners) + images = gen_image_ops.resize_bicubic( + images, size, align_corners=align_corners) elif method == ResizeMethod.AREA: - images = gen_image_ops.resize_area(images, - size, - align_corners=align_corners) + images = gen_image_ops.resize_area( + images, size, align_corners=align_corners) else: raise ValueError('Resize method is not implemented.') @@ -813,6 +843,7 @@ def resize_images(images, return images +@tf_export('image.per_image_standardization') def per_image_standardization(image): """Linearly scales `image` to have zero mean and unit norm. @@ -834,15 +865,15 @@ def per_image_standardization(image): """ with ops.name_scope(None, 'per_image_standardization', [image]) as scope: image = ops.convert_to_tensor(image, name='image') - image = control_flow_ops.with_dependencies( - _Check3DImage(image, require_static=False), image) + image = _Assert3DImage(image) num_pixels = math_ops.reduce_prod(array_ops.shape(image)) image = math_ops.cast(image, dtype=dtypes.float32) image_mean = math_ops.reduce_mean(image) - variance = (math_ops.reduce_mean(math_ops.square(image)) - - math_ops.square(image_mean)) + variance = ( + math_ops.reduce_mean(math_ops.square(image)) - + math_ops.square(image_mean)) variance = gen_nn_ops.relu(variance) stddev = math_ops.sqrt(variance) @@ -856,6 +887,7 @@ def per_image_standardization(image): return image +@tf_export('image.random_brightness') def random_brightness(image, max_delta, seed=None): """Adjust the brightness of images by a random factor. @@ -882,6 +914,7 @@ def random_brightness(image, max_delta, seed=None): return adjust_brightness(image, delta) +@tf_export('image.random_contrast') def random_contrast(image, lower, upper, seed=None): """Adjust the contrast of an image by a random factor. @@ -913,6 +946,7 @@ def random_contrast(image, lower, upper, seed=None): return adjust_contrast(image, contrast_factor) +@tf_export('image.adjust_brightness') def adjust_brightness(image, delta): """Adjust the brightness of RGB or Grayscale images. @@ -940,13 +974,13 @@ def adjust_brightness(image, delta): orig_dtype = image.dtype flt_image = convert_image_dtype(image, dtypes.float32) - adjusted = math_ops.add(flt_image, - math_ops.cast(delta, dtypes.float32), - name=name) + adjusted = math_ops.add( + flt_image, math_ops.cast(delta, dtypes.float32), name=name) return convert_image_dtype(adjusted, orig_dtype, saturate=True) +@tf_export('image.adjust_contrast') def adjust_contrast(images, contrast_factor): """Adjust contrast of RGB or grayscale images. @@ -980,14 +1014,14 @@ def adjust_contrast(images, contrast_factor): flt_images = convert_image_dtype(images, dtypes.float32) # pylint: disable=protected-access - adjusted = gen_image_ops._adjust_contrastv2(flt_images, - contrast_factor=contrast_factor, - name=name) + adjusted = gen_image_ops._adjust_contrastv2( + flt_images, contrast_factor=contrast_factor, name=name) # pylint: enable=protected-access return convert_image_dtype(adjusted, orig_dtype, saturate=True) +@tf_export('image.adjust_gamma') def adjust_gamma(image, gamma=1, gain=1): """Performs Gamma Correction on the input image. @@ -1026,16 +1060,17 @@ def adjust_gamma(image, gamma=1, gain=1): 'Gamma should be a non-negative real number.') if assert_op: gamma = control_flow_ops.with_dependencies(assert_op, gamma) - + # scale = max(dtype) - min(dtype). - scale = constant_op.constant(image.dtype.limits[1] - image.dtype.limits[0], - dtype=dtypes.float32) + scale = constant_op.constant( + image.dtype.limits[1] - image.dtype.limits[0], dtype=dtypes.float32) # According to the definition of gamma correction. - adjusted_img = (img / scale) ** gamma * scale * gain + adjusted_img = (img / scale)**gamma * scale * gain return adjusted_img +@tf_export('image.convert_image_dtype') def convert_image_dtype(image, dtype, saturate=False, name=None): """Convert `image` to `dtype`, scaling its values if needed. @@ -1114,6 +1149,7 @@ def convert_image_dtype(image, dtype, saturate=False, name=None): return math_ops.cast(scaled, dtype, name=name) +@tf_export('image.rgb_to_grayscale') def rgb_to_grayscale(images, name=None): """Converts one or more images from RGB to Grayscale. @@ -1143,6 +1179,7 @@ def rgb_to_grayscale(images, name=None): return convert_image_dtype(gray_float, orig_dtype, name=name) +@tf_export('image.grayscale_to_rgb') def grayscale_to_rgb(images, name=None): """Converts one or more images from Grayscale to RGB. @@ -1159,9 +1196,8 @@ def grayscale_to_rgb(images, name=None): with ops.name_scope(name, 'grayscale_to_rgb', [images]) as name: images = ops.convert_to_tensor(images, name='images') rank_1 = array_ops.expand_dims(array_ops.rank(images) - 1, 0) - shape_list = ( - [array_ops.ones(rank_1, - dtype=dtypes.int32)] + [array_ops.expand_dims(3, 0)]) + shape_list = ([array_ops.ones(rank_1, dtype=dtypes.int32)] + + [array_ops.expand_dims(3, 0)]) multiples = array_ops.concat(shape_list, 0) rgb = array_ops.tile(images, multiples, name=name) rgb.set_shape(images.get_shape()[:-1].concatenate([3])) @@ -1169,6 +1205,7 @@ def grayscale_to_rgb(images, name=None): # pylint: disable=invalid-name +@tf_export('image.random_hue') def random_hue(image, max_delta, seed=None): """Adjust the hue of an RGB image by a random factor. @@ -1201,6 +1238,7 @@ def random_hue(image, max_delta, seed=None): return adjust_hue(image, delta) +@tf_export('image.adjust_hue') def adjust_hue(image, delta, name=None): """Adjust hue of an RGB image. @@ -1234,6 +1272,7 @@ def adjust_hue(image, delta, name=None): return convert_image_dtype(rgb_altered, orig_dtype) +@tf_export('image.random_saturation') def random_saturation(image, lower, upper, seed=None): """Adjust the saturation of an RGB image by a random factor. @@ -1266,6 +1305,7 @@ def random_saturation(image, lower, upper, seed=None): return adjust_saturation(image, saturation_factor) +@tf_export('image.adjust_saturation') def adjust_saturation(image, saturation_factor, name=None): """Adjust saturation of an RGB image. @@ -1297,6 +1337,8 @@ def adjust_saturation(image, saturation_factor, name=None): gen_image_ops.adjust_saturation(flt_image, saturation_factor), orig_dtype) + +@tf_export('image.decode_image') def decode_image(contents, channels=None, name=None): """Convenience function for `decode_bmp`, `decode_gif`, `decode_jpeg`, and `decode_png`. @@ -1351,8 +1393,7 @@ def decode_image(contents, channels=None, name=None): gif_channels = 0 if channels is None else channels good_channels = math_ops.logical_and( math_ops.not_equal(gif_channels, 1, name='check_gif_channels'), - math_ops.not_equal(gif_channels, 4, name='check_gif_channels') - ) + math_ops.not_equal(gif_channels, 4, name='check_gif_channels')) channels_msg = 'Channels must be in (None, 0, 3) when decoding GIF images' assert_channels = control_flow_ops.Assert(good_channels, [channels_msg]) with ops.control_dependencies([assert_channels]): @@ -1375,8 +1416,8 @@ def decode_image(contents, channels=None, name=None): def _jpeg(): """Decodes a jpeg image.""" jpeg_channels = 0 if channels is None else channels - good_channels = math_ops.not_equal(jpeg_channels, 4, - name='check_jpeg_channels') + good_channels = math_ops.not_equal( + jpeg_channels, 4, name='check_jpeg_channels') channels_msg = ('Channels must be in (None, 0, 1, 3) when decoding JPEG ' 'images') assert_channels = control_flow_ops.Assert(good_channels, [channels_msg]) @@ -1389,6 +1430,7 @@ def decode_image(contents, channels=None, name=None): return control_flow_ops.cond(is_jpeg, _jpeg, check_png, name='cond_jpeg') +@tf_export('image.total_variation') def total_variation(images, name=None): """Calculate and return the total variation for one or more images. @@ -1453,15 +1495,21 @@ def total_variation(images, name=None): # Calculate the total variation by taking the absolute value of the # pixel-differences and summing over the appropriate axis. - tot_var = (math_ops.reduce_sum(math_ops.abs(pixel_dif1), axis=sum_axis) + - math_ops.reduce_sum(math_ops.abs(pixel_dif2), axis=sum_axis)) + tot_var = ( + math_ops.reduce_sum(math_ops.abs(pixel_dif1), axis=sum_axis) + + math_ops.reduce_sum(math_ops.abs(pixel_dif2), axis=sum_axis)) return tot_var -def sample_distorted_bounding_box(image_size, bounding_boxes, seed=None, - seed2=None, min_object_covered=None, - aspect_ratio_range=None, area_range=None, +@tf_export('image.sample_distorted_bounding_box') +def sample_distorted_bounding_box(image_size, + bounding_boxes, + seed=None, + seed2=None, + min_object_covered=0.1, + aspect_ratio_range=None, + area_range=None, max_attempts=None, use_image_if_no_bounding_boxes=None, name=None): @@ -1477,10 +1525,12 @@ def sample_distorted_bounding_box(image_size, bounding_boxes, seed=None, The output of this Op is a single bounding box that may be used to crop the original image. The output is returned as 3 tensors: `begin`, `size` and `bboxes`. The first 2 tensors can be fed directly into `tf.slice` to crop the - image. The latter may be supplied to `tf.image.draw_bounding_boxes` to visualize + image. The latter may be supplied to `tf.image.draw_bounding_boxes` to + visualize what the bounding box looks like. - Bounding boxes are supplied and returned as `[y_min, x_min, y_max, x_max]`. The + Bounding boxes are supplied and returned as `[y_min, x_min, y_max, x_max]`. + The bounding box coordinates are floats in `[0.0, 1.0]` relative to the width and height of the underlying image. @@ -1508,23 +1558,27 @@ def sample_distorted_bounding_box(image_size, bounding_boxes, seed=None, false and no bounding boxes are supplied, an error is raised. Args: - image_size: A `Tensor`. Must be one of the following types: `uint8`, `int8`, `int16`, `int32`, `int64`. + image_size: A `Tensor`. Must be one of the following types: `uint8`, `int8`, + `int16`, `int32`, `int64`. 1-D, containing `[height, width, channels]`. bounding_boxes: A `Tensor` of type `float32`. 3-D with shape `[batch, N, 4]` describing the N bounding boxes associated with the image. seed: An optional `int`. Defaults to `0`. If either `seed` or `seed2` are set to non-zero, the random number - generator is seeded by the given `seed`. Otherwise, it is seeded by a random + generator is seeded by the given `seed`. Otherwise, it is seeded by a + random seed. seed2: An optional `int`. Defaults to `0`. A second seed to avoid seed collision. min_object_covered: A Tensor of type `float32`. Defaults to `0.1`. The cropped area of the image must contain at least this - fraction of any bounding box supplied. The value of this parameter should be + fraction of any bounding box supplied. The value of this parameter should + be non-negative. In the case of 0, the cropped area does not need to overlap any of the bounding boxes supplied. - aspect_ratio_range: An optional list of `floats`. Defaults to `[0.75, 1.33]`. + aspect_ratio_range: An optional list of `floats`. Defaults to `[0.75, + 1.33]`. The cropped area of the image must have an aspect ratio = width / height within this range. area_range: An optional list of `floats`. Defaults to `[0.05, 1]`. @@ -1532,34 +1586,44 @@ def sample_distorted_bounding_box(image_size, bounding_boxes, seed=None, supplied image within in this range. max_attempts: An optional `int`. Defaults to `100`. Number of attempts at generating a cropped region of the image - of the specified constraints. After `max_attempts` failures, return the entire + of the specified constraints. After `max_attempts` failures, return the + entire image. use_image_if_no_bounding_boxes: An optional `bool`. Defaults to `False`. Controls behavior if no bounding boxes supplied. - If true, assume an implicit bounding box covering the whole input. If false, + If true, assume an implicit bounding box covering the whole input. If + false, raise an error. name: A name for the operation (optional). Returns: A tuple of `Tensor` objects (begin, size, bboxes). - begin: A `Tensor`. Has the same type as `image_size`. 1-D, containing `[offset_height, offset_width, 0]`. Provide as input to + begin: A `Tensor`. Has the same type as `image_size`. 1-D, containing + `[offset_height, offset_width, 0]`. Provide as input to `tf.slice`. - size: A `Tensor`. Has the same type as `image_size`. 1-D, containing `[target_height, target_width, -1]`. Provide as input to + size: A `Tensor`. Has the same type as `image_size`. 1-D, containing + `[target_height, target_width, -1]`. Provide as input to `tf.slice`. - bboxes: A `Tensor` of type `float32`. 3-D with shape `[1, 1, 4]` containing the distorted bounding box. + bboxes: A `Tensor` of type `float32`. 3-D with shape `[1, 1, 4]` containing + the distorted bounding box. Provide as input to `tf.image.draw_bounding_boxes`. """ with ops.name_scope(name, 'sample_distorted_bounding_box'): - return gen_image_ops._sample_distorted_bounding_box_v2(image_size, - bounding_boxes, seed=seed, - seed2=seed2, min_object_covered=min_object_covered, - aspect_ratio_range=aspect_ratio_range, area_range=area_range, - max_attempts=max_attempts, - use_image_if_no_bounding_boxes=use_image_if_no_bounding_boxes, - name=name) - - + return gen_image_ops._sample_distorted_bounding_box_v2( # pylint: disable=protected-access + image_size, + bounding_boxes, + seed=seed, + seed2=seed2, + min_object_covered=min_object_covered, + aspect_ratio_range=aspect_ratio_range, + area_range=area_range, + max_attempts=max_attempts, + use_image_if_no_bounding_boxes=use_image_if_no_bounding_boxes, + name=name) + + +@tf_export('image.non_max_suppression') def non_max_suppression(boxes, scores, max_output_size, @@ -1604,3 +1668,106 @@ def non_max_suppression(boxes, return gen_image_ops._non_max_suppression_v2(boxes, scores, max_output_size, iou_threshold) # pylint: enable=protected-access + + +_rgb_to_yiq_kernel = [[0.299, 0.59590059, 0.2115], + [0.587, -0.27455667, -0.52273617], + [0.114, -0.32134392, 0.31119955]] + + +def rgb_to_yiq(images): + """Converts one or more images from RGB to YIQ. + + Outputs a tensor of the same shape as the `images` tensor, containing the YIQ + value of the pixels. + The output is only well defined if the value in images are in [0,1]. + + Args: + images: 2-D or higher rank. Image data to convert. Last dimension must be + size 3. + + Returns: + images: tensor with the same shape as `images`. + """ + images = ops.convert_to_tensor(images, name='images') + kernel = ops.convert_to_tensor(_rgb_to_yiq_kernel, dtype=images.dtype, name='kernel') + ndims = images.get_shape().ndims + return math_ops.tensordot(images, kernel, axes=[[ndims-1], [0]]) + + +_yiq_to_rgb_kernel = [[1, 1, 1], + [0.95598634, -0.27201283, -1.10674021], + [0.6208248, -0.64720424, 1.70423049]] + + +def yiq_to_rgb(images): + """Converts one or more images from YIQ to RGB. + + Outputs a tensor of the same shape as the `images` tensor, containing the RGB + value of the pixels. + The output is only well defined if the Y value in images are in [0,1], + I value are in [-0.5957,0.5957] and Q value are in [-0.5226,0.5226]. + + Args: + images: 2-D or higher rank. Image data to convert. Last dimension must be + size 3. + + Returns: + images: tensor with the same shape as `images`. + """ + images = ops.convert_to_tensor(images, name='images') + kernel = ops.convert_to_tensor(_yiq_to_rgb_kernel, dtype=images.dtype, name='kernel') + ndims = images.get_shape().ndims + return math_ops.tensordot(images, kernel, axes=[[ndims-1], [0]]) + + +_rgb_to_yuv_kernel = [[0.299, -0.14714119, 0.61497538], + [0.587, -0.28886916, -0.51496512], + [0.114, 0.43601035, -0.10001026]] + + +def rgb_to_yuv(images): + """Converts one or more images from RGB to YUV. + + Outputs a tensor of the same shape as the `images` tensor, containing the YUV + value of the pixels. + The output is only well defined if the value in images are in [0,1]. + + Args: + images: 2-D or higher rank. Image data to convert. Last dimension must be + size 3. + + Returns: + images: tensor with the same shape as `images`. + """ + images = ops.convert_to_tensor(images, name='images') + kernel = ops.convert_to_tensor(_rgb_to_yuv_kernel, dtype=images.dtype, name='kernel') + ndims = images.get_shape().ndims + return math_ops.tensordot(images, kernel, axes=[[ndims-1], [0]]) + + +_yuv_to_rgb_kernel = [[1, 1, 1], + [0, -0.394642334, 2.03206185], + [1.13988303, -0.58062185, 0]] + + +def yuv_to_rgb(images): + """Converts one or more images from YUV to RGB. + + Outputs a tensor of the same shape as the `images` tensor, containing the RGB + value of the pixels. + The output is only well defined if the Y value in images are in [0,1], + U and V value are in [-0.5,0.5]. + + Args: + images: 2-D or higher rank. Image data to convert. Last dimension must be + size 3. + + Returns: + images: tensor with the same shape as `images`. + """ + images = ops.convert_to_tensor(images, name='images') + kernel = ops.convert_to_tensor(_yuv_to_rgb_kernel, dtype=images.dtype, name='kernel') + ndims = images.get_shape().ndims + return math_ops.tensordot(images, kernel, axes=[[ndims-1], [0]]) + diff --git a/tensorflow/python/ops/image_ops_test.py b/tensorflow/python/ops/image_ops_test.py index 3a49d41c9ea031126286a4b70861394d6907381f..9834384634261e5d99cac6a4d09b0417b9b2f883 100644 --- a/tensorflow/python/ops/image_ops_test.py +++ b/tensorflow/python/ops/image_ops_test.py @@ -85,6 +85,64 @@ class RGBToHSVTest(test_util.TensorFlowTestCase): self.assertAllClose(rgb_tf, rgb_np) +class RGBToYIQTest(test_util.TensorFlowTestCase): + + def testBatch(self): + # Build an arbitrary RGB image + np.random.seed(7) + batch_size = 5 + shape = (batch_size, 2, 7, 3) + + for nptype in [np.float32, np.float64]: + inp = np.random.rand(*shape).astype(nptype) + + # Convert to YIQ and back, as a batch and individually + with self.test_session(use_gpu=True) as sess: + batch0 = constant_op.constant(inp) + batch1 = image_ops.rgb_to_yiq(batch0) + batch2 = image_ops.yiq_to_rgb(batch1) + split0 = array_ops.unstack(batch0) + split1 = list(map(image_ops.rgb_to_yiq, split0)) + split2 = list(map(image_ops.yiq_to_rgb, split1)) + join1 = array_ops.stack(split1) + join2 = array_ops.stack(split2) + batch1, batch2, join1, join2 = sess.run([batch1, batch2, join1, join2]) + + # Verify that processing batch elements together is the same as separate + self.assertAllClose(batch1, join1, rtol=1e-4, atol=1e-4) + self.assertAllClose(batch2, join2, rtol=1e-4, atol=1e-4) + self.assertAllClose(batch2, inp, rtol=1e-4, atol=1e-4) + + +class RGBToYUVTest(test_util.TensorFlowTestCase): + + def testBatch(self): + # Build an arbitrary RGB image + np.random.seed(7) + batch_size = 5 + shape = (batch_size, 2, 7, 3) + + for nptype in [np.float32, np.float64]: + inp = np.random.rand(*shape).astype(nptype) + + # Convert to YUV and back, as a batch and individually + with self.test_session(use_gpu=True) as sess: + batch0 = constant_op.constant(inp) + batch1 = image_ops.rgb_to_yuv(batch0) + batch2 = image_ops.yuv_to_rgb(batch1) + split0 = array_ops.unstack(batch0) + split1 = list(map(image_ops.rgb_to_yuv, split0)) + split2 = list(map(image_ops.yuv_to_rgb, split1)) + join1 = array_ops.stack(split1) + join2 = array_ops.stack(split2) + batch1, batch2, join1, join2 = sess.run([batch1, batch2, join1, join2]) + + # Verify that processing batch elements together is the same as separate + self.assertAllClose(batch1, join1, rtol=1e-4, atol=1e-4) + self.assertAllClose(batch2, join2, rtol=1e-4, atol=1e-4) + self.assertAllClose(batch2, inp, rtol=1e-4, atol=1e-4) + + class GrayscaleToRGBTest(test_util.TensorFlowTestCase): def _RGBToGrayscale(self, images): @@ -1857,6 +1915,25 @@ class SelectDistortedCropBoxTest(test_util.TensorFlowTestCase): self.assertAllEqual([3], end.get_shape().as_list()) self.assertAllEqual([1, 1, 4], bbox_for_drawing.get_shape().as_list()) + def testDefaultMinObjectCovered(self): + # By default min_object_covered=0.1 if not provided + with self.test_session(use_gpu=True): + image_size = constant_op.constant( + [40, 50, 1], shape=[3], dtype=dtypes.int32) + bounding_box = constant_op.constant( + [0.0, 0.0, 1.0, 1.0], + shape=[4], + dtype=dtypes.float32,) + begin, end, bbox_for_drawing = image_ops.sample_distorted_bounding_box( + image_size=image_size, + bounding_boxes=bounding_box, + aspect_ratio_range=(0.75, 1.33), + area_range=(0.05, 1.0)) + + self.assertAllEqual([3], begin.get_shape().as_list()) + self.assertAllEqual([3], end.get_shape().as_list()) + self.assertAllEqual([1, 1, 4], bbox_for_drawing.get_shape().as_list()) + class ResizeImagesTest(test_util.TensorFlowTestCase): @@ -2833,6 +2910,16 @@ class PngTest(test_util.TensorFlowTestCase): class GifTest(test_util.TensorFlowTestCase): + def testOptimizedGifErrorString(self): + filename = "tensorflow/core/lib/gif/testdata/optimized.gif" + + with self.test_session(use_gpu=True) as sess: + gif = io_ops.read_file(filename) + image = image_ops.decode_gif(gif) + with self.assertRaisesRegexp( + errors.InvalidArgumentError, "can't process optimized gif"): + gif, image = sess.run([gif, image]) + def testValid(self): # Read some real GIFs prefix = "tensorflow/core/lib/gif/testdata/" diff --git a/tensorflow/python/ops/init_ops.py b/tensorflow/python/ops/init_ops.py index 5dc43d65b955613698efccd06f60f1c1b05842d6..c7502d0fda5c38079362d30877a917e3965e6ca0 100644 --- a/tensorflow/python/ops/init_ops.py +++ b/tensorflow/python/ops/init_ops.py @@ -44,8 +44,10 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import random_ops from tensorflow.python.util.deprecation import deprecated +from tensorflow.python.util.tf_export import tf_export +@tf_export("keras.initializers.Initializer") class Initializer(object): """Initializer base class: all initializers inherit from this class. """ @@ -83,6 +85,8 @@ class Initializer(object): return cls(**config) +@tf_export("keras.initializers.Zeros", "initializers.zeros", + "zeros_initializer") class Zeros(Initializer): """Initializer that generates tensors initialized to 0.""" @@ -98,6 +102,7 @@ class Zeros(Initializer): return {"dtype": self.dtype.name} +@tf_export("keras.initializers.Ones", "initializers.ones", "ones_initializer") class Ones(Initializer): """Initializer that generates tensors initialized to 1.""" @@ -113,6 +118,8 @@ class Ones(Initializer): return {"dtype": self.dtype.name} +@tf_export("keras.initializers.Constant", "initializers.constant", + "constant_initializer") class Constant(Initializer): """Initializer that generates tensors with constant values. @@ -217,6 +224,8 @@ class Constant(Initializer): return {"value": self.value, "dtype": self.dtype.name} +@tf_export("keras.initializers.RandomUniform", "initializers.random_uniform", + "random_uniform_initializer") class RandomUniform(Initializer): """Initializer that generates tensors with a uniform distribution. @@ -252,6 +261,8 @@ class RandomUniform(Initializer): } +@tf_export("keras.initializers.RandomNormal", "initializers.random_normal", + "random_normal_initializer") class RandomNormal(Initializer): """Initializer that generates tensors with a normal distribution. @@ -287,6 +298,8 @@ class RandomNormal(Initializer): } +@tf_export("keras.initializers.TruncatedNormal", + "initializers.truncated_normal", "truncated_normal_initializer") class TruncatedNormal(Initializer): """Initializer that generates a truncated normal distribution. @@ -327,6 +340,8 @@ class TruncatedNormal(Initializer): } +@tf_export("initializers.uniform_unit_scaling", + "uniform_unit_scaling_initializer") class UniformUnitScaling(Initializer): """Initializer that generates tensors without scaling variance. @@ -385,6 +400,8 @@ class UniformUnitScaling(Initializer): return {"factor": self.factor, "seed": self.seed, "dtype": self.dtype.name} +@tf_export("keras.initializers.VarianceScaling", + "initializers.variance_scaling", "variance_scaling_initializer") class VarianceScaling(Initializer): """Initializer capable of adapting its scale to the shape of weights tensors. @@ -464,6 +481,8 @@ class VarianceScaling(Initializer): } +@tf_export("keras.initializers.Orthogonal", "initializers.orthogonal", + "orthogonal_initializer") class Orthogonal(Initializer): """Initializer that generates an orthogonal matrix. @@ -523,6 +542,7 @@ class Orthogonal(Initializer): return {"gain": self.gain, "seed": self.seed, "dtype": self.dtype.name} +@tf_export("keras.initializers.Identity", "initializers.identity") class Identity(Initializer): """Initializer that generates the identity matrix. @@ -570,6 +590,7 @@ identity_initializer = Identity # pylint: enable=invalid-name +@tf_export("glorot_uniform_initializer") def glorot_uniform_initializer(seed=None, dtype=dtypes.float32): """The Glorot uniform initializer, also called Xavier uniform initializer. @@ -593,6 +614,7 @@ def glorot_uniform_initializer(seed=None, dtype=dtypes.float32): scale=1.0, mode="fan_avg", distribution="uniform", seed=seed, dtype=dtype) +@tf_export("glorot_normal_initializer") def glorot_normal_initializer(seed=None, dtype=dtypes.float32): """The Glorot normal initializer, also called Xavier normal initializer. diff --git a/tensorflow/python/ops/io_ops.py b/tensorflow/python/ops/io_ops.py index 670bb9a9c29e8450b101b04ce781dc97ceb78398..5e70b3186f382a0c795b1795b2db27bb2058ee41 100644 --- a/tensorflow/python/ops/io_ops.py +++ b/tensorflow/python/ops/io_ops.py @@ -79,6 +79,7 @@ from tensorflow.python.ops import gen_io_ops # go/tf-wildcard-import # pylint: disable=wildcard-import from tensorflow.python.ops.gen_io_ops import * +from tensorflow.python.util.tf_export import tf_export # pylint: enable=wildcard-import @@ -140,6 +141,7 @@ def _restore_slice(file_pattern, tensor_name, shape_and_slice, tensor_type, preferred_shard, name=name) +@tf_export("ReaderBase") class ReaderBase(object): """Base class for different Reader types, that produce a record every step. @@ -354,6 +356,7 @@ ops.NotDifferentiable("ReaderRestoreState") ops.NotDifferentiable("ReaderReset") +@tf_export("WholeFileReader") class WholeFileReader(ReaderBase): """A Reader that outputs the entire contents of a file as a value. @@ -381,6 +384,7 @@ class WholeFileReader(ReaderBase): ops.NotDifferentiable("WholeFileReader") +@tf_export("TextLineReader") class TextLineReader(ReaderBase): """A Reader that outputs the lines of a file delimited by newlines. @@ -410,6 +414,7 @@ class TextLineReader(ReaderBase): ops.NotDifferentiable("TextLineReader") +@tf_export("FixedLengthRecordReader") class FixedLengthRecordReader(ReaderBase): """A Reader that outputs fixed-length records from a file. @@ -452,6 +457,7 @@ class FixedLengthRecordReader(ReaderBase): ops.NotDifferentiable("FixedLengthRecordReader") +@tf_export("TFRecordReader") class TFRecordReader(ReaderBase): """A Reader that outputs the records from a TFRecords file. @@ -482,6 +488,7 @@ class TFRecordReader(ReaderBase): ops.NotDifferentiable("TFRecordReader") +@tf_export("LMDBReader") class LMDBReader(ReaderBase): """A Reader that outputs the records from a LMDB file. @@ -506,6 +513,7 @@ class LMDBReader(ReaderBase): ops.NotDifferentiable("LMDBReader") +@tf_export("IdentityReader") class IdentityReader(ReaderBase): """A Reader that outputs the queued work as both the key and value. diff --git a/tensorflow/python/ops/linalg_grad.py b/tensorflow/python/ops/linalg_grad.py index 13a32c83d99363e687f7e2365a91c8e453c81c7e..3cbbf3412a2a1bd974354a5819d410b4074ab47d 100644 --- a/tensorflow/python/ops/linalg_grad.py +++ b/tensorflow/python/ops/linalg_grad.py @@ -277,20 +277,28 @@ def _SvdGrad(op, grad_s, grad_u, grad_v): # https://j-towns.github.io/papers/svd-derivative.pdf a = op.inputs[0] a_shape = a.get_shape().with_rank_at_least(2) + grad_s_mat = array_ops.matrix_diag(grad_s) - if op.get_attr("compute_uv"): - # TODO(rmlarsen): Make this work with complex types. - if a.dtype.is_complex: - raise NotImplementedError( - "SVD gradient is not implemented for complex types and " - "compute_uv=True.") - grad_u_shape = grad_u.get_shape().with_rank_at_least(2) - grad_v_shape = grad_v.get_shape().with_rank_at_least(2) - m = a_shape[-2].merge_with(grad_u_shape[-2]) - n = a_shape[-1].merge_with(grad_v_shape[-2]) - batch_shape = a_shape[:-2].merge_with(grad_u_shape[:-2]).merge_with( - grad_v_shape[:-2]) - a_shape = batch_shape.concatenate([m, n]) + if not op.get_attr("compute_uv"): + s, u, v = linalg_ops.svd(a, compute_uv=True) + grad_a = math_ops.matmul(u, math_ops.matmul(grad_s_mat, v, adjoint_b=True)) + grad_a.set_shape(a_shape) + return grad_a + + full_matrices = op.get_attr("full_matrices") + + # TODO(rmlarsen): Make this work with complex types. + if a.dtype.is_complex: + raise NotImplementedError( + "SVD gradient is not implemented for complex types and " + "compute_uv=True.") + grad_u_shape = grad_u.get_shape().with_rank_at_least(2) + grad_v_shape = grad_v.get_shape().with_rank_at_least(2) + m = a_shape[-2].merge_with(grad_u_shape[-2]) + n = a_shape[-1].merge_with(grad_v_shape[-2]) + batch_shape = a_shape[:-2].merge_with(grad_u_shape[:-2]).merge_with( + grad_v_shape[:-2]) + a_shape = batch_shape.concatenate([m, n]) m = a_shape[-2].value n = a_shape[-1].value @@ -300,12 +308,9 @@ def _SvdGrad(op, grad_s, grad_u, grad_v): "SVD gradient has not been implemented for input with unknown " "inner matrix shape.") - if not op.get_attr("compute_uv"): - s, u, v = linalg_ops.svd(a, compute_uv=True, full_matrices=True) - else: - s = op.outputs[0] - u = op.outputs[1] - v = op.outputs[2] + s = op.outputs[0] + u = op.outputs[1] + v = op.outputs[2] use_adjoint = False if m > n: @@ -317,19 +322,7 @@ def _SvdGrad(op, grad_s, grad_u, grad_v): grad_u, grad_v = grad_v, grad_u with ops.control_dependencies([grad_s, grad_u, grad_v]): - grad_s_mat = array_ops.matrix_diag(grad_s) - if not op.get_attr("compute_uv"): - if use_adjoint: - grad_a = math_ops.matmul( - v[..., :, :m], math_ops.matmul(u, grad_s_mat), adjoint_b=True) - else: - grad_a = math_ops.matmul(u, - math_ops.matmul( - grad_s_mat, v[..., :, :m], adjoint_b=True)) - grad_a.set_shape(a_shape) - return grad_a - - if op.get_attr("full_matrices") and abs(m - n) > 1: + if full_matrices and abs(m - n) > 1: raise NotImplementedError( "svd gradient is not implemented for abs(m - n) > 1 " "when full_matrices is True") @@ -371,7 +364,7 @@ def _SvdGrad(op, grad_s, grad_u, grad_v): gv1t_v1 = math_ops.matmul(gv1t, v1) term2_nous = gv1t - math_ops.matmul(gv1t_v1, v1, adjoint_b=True) - if op.get_attr("full_matrices"): + if full_matrices: v2 = v[..., :, m:n] grad_v2 = grad_v[..., :, m:n] diff --git a/tensorflow/python/ops/linalg_ops.py b/tensorflow/python/ops/linalg_ops.py index be9beee633bb7c900b1618c2922b6eff5bf65df0..9803eed6aefe072cbe0841dff2de3f640a440dd5 100644 --- a/tensorflow/python/ops/linalg_ops.py +++ b/tensorflow/python/ops/linalg_ops.py @@ -31,6 +31,7 @@ from tensorflow.python.ops.gen_linalg_ops import * # pylint: enable=wildcard-import from tensorflow.python.util import compat from tensorflow.python.util import deprecation +from tensorflow.python.util.tf_export import tf_export # Names below are lower_case. # pylint: disable=invalid-name @@ -77,6 +78,7 @@ def _RegularizedGramianCholesky(matrix, l2_regularizer, first_kind): return gen_linalg_ops.cholesky(gramian) +@tf_export('cholesky_solve', 'linalg.cholesky_solve') def cholesky_solve(chol, rhs, name=None): """Solves systems of linear eqns `A X = RHS`, given Cholesky factorizations. @@ -119,6 +121,7 @@ def cholesky_solve(chol, rhs, name=None): return x +@tf_export('eye', 'linalg.eye') def eye(num_rows, num_columns=None, batch_shape=None, @@ -188,6 +191,7 @@ def eye(num_rows, return array_ops.matrix_set_diag(zero_matrix, diag_ones) +@tf_export('matrix_solve_ls', 'linalg.lstsq') def matrix_solve_ls(matrix, rhs, l2_regularizer=0.0, fast=True, name=None): r"""Solves one or more linear least-squares problems. @@ -324,6 +328,7 @@ def matrix_solve_ls(matrix, rhs, l2_regularizer=0.0, fast=True, name=None): # pylint: enable=protected-access +@tf_export('self_adjoint_eig', 'linalg.eigh') def self_adjoint_eig(tensor, name=None): """Computes the eigen decomposition of a batch of self-adjoint matrices. @@ -346,6 +351,7 @@ def self_adjoint_eig(tensor, name=None): return e, v +@tf_export('self_adjoint_eigvals', 'linalg.eigvalsh') def self_adjoint_eigvals(tensor, name=None): """Computes the eigenvalues of one or more self-adjoint matrices. @@ -368,6 +374,7 @@ def self_adjoint_eigvals(tensor, name=None): return e +@tf_export('svd', 'linalg.svd') def svd(tensor, full_matrices=False, compute_uv=True, name=None): r"""Computes the singular value decompositions of one or more matrices. @@ -439,6 +446,7 @@ def svd(tensor, full_matrices=False, compute_uv=True, name=None): # pylint: disable=redefined-builtin +@tf_export('norm', 'linalg.norm') @deprecation.deprecated_args( None, 'keep_dims is deprecated, use keepdims instead', 'keep_dims') def norm(tensor, diff --git a/tensorflow/python/ops/list_ops.py b/tensorflow/python/ops/list_ops.py index 6b31c0063983d19ce281183ec57a230c5909e5b1..bba59ebcef9c7caf1a53d724767999ae7ac079e5 100644 --- a/tensorflow/python/ops/list_ops.py +++ b/tensorflow/python/ops/list_ops.py @@ -19,7 +19,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_list_ops # go/tf-wildcard-import # pylint: disable=wildcard-import @@ -28,28 +30,30 @@ from tensorflow.python.ops.gen_list_ops import * @ops.RegisterGradient("TensorListPushBack") -def _PushBackGradient(op, dresult): +def _PushBackGrad(op, dresult): return gen_list_ops.tensor_list_pop_back( dresult, element_dtype=op.get_attr("element_dtype")) @ops.RegisterGradient("TensorListPopBack") -def _PopBackGradient(unused_op, dlist, delement): +def _PopBackGrad(op, dlist, delement): if dlist is None: dlist = gen_list_ops.empty_tensor_list( element_dtype=delement.dtype, - element_shape=-1) + element_shape=gen_list_ops.tensor_list_element_shape( + op.outputs[0], shape_type=dtypes.int32)) return gen_list_ops.tensor_list_push_back(dlist, delement) @ops.RegisterGradient("TensorListStack") -def _TensorListStack(unused_op, dtensor): +def _TensorListStackGrad(unused_op, dtensor): return gen_list_ops.tensor_list_from_tensor(dtensor, element_shape=dtensor.shape[1:]) @ops.RegisterGradient("TensorListFromTensor") -def _TensorListFromTensor(op, dlist): +def _TensorListFromTensorGrad(op, dlist): + """Gradient for TensorListFromTensor.""" if op.inputs[0].shape[0] is not None: num_elements = op.inputs[0].shape[0] else: @@ -57,7 +61,34 @@ def _TensorListFromTensor(op, dlist): if dlist is None: dlist = gen_list_ops.empty_tensor_list( element_dtype=op.inputs[0].dtype, - element_shape=-1) + element_shape=gen_list_ops.tensor_list_element_shape( + op.outputs[0], shape_type=dtypes.int32)) return gen_list_ops.tensor_list_stack( dlist, element_dtype=op.inputs[0].dtype, num_elements=num_elements) + + +@ops.RegisterGradient("TensorListGetItem") +def _TensorListGetItemGrad(op, ditem): + """Gradient for TensorListGetItem.""" + list_size = gen_list_ops.tensor_list_length(op.inputs[0]) + list_grad = gen_list_ops.tensor_list_set_item( + gen_list_ops.tensor_list_reserve( + gen_list_ops.tensor_list_element_shape(op.inputs[0], + shape_type=dtypes.int32), + list_size, element_dtype=ditem.dtype), + index=op.inputs[1], + item=ditem) + index_grad = None + return list_grad, index_grad + + +@ops.RegisterGradient("TensorListSetItem") +def _TensorListSetItemGrad(op, dlist): + _, index, item = op.inputs + list_grad = gen_list_ops.tensor_list_set_item( + dlist, index=index, item=array_ops.zeros_like(item)) + index_grad = None + element_grad = gen_list_ops.tensor_list_get_item( + dlist, index, element_dtype=item.dtype) + return list_grad, index_grad, element_grad diff --git a/tensorflow/python/ops/logging_ops.py b/tensorflow/python/ops/logging_ops.py index 51ab2aec2298a9072c90c226992f122a804ec02e..eadbc1b7c3b6e66aa76c9afd860b2274ac1976ae 100644 --- a/tensorflow/python/ops/logging_ops.py +++ b/tensorflow/python/ops/logging_ops.py @@ -27,6 +27,7 @@ from tensorflow.python.ops import gen_logging_ops from tensorflow.python.ops.gen_logging_ops import * # pylint: enable=wildcard-import from tensorflow.python.util.deprecation import deprecated +from tensorflow.python.util.tf_export import tf_export # The python wrapper for Assert is in control_flow_ops, as the Assert # call relies on certain conditionals for its dependencies. Use @@ -35,6 +36,7 @@ from tensorflow.python.util.deprecation import deprecated # Assert and Print are special symbols in python, so we must # use an upper-case version of them. +@tf_export("Print") def Print(input_, data, message=None, first_n=None, summarize=None, name=None): """Prints a list of tensors. diff --git a/tensorflow/python/ops/lookup_ops.py b/tensorflow/python/ops/lookup_ops.py index 333e36873af31a7a89d59d02af87d86227446bd0..f539a7bb68da57e31746bc80fb25339a03a4fafe 100644 --- a/tensorflow/python/ops/lookup_ops.py +++ b/tensorflow/python/ops/lookup_ops.py @@ -40,8 +40,10 @@ from tensorflow.python.ops.gen_lookup_ops import * # pylint: enable=wildcard-import from tensorflow.python.util import compat from tensorflow.python.util.deprecation import deprecated +from tensorflow.python.util.tf_export import tf_export +@tf_export("initialize_all_tables") @deprecated(None, "Use `tf.tables_initializer` instead.") def initialize_all_tables(name="init_all_tables"): """Returns an Op that initializes all tables of the default graph. @@ -56,6 +58,7 @@ def initialize_all_tables(name="init_all_tables"): return tables_initializer(name) +@tf_export("tables_initializer") def tables_initializer(name="init_all_tables"): """Returns an Op that initializes all tables of the default graph. diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py index cfdfa09757654aeb10426e1361176baca38d7b6a..b8e8207bb24ad64d9e07a4585501a10741f5c9ab 100644 --- a/tensorflow/python/ops/math_ops.py +++ b/tensorflow/python/ops/math_ops.py @@ -172,6 +172,7 @@ from tensorflow.python.ops.gen_math_ops import * # pylint: enable=wildcard-import from tensorflow.python.util import compat from tensorflow.python.util import deprecation +from tensorflow.python.util.tf_export import tf_export # Aliases for some automatically-generated names. linspace = gen_math_ops.lin_space @@ -190,6 +191,7 @@ def _set_doc(doc): # pylint: disable=redefined-builtin +@tf_export("argmax") @deprecation.deprecated_args(None, "Use the `axis` argument instead", "dimension") @_set_doc( @@ -209,6 +211,7 @@ def argmax(input, return gen_math_ops.arg_max(input, axis, name=name, output_type=output_type) +@tf_export("argmin") @deprecation.deprecated_args(None, "Use the `axis` argument instead", "dimension") @_set_doc( @@ -233,6 +236,7 @@ def argmin(input, # pylint: disable=anomalous-backslash-in-string,protected-access # pylint: disable=g-docstring-has-escape +@tf_export("abs") def abs(x, name=None): r"""Computes the absolute value of a tensor. @@ -307,6 +311,7 @@ class DivideDelegateWithName(object): return _div_python2(self.x, y, self.name) +@tf_export("divide") def divide(x, y, name=None): """Computes Python style division of `x` by `y`.""" @@ -318,6 +323,7 @@ def divide(x, y, name=None): return x / y +@tf_export("multiply") def multiply(x, y, name=None): return gen_math_ops._mul(x, y, name) @@ -337,6 +343,7 @@ _mul.__doc__ = ( gen_math_ops._mul.__doc__ + ("" if _mul.__doc__ is None else _mul.__doc__)) +@tf_export("subtract") def subtract(x, y, name=None): return gen_math_ops._sub(x, y, name) @@ -357,6 +364,7 @@ _sub.__doc__ = ( # pylint: disable=g-docstring-has-escape +@tf_export("negative") def negative(x, name=None): """Computes numerical negative value element-wise. @@ -405,6 +413,7 @@ def _neg(x, name=None): # pylint: enable=g-docstring-has-escape +@tf_export("sign") def sign(x, name=None): """Returns an element-wise indication of the sign of a number. @@ -435,6 +444,7 @@ def sign(x, name=None): return gen_math_ops.sign(x, name=name) +@tf_export("square") def square(x, name=None): r"""Computes square of x element-wise. @@ -457,6 +467,7 @@ def square(x, name=None): return gen_math_ops.square(x, name=name) +@tf_export("sqrt") def sqrt(x, name=None): r"""Computes square root of x element-wise. @@ -479,6 +490,7 @@ def sqrt(x, name=None): return gen_math_ops.sqrt(x, name=name) +@tf_export("erf") def erf(x, name=None): """Computes the Gauss error function of `x` element-wise. @@ -499,6 +511,7 @@ def erf(x, name=None): return gen_math_ops.erf(x, name=name) +@tf_export("scalar_mul") def scalar_mul(scalar, x): """Multiplies a scalar times a `Tensor` or `IndexedSlices` object. @@ -528,6 +541,7 @@ def scalar_mul(scalar, x): raise ValueError("Only scalar multiply works, got shape %s" % shape) +@tf_export("pow") def pow(x, y, name=None): r"""Computes the power of one value to another. @@ -555,6 +569,7 @@ def pow(x, y, name=None): # pylint: disable=redefined-builtin,redefined-outer-name +@tf_export("complex") def complex(real, imag, name=None): r"""Converts two real numbers to a complex number. @@ -596,6 +611,7 @@ def complex(real, imag, name=None): return gen_math_ops._complex(real, imag, Tout=Tout, name=name) +@tf_export("real") def real(input, name=None): r"""Returns the real part of a complex (or real) tensor. @@ -626,6 +642,7 @@ def real(input, name=None): return input +@tf_export("imag") def imag(input, name=None): r"""Returns the imaginary part of a complex (or real) tensor. @@ -655,6 +672,7 @@ def imag(input, name=None): return array_ops.zeros_like(input) +@tf_export("angle") def angle(input, name=None): r"""Returns the element-wise argument of a complex (or real) tensor. @@ -693,6 +711,7 @@ def angle(input, name=None): # pylint: enable=redefined-outer-name,redefined-builtin +@tf_export("round") def round(x, name=None): """Rounds the values of a tensor to the nearest integer, element-wise. @@ -719,6 +738,7 @@ def round(x, name=None): return gen_math_ops.round(x, name=name) +@tf_export("cast") def cast(x, dtype, name=None): """Casts a tensor to a new type. @@ -759,6 +779,7 @@ def cast(x, dtype, name=None): return gen_math_ops.cast(x, base_type, name=name) +@tf_export("saturate_cast") def saturate_cast(value, dtype, name=None): """Performs a safe saturating cast of `value` to `dtype`. @@ -792,6 +813,7 @@ def saturate_cast(value, dtype, name=None): return cast(value, dtype, name=name) +@tf_export("to_float") def to_float(x, name="ToFloat"): """Casts a tensor to type `float32`. @@ -808,6 +830,7 @@ def to_float(x, name="ToFloat"): return cast(x, dtypes.float32, name=name) +@tf_export("to_double") def to_double(x, name="ToDouble"): """Casts a tensor to type `float64`. @@ -824,6 +847,7 @@ def to_double(x, name="ToDouble"): return cast(x, dtypes.float64, name=name) +@tf_export("to_int32") def to_int32(x, name="ToInt32"): """Casts a tensor to type `int32`. @@ -840,6 +864,7 @@ def to_int32(x, name="ToInt32"): return cast(x, dtypes.int32, name=name) +@tf_export("to_int64") def to_int64(x, name="ToInt64"): """Casts a tensor to type `int64`. @@ -856,6 +881,7 @@ def to_int64(x, name="ToInt64"): return cast(x, dtypes.int64, name=name) +@tf_export("to_bfloat16") def to_bfloat16(x, name="ToBFloat16"): """Casts a tensor to type `bfloat16`. @@ -1029,6 +1055,7 @@ def _div_python2(x, y, name=None): return gen_math_ops._floor_div(x, y, name=name) +@tf_export("truediv") def truediv(x, y, name=None): """Divides x / y elementwise (using Python 3 division operator semantics). @@ -1060,6 +1087,7 @@ def truediv(x, y, name=None): return _truediv_python3(x, y, name) +@tf_export("div") def div(x, y, name=None): """Divides x / y elementwise (using Python 2 division operator semantics). @@ -1087,6 +1115,7 @@ mod = gen_math_ops._floor_mod # TODO(aselle): Deprecate this once all internal functionality uses # tf.truncatediv +@tf_export("floordiv") def floordiv(x, y, name=None): """Divides `x / y` elementwise, rounding toward the most negative integer. @@ -1157,6 +1186,7 @@ _OverrideBinaryOperatorHelper(gen_math_ops._floor_mod, "mod") _OverrideBinaryOperatorHelper(pow, "pow") +@tf_export("logical_xor") def logical_xor(x, y, name="LogicalXor"): """x ^ y = (x | y) & ~(x & y).""" # TODO(alemi) Make this a cwise op if people end up relying on it. @@ -1176,6 +1206,7 @@ ops.Tensor._override_operator("__gt__", gen_math_ops.greater) ops.Tensor._override_operator("__ge__", gen_math_ops.greater_equal) +@tf_export("range") def range(start, limit=None, delta=1, dtype=None, name="range"): """Creates a sequence of numbers. @@ -1281,6 +1312,7 @@ def _may_reduce_to_scalar(keepdims, axis, reduction_indices, output): return output +@tf_export("reduce_sum") @deprecation.deprecated_args( None, "keep_dims is deprecated, use keepdims instead", "keep_dims") def reduce_sum(input_tensor, @@ -1341,6 +1373,7 @@ def reduce_sum(input_tensor, name=name)) +@tf_export("count_nonzero") @deprecation.deprecated_args( None, "keep_dims is deprecated, use keepdims instead", "keep_dims") def count_nonzero(input_tensor, @@ -1407,6 +1440,7 @@ def count_nonzero(input_tensor, dtype=dtype) +@tf_export("reduce_mean") @deprecation.deprecated_args( None, "keep_dims is deprecated, use keepdims instead", "keep_dims") def reduce_mean(input_tensor, @@ -1478,6 +1512,7 @@ def reduce_mean(input_tensor, name=name)) +@tf_export("reduce_prod") @deprecation.deprecated_args( None, "keep_dims is deprecated, use keepdims instead", "keep_dims") def reduce_prod(input_tensor, @@ -1527,6 +1562,7 @@ def reduce_prod(input_tensor, name=name)) +@tf_export("reduce_min") @deprecation.deprecated_args( None, "keep_dims is deprecated, use keepdims instead", "keep_dims") def reduce_min(input_tensor, @@ -1575,6 +1611,7 @@ def reduce_min(input_tensor, name=name)) +@tf_export("reduce_max") @deprecation.deprecated_args( None, "keep_dims is deprecated, use keepdims instead", "keep_dims") def reduce_max(input_tensor, @@ -1623,6 +1660,7 @@ def reduce_max(input_tensor, name=name)) +@tf_export("reduce_all") @deprecation.deprecated_args( None, "keep_dims is deprecated, use keepdims instead", "keep_dims") def reduce_all(input_tensor, @@ -1680,6 +1718,7 @@ def reduce_all(input_tensor, name=name)) +@tf_export("reduce_any") @deprecation.deprecated_args( None, "keep_dims is deprecated, use keepdims instead", "keep_dims") def reduce_any(input_tensor, @@ -1737,6 +1776,7 @@ def reduce_any(input_tensor, name=name)) +@tf_export("reduce_logsumexp") @deprecation.deprecated_args( None, "keep_dims is deprecated, use keepdims instead", "keep_dims") def reduce_logsumexp(input_tensor, @@ -1810,6 +1850,7 @@ def reduce_logsumexp(input_tensor, return _may_reduce_to_scalar(keepdims, axis, reduction_indices, result) +@tf_export("trace", "linalg.trace") def trace(x, name=None): """Compute the trace of a tensor `x`. @@ -1851,6 +1892,7 @@ def trace(x, name=None): return reduce_sum(array_ops.matrix_diag_part(x), [-1], name=name) +@tf_export("matmul") def matmul(a, b, transpose_a=False, @@ -2103,6 +2145,7 @@ def _as_indexed_slices_list(inputs, optimize=True): return casted_outputs +@tf_export("add_n") def add_n(inputs, name=None): """Adds all input tensors element-wise. @@ -2132,6 +2175,7 @@ def add_n(inputs, name=None): return gen_math_ops._add_n(inputs, name=name) +@tf_export("accumulate_n") def accumulate_n(inputs, shape=None, tensor_dtype=None, name=None): """Returns the element-wise sum of a list of tensors. @@ -2216,6 +2260,7 @@ def accumulate_n(inputs, shape=None, tensor_dtype=None, name=None): ref, var_name=var.op.name, name=name) +@tf_export("nn.sigmoid", "sigmoid") def sigmoid(x, name=None): """Computes sigmoid of `x` element-wise. @@ -2238,6 +2283,7 @@ def sigmoid(x, name=None): return gen_math_ops._sigmoid(x, name=name) +@tf_export("log_sigmoid") def log_sigmoid(x, name=None): """Computes log sigmoid of `x` element-wise. @@ -2256,6 +2302,7 @@ def log_sigmoid(x, name=None): return gen_math_ops._neg(gen_nn_ops.softplus(-x), name=name) +@tf_export("nn.tanh", "tanh") def tanh(x, name=None): """Computes hyperbolic tangent of `x` element-wise. @@ -2276,6 +2323,7 @@ def tanh(x, name=None): return gen_math_ops._tanh(x, name=name) +@tf_export("bincount") def bincount(arr, weights=None, minlength=None, @@ -2322,6 +2370,7 @@ def bincount(arr, return gen_math_ops.bincount(arr, output_size, weights) +@tf_export("cumsum") def cumsum(x, axis=0, exclusive=False, reverse=False, name=None): """Compute the cumulative sum of the tensor `x` along `axis`. @@ -2373,6 +2422,7 @@ def cumsum(x, axis=0, exclusive=False, reverse=False, name=None): x, axis, exclusive=exclusive, reverse=reverse, name=name) +@tf_export("cumprod") def cumprod(x, axis=0, exclusive=False, reverse=False, name=None): """Compute the cumulative product of the tensor `x` along `axis`. @@ -2424,6 +2474,7 @@ def cumprod(x, axis=0, exclusive=False, reverse=False, name=None): x, axis, exclusive=exclusive, reverse=reverse, name=name) +@tf_export("conj") def conj(x, name=None): r"""Returns the complex conjugate of a complex number. @@ -2502,6 +2553,7 @@ def reduced_shape(input_shape, axes): ]) # [1, 1] +@tf_export("sparse_segment_sum") def sparse_segment_sum(data, indices, segment_ids, name=None, num_segments=None): r"""Computes the sum along sparse segments of a tensor. @@ -2576,6 +2628,7 @@ def sparse_segment_sum(data, indices, segment_ids, name=None, name=name) +@tf_export("sparse_segment_mean") def sparse_segment_mean(data, indices, segment_ids, name=None, num_segments=None): r"""Computes the mean along sparse segments of a tensor. @@ -2619,6 +2672,7 @@ def sparse_segment_mean(data, indices, segment_ids, name=None, name=name) +@tf_export("sparse_segment_sqrt_n") def sparse_segment_sqrt_n(data, indices, segment_ids, name=None, num_segments=None): r"""Computes the sum along sparse segments of a tensor divided by the sqrt(N). @@ -2655,6 +2709,7 @@ def sparse_segment_sqrt_n(data, indices, segment_ids, name=None, name=name) +@tf_export("tensordot", "linalg.tensordot") def tensordot(a, b, axes, name=None): r"""Tensor contraction of a and b along specified axes. @@ -2772,10 +2827,14 @@ def tensordot(a, b, axes, name=None): """Generates two sets of contraction axes for the two tensor arguments.""" a_shape = a.get_shape() if isinstance(axes, compat.integral_types): - if axes < 1: - raise ValueError("'axes' must be at least 1.") + if axes < 0: + raise ValueError("'axes' must be at least 0.") if a_shape.ndims is not None: - return range(a_shape.ndims - axes, a_shape.ndims), range(axes) + if axes > a_shape.ndims: + raise ValueError("'axes' must not be larger than the number of " + "dimensions of tensor %s." % a) + return (list(xrange(a_shape.ndims - axes, a_shape.ndims)), + list(xrange(axes))) else: rank = array_ops.rank(a) return (range(rank - axes, rank, dtype=dtypes.int32), diff --git a/tensorflow/python/ops/metrics_impl.py b/tensorflow/python/ops/metrics_impl.py index 25e1613a651cf3bc144e121b61f1edd64a16596e..7776ff08c4f55c43947010f313d8167596b15db7 100644 --- a/tensorflow/python/ops/metrics_impl.py +++ b/tensorflow/python/ops/metrics_impl.py @@ -34,6 +34,7 @@ from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import weights_broadcast_ops from tensorflow.python.util.deprecation import deprecated +from tensorflow.python.util.tf_export import tf_export def metric_variable(shape, dtype, validate_shape=True, name=None): @@ -99,27 +100,29 @@ def _remove_squeezable_dimensions(predictions, labels, weights): # Use dynamic rank. weights_rank_tensor = array_ops.rank(weights) rank_diff = weights_rank_tensor - array_ops.rank(predictions) + def _maybe_expand_weights(): return control_flow_ops.cond( math_ops.equal(rank_diff, -1), - lambda: array_ops.expand_dims(weights, [-1]), - lambda: weights) + lambda: array_ops.expand_dims(weights, [-1]), lambda: weights) + # Don't attempt squeeze if it will fail based on static check. if ((weights_rank is not None) and (not weights_shape.dims[-1].is_compatible_with(1))): maybe_squeeze_weights = lambda: weights else: maybe_squeeze_weights = lambda: array_ops.squeeze(weights, [-1]) + def _maybe_adjust_weights(): return control_flow_ops.cond( - math_ops.equal(rank_diff, 1), - maybe_squeeze_weights, + math_ops.equal(rank_diff, 1), maybe_squeeze_weights, _maybe_expand_weights) + # If weights are scalar, do nothing. Otherwise, try to add or remove a # dimension to match predictions. weights = control_flow_ops.cond( - math_ops.equal(weights_rank_tensor, 0), - lambda: weights, _maybe_adjust_weights) + math_ops.equal(weights_rank_tensor, 0), lambda: weights, + _maybe_adjust_weights) return predictions, labels, weights @@ -164,14 +167,14 @@ def _maybe_expand_labels(labels, predictions): if predictions_rank == labels_rank + 1: return array_ops.expand_dims(labels, -1, name=scope) raise ValueError( - 'Unexpected labels shape %s for predictions shape %s.' % ( - labels.get_shape(), predictions.get_shape())) + 'Unexpected labels shape %s for predictions shape %s.' % + (labels.get_shape(), predictions.get_shape())) # Otherwise, use dynamic shape. return control_flow_ops.cond( - math_ops.equal(array_ops.rank(predictions), array_ops.rank(labels) + 1), - lambda: array_ops.expand_dims(labels, -1, name=scope), - lambda: labels) + math_ops.equal(array_ops.rank(predictions), + array_ops.rank(labels) + 1), + lambda: array_ops.expand_dims(labels, -1, name=scope), lambda: labels) def _safe_div(numerator, denominator, name): @@ -262,8 +265,12 @@ def _streaming_confusion_matrix(labels, predictions, num_classes, weights=None): return total_cm, update_op -def mean(values, weights=None, metrics_collections=None, - updates_collections=None, name=None): +@tf_export('metrics.mean') +def mean(values, + weights=None, + metrics_collections=None, + updates_collections=None, + name=None): """Computes the (weighted) mean of the given values. The `mean` function creates two local variables, `total` and `count` @@ -337,8 +344,13 @@ def mean(values, weights=None, metrics_collections=None, return mean_t, update_op -def accuracy(labels, predictions, weights=None, metrics_collections=None, - updates_collections=None, name=None): +@tf_export('metrics.accuracy') +def accuracy(labels, + predictions, + weights=None, + metrics_collections=None, + updates_collections=None, + name=None): """Calculates how often `predictions` matches `labels`. The `accuracy` function creates two local variables, `total` and @@ -392,12 +404,15 @@ def accuracy(labels, predictions, weights=None, metrics_collections=None, if labels.dtype != predictions.dtype: predictions = math_ops.cast(predictions, labels.dtype) is_correct = math_ops.to_float(math_ops.equal(predictions, labels)) - return mean(is_correct, weights, metrics_collections, - updates_collections, name or 'accuracy') + return mean(is_correct, weights, metrics_collections, updates_collections, + name or 'accuracy') -def _confusion_matrix_at_thresholds( - labels, predictions, thresholds, weights=None, includes=None): +def _confusion_matrix_at_thresholds(labels, + predictions, + thresholds, + weights=None, + includes=None): """Computes true_positives, false_negatives, true_negatives, false_positives. This function creates up to four local variables, `true_positives`, @@ -495,8 +510,8 @@ def _confusion_matrix_at_thresholds( if weights is not None: weights = weights_broadcast_ops.broadcast_weights( math_ops.to_float(weights), predictions) - weights_tiled = array_ops.tile(array_ops.reshape( - weights, [1, -1]), [num_thresholds, 1]) + weights_tiled = array_ops.tile( + array_ops.reshape(weights, [1, -1]), [num_thresholds, 1]) thresh_tiled.get_shape().assert_is_compatible_with( weights_tiled.get_shape()) else: @@ -512,8 +527,9 @@ def _confusion_matrix_at_thresholds( math_ops.logical_and(label_is_pos, pred_is_pos)) if weights_tiled is not None: is_true_positive *= weights_tiled - update_ops['tp'] = state_ops.assign_add( - true_p, math_ops.reduce_sum(is_true_positive, 1)) + update_ops['tp'] = state_ops.assign_add(true_p, + math_ops.reduce_sum( + is_true_positive, 1)) values['tp'] = true_p if 'fn' in includes: @@ -523,8 +539,9 @@ def _confusion_matrix_at_thresholds( math_ops.logical_and(label_is_pos, pred_is_neg)) if weights_tiled is not None: is_false_negative *= weights_tiled - update_ops['fn'] = state_ops.assign_add( - false_n, math_ops.reduce_sum(is_false_negative, 1)) + update_ops['fn'] = state_ops.assign_add(false_n, + math_ops.reduce_sum( + is_false_negative, 1)) values['fn'] = false_n if 'tn' in includes: @@ -534,8 +551,9 @@ def _confusion_matrix_at_thresholds( math_ops.logical_and(label_is_neg, pred_is_neg)) if weights_tiled is not None: is_true_negative *= weights_tiled - update_ops['tn'] = state_ops.assign_add( - true_n, math_ops.reduce_sum(is_true_negative, 1)) + update_ops['tn'] = state_ops.assign_add(true_n, + math_ops.reduce_sum( + is_true_negative, 1)) values['tn'] = true_n if 'fp' in includes: @@ -545,16 +563,24 @@ def _confusion_matrix_at_thresholds( math_ops.logical_and(label_is_neg, pred_is_pos)) if weights_tiled is not None: is_false_positive *= weights_tiled - update_ops['fp'] = state_ops.assign_add( - false_p, math_ops.reduce_sum(is_false_positive, 1)) + update_ops['fp'] = state_ops.assign_add(false_p, + math_ops.reduce_sum( + is_false_positive, 1)) values['fp'] = false_p return values, update_ops -def auc(labels, predictions, weights=None, num_thresholds=200, - metrics_collections=None, updates_collections=None, - curve='ROC', name=None, summation_method='trapezoidal'): +@tf_export('metrics.auc') +def auc(labels, + predictions, + weights=None, + num_thresholds=200, + metrics_collections=None, + updates_collections=None, + curve='ROC', + name=None, + summation_method='trapezoidal'): """Computes the approximate AUC via a Riemann sum. The `auc` function creates four local variables, `true_positives`, @@ -622,14 +648,14 @@ def auc(labels, predictions, weights=None, num_thresholds=200, raise RuntimeError('tf.metrics.auc is not supported when eager execution ' 'is enabled.') - with variable_scope.variable_scope( - name, 'auc', (labels, predictions, weights)): + with variable_scope.variable_scope(name, 'auc', + (labels, predictions, weights)): if curve != 'ROC' and curve != 'PR': - raise ValueError('curve must be either ROC or PR, %s unknown' % - (curve)) + raise ValueError('curve must be either ROC or PR, %s unknown' % (curve)) kepsilon = 1e-7 # to account for floating point imprecisions - thresholds = [(i + 1) * 1.0 / (num_thresholds - 1) - for i in range(num_thresholds-2)] + thresholds = [ + (i + 1) * 1.0 / (num_thresholds - 1) for i in range(num_thresholds - 2) + ] thresholds = [0.0 - kepsilon] + thresholds + [1.0 + kepsilon] values, update_ops = _confusion_matrix_at_thresholds( @@ -637,6 +663,7 @@ def auc(labels, predictions, weights=None, num_thresholds=200, # Add epsilons to avoid dividing by 0. epsilon = 1.0e-6 + def compute_auc(tp, fn, tn, fp, name): """Computes the roc-auc or pr-auc based on confusion counts.""" rec = math_ops.div(tp + epsilon, tp + fn + epsilon) @@ -667,11 +694,10 @@ def auc(labels, predictions, weights=None, num_thresholds=200, raise ValueError('Invalid summation_method: %s' % summation_method) # sum up the areas of all the trapeziums - auc_value = compute_auc( - values['tp'], values['fn'], values['tn'], values['fp'], 'value') - update_op = compute_auc( - update_ops['tp'], update_ops['fn'], update_ops['tn'], update_ops['fp'], - 'update_op') + auc_value = compute_auc(values['tp'], values['fn'], values['tn'], + values['fp'], 'value') + update_op = compute_auc(update_ops['tp'], update_ops['fn'], + update_ops['tn'], update_ops['fp'], 'update_op') if metrics_collections: ops.add_to_collections(metrics_collections, auc_value) @@ -682,7 +708,10 @@ def auc(labels, predictions, weights=None, num_thresholds=200, return auc_value, update_op -def mean_absolute_error(labels, predictions, weights=None, +@tf_export('metrics.mean_absolute_error') +def mean_absolute_error(labels, + predictions, + weights=None, metrics_collections=None, updates_collections=None, name=None): @@ -740,7 +769,11 @@ def mean_absolute_error(labels, predictions, weights=None, updates_collections, name or 'mean_absolute_error') -def mean_cosine_distance(labels, predictions, dim, weights=None, +@tf_export('metrics.mean_cosine_distance') +def mean_cosine_distance(labels, + predictions, + dim, + weights=None, metrics_collections=None, updates_collections=None, name=None): @@ -796,10 +829,8 @@ def mean_cosine_distance(labels, predictions, dim, weights=None, radial_diffs, reduction_indices=[ dim, ], keepdims=True) - mean_distance, update_op = mean(radial_diffs, weights, - None, - None, - name or 'mean_cosine_distance') + mean_distance, update_op = mean(radial_diffs, weights, None, None, name or + 'mean_cosine_distance') mean_distance = math_ops.subtract(1.0, mean_distance) update_op = math_ops.subtract(1.0, update_op) @@ -812,6 +843,7 @@ def mean_cosine_distance(labels, predictions, dim, weights=None, return mean_distance, update_op +@tf_export('metrics.mean_per_class_accuracy') def mean_per_class_accuracy(labels, predictions, num_classes, @@ -824,8 +856,8 @@ def mean_per_class_accuracy(labels, Calculates the accuracy for each class, then takes the mean of that. For estimation of the metric over a stream of data, the function creates an - `update_op` operation that updates these variables and returns the - `mean_accuracy`. + `update_op` operation that updates the accuracy of each class and returns + them. If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. @@ -836,8 +868,8 @@ def mean_per_class_accuracy(labels, shape is [batch size] and type `int32` or `int64`. The tensor will be flattened if its rank > 1. num_classes: The possible number of labels the prediction task can - have. This value must be provided, since a confusion matrix of - dimension = [num_classes, num_classes] will be allocated. + have. This value must be provided, since two variables with shape = + [num_classes] will be allocated. weights: Optional `Tensor` whose rank is either 0, or the same rank as `labels`, and must be broadcastable to `labels` (i.e., all dimensions must be either `1`, or the same as the corresponding `labels` dimension). @@ -850,7 +882,7 @@ def mean_per_class_accuracy(labels, Returns: mean_accuracy: A `Tensor` representing the mean per class accuracy. - update_op: An operation that increments the confusion matrix. + update_op: An operation that updates the accuracy tensor. Raises: ValueError: If `predictions` and `labels` have mismatched shapes, or if @@ -865,27 +897,43 @@ def mean_per_class_accuracy(labels, with variable_scope.variable_scope(name, 'mean_accuracy', (predictions, labels, weights)): + labels = math_ops.to_int64(labels) + + # Flatten the input if its rank > 1. + if labels.get_shape().ndims > 1: + labels = array_ops.reshape(labels, [-1]) + + if predictions.get_shape().ndims > 1: + predictions = array_ops.reshape(predictions, [-1]) + # Check if shape is compatible. predictions.get_shape().assert_is_compatible_with(labels.get_shape()) - total_cm, update_op = _streaming_confusion_matrix( - labels, predictions, num_classes, weights=weights) + total = metric_variable([num_classes], dtypes.float32, name='total') + count = metric_variable([num_classes], dtypes.float32, name='count') - def compute_mean_accuracy(name): - """Compute the mean per class accuracy via the confusion matrix.""" - per_row_sum = math_ops.to_float(math_ops.reduce_sum(total_cm, 1)) - cm_diag = math_ops.to_float(array_ops.diag_part(total_cm)) - denominator = per_row_sum + ones = array_ops.ones([array_ops.size(labels)], dtypes.float32) - # If the value of the denominator is 0, set it to 1 to avoid - # zero division. - denominator = array_ops.where( - math_ops.greater(denominator, 0), denominator, - array_ops.ones_like(denominator)) - accuracies = math_ops.div(cm_diag, denominator) - return math_ops.reduce_mean(accuracies, name=name) + if labels.dtype != predictions.dtype: + predictions = math_ops.cast(predictions, labels.dtype) + is_correct = math_ops.to_float(math_ops.equal(predictions, labels)) + + if weights is not None: + if weights.get_shape().ndims > 1: + weights = array_ops.reshape(weights, [-1]) + weights = math_ops.to_float(weights) - mean_accuracy_v = compute_mean_accuracy('mean_accuracy') + is_correct = is_correct * weights + ones = ones * weights + + update_total_op = state_ops.scatter_add(total, labels, ones) + update_count_op = state_ops.scatter_add(count, labels, is_correct) + + per_class_accuracy = _safe_div(count, total, None) + + mean_accuracy_v = math_ops.reduce_mean( + per_class_accuracy, name='mean_accuracy') + update_op = _safe_div(update_count_op, update_total_op, name='update_op') if metrics_collections: ops.add_to_collections(metrics_collections, mean_accuracy_v) @@ -896,6 +944,7 @@ def mean_per_class_accuracy(labels, return mean_accuracy_v, update_op +@tf_export('metrics.mean_iou') def mean_iou(labels, predictions, num_classes, @@ -951,13 +1000,14 @@ def mean_iou(labels, raise RuntimeError('tf.metrics.mean_iou is not supported when ' 'eager execution is enabled.') - with variable_scope.variable_scope( - name, 'mean_iou', (predictions, labels, weights)): + with variable_scope.variable_scope(name, 'mean_iou', + (predictions, labels, weights)): # Check if shape is compatible. predictions.get_shape().assert_is_compatible_with(labels.get_shape()) total_cm, update_op = _streaming_confusion_matrix(labels, predictions, num_classes, weights) + def compute_mean_iou(name): """Compute the mean intersection-over-union via the confusion matrix.""" sum_over_row = math_ops.to_float(math_ops.reduce_sum(total_cm, 0)) @@ -968,22 +1018,21 @@ def mean_iou(labels, # The mean is only computed over classes that appear in the # label or prediction tensor. If the denominator is 0, we need to # ignore the class. - num_valid_entries = math_ops.reduce_sum(math_ops.cast( - math_ops.not_equal(denominator, 0), dtype=dtypes.float32)) + num_valid_entries = math_ops.reduce_sum( + math_ops.cast( + math_ops.not_equal(denominator, 0), dtype=dtypes.float32)) # If the value of the denominator is 0, set it to 1 to avoid # zero division. denominator = array_ops.where( - math_ops.greater(denominator, 0), - denominator, + math_ops.greater(denominator, 0), denominator, array_ops.ones_like(denominator)) iou = math_ops.div(cm_diag, denominator) # If the number of valid entries is 0 (no classes) we return 0. result = array_ops.where( math_ops.greater(num_valid_entries, 0), - math_ops.reduce_sum(iou, name=name) / num_valid_entries, - 0) + math_ops.reduce_sum(iou, name=name) / num_valid_entries, 0) return result mean_iou_v = compute_mean_iou('mean_iou') @@ -997,7 +1046,11 @@ def mean_iou(labels, return mean_iou_v, update_op -def mean_relative_error(labels, predictions, normalizer, weights=None, +@tf_export('metrics.mean_relative_error') +def mean_relative_error(labels, + predictions, + normalizer, + weights=None, metrics_collections=None, updates_collections=None, name=None): @@ -1056,14 +1109,16 @@ def mean_relative_error(labels, predictions, normalizer, weights=None, predictions, normalizer) predictions.get_shape().assert_is_compatible_with(normalizer.get_shape()) relative_errors = array_ops.where( - math_ops.equal(normalizer, 0.0), - array_ops.zeros_like(labels), + math_ops.equal(normalizer, 0.0), array_ops.zeros_like(labels), math_ops.div(math_ops.abs(labels - predictions), normalizer)) return mean(relative_errors, weights, metrics_collections, updates_collections, name or 'mean_relative_error') -def mean_squared_error(labels, predictions, weights=None, +@tf_export('metrics.mean_squared_error') +def mean_squared_error(labels, + predictions, + weights=None, metrics_collections=None, updates_collections=None, name=None): @@ -1117,12 +1172,16 @@ def mean_squared_error(labels, predictions, weights=None, predictions, labels, weights = _remove_squeezable_dimensions( predictions=predictions, labels=labels, weights=weights) squared_error = math_ops.square(labels - predictions) - return mean(squared_error, weights, metrics_collections, - updates_collections, name or 'mean_squared_error') + return mean(squared_error, weights, metrics_collections, updates_collections, + name or 'mean_squared_error') -def mean_tensor(values, weights=None, metrics_collections=None, - updates_collections=None, name=None): +@tf_export('metrics.mean_tensor') +def mean_tensor(values, + weights=None, + metrics_collections=None, + updates_collections=None, + name=None): """Computes the element-wise (weighted) mean of the given tensors. In contrast to the `mean` function which returns a scalar with the @@ -1189,9 +1248,8 @@ def mean_tensor(values, weights=None, metrics_collections=None, update_count_op = state_ops.assign_add(count, num_values) def compute_mean(total, count, name): - non_zero_count = math_ops.maximum(count, - array_ops.ones_like(count), - name=name) + non_zero_count = math_ops.maximum( + count, array_ops.ones_like(count), name=name) return math_ops.truediv(total, non_zero_count, name=name) mean_t = compute_mean(total, count, 'value') @@ -1206,7 +1264,10 @@ def mean_tensor(values, weights=None, metrics_collections=None, return mean_t, update_op -def percentage_below(values, threshold, weights=None, +@tf_export('metrics.percentage_below') +def percentage_below(values, + threshold, + weights=None, metrics_collections=None, updates_collections=None, name=None): @@ -1253,14 +1314,13 @@ def percentage_below(values, threshold, weights=None, 'eager execution is enabled.') is_below_threshold = math_ops.to_float(math_ops.less(values, threshold)) - return mean(is_below_threshold, - weights, - metrics_collections, - updates_collections, - name or 'percentage_below_threshold') + return mean(is_below_threshold, weights, metrics_collections, + updates_collections, name or 'percentage_below_threshold') -def _count_condition(values, weights=None, metrics_collections=None, +def _count_condition(values, + weights=None, + metrics_collections=None, updates_collections=None): """Sums the weights of cases where the given values are True. @@ -1290,8 +1350,8 @@ def _count_condition(values, weights=None, metrics_collections=None, values = math_ops.to_float(values) if weights is not None: - with ops.control_dependencies(( - check_ops.assert_rank_in(weights, (0, array_ops.rank(values))),)): + with ops.control_dependencies((check_ops.assert_rank_in( + weights, (0, array_ops.rank(values))),)): weights = math_ops.to_float(weights) values = math_ops.multiply(values, weights) @@ -1307,7 +1367,10 @@ def _count_condition(values, weights=None, metrics_collections=None, return value_tensor, update_op -def false_negatives(labels, predictions, weights=None, +@tf_export('metrics.false_negatives') +def false_negatives(labels, + predictions, + weights=None, metrics_collections=None, updates_collections=None, name=None): @@ -1343,20 +1406,24 @@ def false_negatives(labels, predictions, weights=None, raise RuntimeError('tf.metrics.false_negatives is not supported when ' 'eager execution is enabled.') - with variable_scope.variable_scope( - name, 'false_negatives', (predictions, labels, weights)): + with variable_scope.variable_scope(name, 'false_negatives', + (predictions, labels, weights)): predictions, labels, weights = _remove_squeezable_dimensions( predictions=math_ops.cast(predictions, dtype=dtypes.bool), labels=math_ops.cast(labels, dtype=dtypes.bool), weights=weights) - is_false_negative = math_ops.logical_and(math_ops.equal(labels, True), - math_ops.equal(predictions, False)) + is_false_negative = math_ops.logical_and( + math_ops.equal(labels, True), math_ops.equal(predictions, False)) return _count_condition(is_false_negative, weights, metrics_collections, updates_collections) -def false_negatives_at_thresholds(labels, predictions, thresholds, weights=None, +@tf_export('metrics.false_negatives_at_thresholds') +def false_negatives_at_thresholds(labels, + predictions, + thresholds, + weights=None, metrics_collections=None, updates_collections=None, name=None): @@ -1409,7 +1476,10 @@ def false_negatives_at_thresholds(labels, predictions, thresholds, weights=None, return values['fn'], update_ops['fn'] -def false_positives(labels, predictions, weights=None, +@tf_export('metrics.false_positives') +def false_positives(labels, + predictions, + weights=None, metrics_collections=None, updates_collections=None, name=None): @@ -1446,20 +1516,24 @@ def false_positives(labels, predictions, weights=None, raise RuntimeError('tf.metrics.false_positives is not supported when ' 'eager execution is enabled.') - with variable_scope.variable_scope( - name, 'false_positives', (predictions, labels, weights)): + with variable_scope.variable_scope(name, 'false_positives', + (predictions, labels, weights)): predictions, labels, weights = _remove_squeezable_dimensions( predictions=math_ops.cast(predictions, dtype=dtypes.bool), labels=math_ops.cast(labels, dtype=dtypes.bool), weights=weights) - is_false_positive = math_ops.logical_and(math_ops.equal(labels, False), - math_ops.equal(predictions, True)) + is_false_positive = math_ops.logical_and( + math_ops.equal(labels, False), math_ops.equal(predictions, True)) return _count_condition(is_false_positive, weights, metrics_collections, updates_collections) -def false_positives_at_thresholds(labels, predictions, thresholds, weights=None, +@tf_export('metrics.false_positives_at_thresholds') +def false_positives_at_thresholds(labels, + predictions, + thresholds, + weights=None, metrics_collections=None, updates_collections=None, name=None): @@ -1512,7 +1586,10 @@ def false_positives_at_thresholds(labels, predictions, thresholds, weights=None, return values['fp'], update_ops['fp'] -def true_negatives(labels, predictions, weights=None, +@tf_export('metrics.true_negatives') +def true_negatives(labels, + predictions, + weights=None, metrics_collections=None, updates_collections=None, name=None): @@ -1549,20 +1626,24 @@ def true_negatives(labels, predictions, weights=None, raise RuntimeError('tf.metrics.true_negatives is not ' 'supported when eager execution is enabled.') - with variable_scope.variable_scope( - name, 'true_negatives', (predictions, labels, weights)): + with variable_scope.variable_scope(name, 'true_negatives', + (predictions, labels, weights)): predictions, labels, weights = _remove_squeezable_dimensions( predictions=math_ops.cast(predictions, dtype=dtypes.bool), labels=math_ops.cast(labels, dtype=dtypes.bool), weights=weights) - is_true_negative = math_ops.logical_and(math_ops.equal(labels, False), - math_ops.equal(predictions, False)) + is_true_negative = math_ops.logical_and( + math_ops.equal(labels, False), math_ops.equal(predictions, False)) return _count_condition(is_true_negative, weights, metrics_collections, updates_collections) -def true_negatives_at_thresholds(labels, predictions, thresholds, weights=None, +@tf_export('metrics.true_negatives_at_thresholds') +def true_negatives_at_thresholds(labels, + predictions, + thresholds, + weights=None, metrics_collections=None, updates_collections=None, name=None): @@ -1615,7 +1696,10 @@ def true_negatives_at_thresholds(labels, predictions, thresholds, weights=None, return values['tn'], update_ops['tn'] -def true_positives(labels, predictions, weights=None, +@tf_export('metrics.true_positives') +def true_positives(labels, + predictions, + weights=None, metrics_collections=None, updates_collections=None, name=None): @@ -1652,20 +1736,24 @@ def true_positives(labels, predictions, weights=None, raise RuntimeError('tf.metrics.true_positives is not ' 'supported when eager execution is enabled.') - with variable_scope.variable_scope( - name, 'true_positives', (predictions, labels, weights)): + with variable_scope.variable_scope(name, 'true_positives', + (predictions, labels, weights)): predictions, labels, weights = _remove_squeezable_dimensions( predictions=math_ops.cast(predictions, dtype=dtypes.bool), labels=math_ops.cast(labels, dtype=dtypes.bool), weights=weights) - is_true_positive = math_ops.logical_and(math_ops.equal(labels, True), - math_ops.equal(predictions, True)) + is_true_positive = math_ops.logical_and( + math_ops.equal(labels, True), math_ops.equal(predictions, True)) return _count_condition(is_true_positive, weights, metrics_collections, updates_collections) -def true_positives_at_thresholds(labels, predictions, thresholds, weights=None, +@tf_export('metrics.true_positives_at_thresholds') +def true_positives_at_thresholds(labels, + predictions, + thresholds, + weights=None, metrics_collections=None, updates_collections=None, name=None): @@ -1718,8 +1806,12 @@ def true_positives_at_thresholds(labels, predictions, thresholds, weights=None, return values['tp'], update_ops['tp'] -def precision(labels, predictions, weights=None, - metrics_collections=None, updates_collections=None, +@tf_export('metrics.precision') +def precision(labels, + predictions, + weights=None, + metrics_collections=None, + updates_collections=None, name=None): """Computes the precision of the predictions with respect to the labels. @@ -1768,8 +1860,8 @@ def precision(labels, predictions, weights=None, raise RuntimeError('tf.metrics.precision is not ' 'supported when eager execution is enabled.') - with variable_scope.variable_scope( - name, 'precision', (predictions, labels, weights)): + with variable_scope.variable_scope(name, 'precision', + (predictions, labels, weights)): predictions, labels, weights = _remove_squeezable_dimensions( predictions=math_ops.cast(predictions, dtype=dtypes.bool), @@ -1777,22 +1869,27 @@ def precision(labels, predictions, weights=None, weights=weights) true_p, true_positives_update_op = true_positives( - labels, predictions, weights, metrics_collections=None, - updates_collections=None, name=None) + labels, + predictions, + weights, + metrics_collections=None, + updates_collections=None, + name=None) false_p, false_positives_update_op = false_positives( - labels, predictions, weights, metrics_collections=None, - updates_collections=None, name=None) + labels, + predictions, + weights, + metrics_collections=None, + updates_collections=None, + name=None) def compute_precision(tp, fp, name): return array_ops.where( - math_ops.greater(tp + fp, 0), - math_ops.div(tp, tp + fp), - 0, - name) + math_ops.greater(tp + fp, 0), math_ops.div(tp, tp + fp), 0, name) p = compute_precision(true_p, false_p, 'value') - update_op = compute_precision( - true_positives_update_op, false_positives_update_op, 'update_op') + update_op = compute_precision(true_positives_update_op, + false_positives_update_op, 'update_op') if metrics_collections: ops.add_to_collections(metrics_collections, p) @@ -1803,10 +1900,14 @@ def precision(labels, predictions, weights=None, return p, update_op -def precision_at_thresholds(labels, predictions, thresholds, +@tf_export('metrics.precision_at_thresholds') +def precision_at_thresholds(labels, + predictions, + thresholds, weights=None, metrics_collections=None, - updates_collections=None, name=None): + updates_collections=None, + name=None): """Computes precision values for different `thresholds` on `predictions`. The `precision_at_thresholds` function creates four local variables, @@ -1862,12 +1963,13 @@ def precision_at_thresholds(labels, predictions, thresholds, # Avoid division by zero. epsilon = 1e-7 + def compute_precision(tp, fp, name): return math_ops.div(tp, epsilon + tp + fp, name='precision_' + name) prec = compute_precision(values['tp'], values['fp'], 'value') - update_op = compute_precision( - update_ops['tp'], update_ops['fp'], 'update_op') + update_op = compute_precision(update_ops['tp'], update_ops['fp'], + 'update_op') if metrics_collections: ops.add_to_collections(metrics_collections, prec) @@ -1878,8 +1980,12 @@ def precision_at_thresholds(labels, predictions, thresholds, return prec, update_op -def recall(labels, predictions, weights=None, - metrics_collections=None, updates_collections=None, +@tf_export('metrics.recall') +def recall(labels, + predictions, + weights=None, + metrics_collections=None, + updates_collections=None, name=None): """Computes the recall of the predictions with respect to the labels. @@ -1926,30 +2032,36 @@ def recall(labels, predictions, weights=None, raise RuntimeError('tf.metrics.recall is not supported is not ' 'supported when eager execution is enabled.') - with variable_scope.variable_scope( - name, 'recall', (predictions, labels, weights)): + with variable_scope.variable_scope(name, 'recall', + (predictions, labels, weights)): predictions, labels, weights = _remove_squeezable_dimensions( predictions=math_ops.cast(predictions, dtype=dtypes.bool), labels=math_ops.cast(labels, dtype=dtypes.bool), weights=weights) true_p, true_positives_update_op = true_positives( - labels, predictions, weights, metrics_collections=None, - updates_collections=None, name=None) + labels, + predictions, + weights, + metrics_collections=None, + updates_collections=None, + name=None) false_n, false_negatives_update_op = false_negatives( - labels, predictions, weights, metrics_collections=None, - updates_collections=None, name=None) + labels, + predictions, + weights, + metrics_collections=None, + updates_collections=None, + name=None) def compute_recall(true_p, false_n, name): return array_ops.where( math_ops.greater(true_p + false_n, 0), - math_ops.div(true_p, true_p + false_n), - 0, - name) + math_ops.div(true_p, true_p + false_n), 0, name) rec = compute_recall(true_p, false_n, 'value') - update_op = compute_recall( - true_positives_update_op, false_negatives_update_op, 'update_op') + update_op = compute_recall(true_positives_update_op, + false_negatives_update_op, 'update_op') if metrics_collections: ops.add_to_collections(metrics_collections, rec) @@ -1983,8 +2095,8 @@ def _select_class_id(ids, selected_id): """ ids = sparse_tensor.convert_to_tensor_or_sparse_tensor(ids) if isinstance(ids, sparse_tensor.SparseTensor): - return sparse_ops.sparse_retain( - ids, math_ops.equal(ids.values, selected_id)) + return sparse_ops.sparse_retain(ids, math_ops.equal(ids.values, + selected_id)) # TODO(ptucker): Make this more efficient, maybe add a sparse version of # tf.equal and tf.reduce_any? @@ -1992,12 +2104,13 @@ def _select_class_id(ids, selected_id): # Shape of filled IDs is the same as `ids` with the last dim collapsed to 1. ids_shape = array_ops.shape(ids, out_type=dtypes.int64) ids_last_dim = array_ops.size(ids_shape) - 1 - filled_selected_id_shape = math_ops.reduced_shape( - ids_shape, array_ops.reshape(ids_last_dim, [1])) + filled_selected_id_shape = math_ops.reduced_shape(ids_shape, + array_ops.reshape( + ids_last_dim, [1])) # Intersect `ids` with the selected ID. - filled_selected_id = array_ops.fill( - filled_selected_id_shape, math_ops.to_int64(selected_id)) + filled_selected_id = array_ops.fill(filled_selected_id_shape, + math_ops.to_int64(selected_id)) result = sets.set_intersection(filled_selected_id, ids) return sparse_tensor.SparseTensor( indices=result.indices, values=result.values, dense_shape=ids_shape) @@ -2057,15 +2170,15 @@ def _sparse_true_positive_at_k(labels, Returns: A [D1, ... DN] `Tensor` of true positive counts. """ - with ops.name_scope( - name, 'true_positives', (predictions_idx, labels, weights)): - labels, predictions_idx = _maybe_select_class_id( - labels, predictions_idx, class_id) + with ops.name_scope(name, 'true_positives', + (predictions_idx, labels, weights)): + labels, predictions_idx = _maybe_select_class_id(labels, predictions_idx, + class_id) tp = sets.set_size(sets.set_intersection(predictions_idx, labels)) tp = math_ops.to_double(tp) if weights is not None: - with ops.control_dependencies(( - weights_broadcast_ops.assert_broadcastable(weights, tp),)): + with ops.control_dependencies((weights_broadcast_ops.assert_broadcastable( + weights, tp),)): weights = math_ops.to_double(weights) tp = math_ops.multiply(tp, weights) return tp @@ -2109,11 +2222,12 @@ def _streaming_sparse_true_positive_at_k(labels, Raises: ValueError: If `weights` is not `None` and has an incompatible shape. """ - with ops.name_scope( - name, _at_k_name('true_positive', k, class_id=class_id), - (predictions_idx, labels, weights)) as scope: + with ops.name_scope(name, _at_k_name('true_positive', k, class_id=class_id), + (predictions_idx, labels, weights)) as scope: tp = _sparse_true_positive_at_k( - predictions_idx=predictions_idx, labels=labels, class_id=class_id, + predictions_idx=predictions_idx, + labels=labels, + class_id=class_id, weights=weights) batch_total_tp = math_ops.to_double(math_ops.reduce_sum(tp)) @@ -2150,18 +2264,16 @@ def _sparse_false_negative_at_k(labels, Returns: A [D1, ... DN] `Tensor` of false negative counts. """ - with ops.name_scope( - None, 'false_negatives', (predictions_idx, labels, weights)): - labels, predictions_idx = _maybe_select_class_id(labels, - predictions_idx, + with ops.name_scope(None, 'false_negatives', + (predictions_idx, labels, weights)): + labels, predictions_idx = _maybe_select_class_id(labels, predictions_idx, class_id) - fn = sets.set_size(sets.set_difference(predictions_idx, - labels, - aminusb=False)) + fn = sets.set_size( + sets.set_difference(predictions_idx, labels, aminusb=False)) fn = math_ops.to_double(fn) if weights is not None: - with ops.control_dependencies(( - weights_broadcast_ops.assert_broadcastable(weights, fn),)): + with ops.control_dependencies((weights_broadcast_ops.assert_broadcastable( + weights, fn),)): weights = math_ops.to_double(weights) fn = math_ops.multiply(fn, weights) return fn @@ -2205,11 +2317,12 @@ def _streaming_sparse_false_negative_at_k(labels, Raises: ValueError: If `weights` is not `None` and has an incompatible shape. """ - with ops.name_scope( - name, _at_k_name('false_negative', k, class_id=class_id), - (predictions_idx, labels, weights)) as scope: + with ops.name_scope(name, _at_k_name('false_negative', k, class_id=class_id), + (predictions_idx, labels, weights)) as scope: fn = _sparse_false_negative_at_k( - predictions_idx=predictions_idx, labels=labels, class_id=class_id, + predictions_idx=predictions_idx, + labels=labels, + class_id=class_id, weights=weights) batch_total_fn = math_ops.to_double(math_ops.reduce_sum(fn)) @@ -2217,6 +2330,7 @@ def _streaming_sparse_false_negative_at_k(labels, return var, state_ops.assign_add(var, batch_total_fn, name='update') +@tf_export('metrics.recall_at_k') def recall_at_k(labels, predictions, k, @@ -2295,9 +2409,8 @@ def recall_at_k(labels, raise RuntimeError('tf.metrics.recall_at_k is not ' 'supported when eager execution is enabled.') - with ops.name_scope( - name, _at_k_name('recall', k, class_id=class_id), - (predictions, labels, weights)) as scope: + with ops.name_scope(name, _at_k_name('recall', k, class_id=class_id), + (predictions, labels, weights)) as scope: _, top_k_idx = nn.top_k(predictions, k) return recall_at_top_k( labels=labels, @@ -2310,6 +2423,7 @@ def recall_at_k(labels, name=scope) +@tf_export('metrics.recall_at_top_k') def recall_at_top_k(labels, predictions_idx, k=None, @@ -2363,16 +2477,21 @@ def recall_at_top_k(labels, `predictions`, or if either `metrics_collections` or `updates_collections` are not a list or tuple. """ - with ops.name_scope(name, - _at_k_name('recall', k, class_id=class_id), + with ops.name_scope(name, _at_k_name('recall', k, class_id=class_id), (predictions_idx, labels, weights)) as scope: labels = _maybe_expand_labels(labels, predictions_idx) top_k_idx = math_ops.to_int64(predictions_idx) tp, tp_update = _streaming_sparse_true_positive_at_k( - predictions_idx=top_k_idx, labels=labels, k=k, class_id=class_id, + predictions_idx=top_k_idx, + labels=labels, + k=k, + class_id=class_id, weights=weights) fn, fn_update = _streaming_sparse_false_negative_at_k( - predictions_idx=top_k_idx, labels=labels, k=k, class_id=class_id, + predictions_idx=top_k_idx, + labels=labels, + k=k, + class_id=class_id, weights=weights) metric = math_ops.div(tp, math_ops.add(tp, fn), name=scope) @@ -2385,9 +2504,14 @@ def recall_at_top_k(labels, return metric, update -def recall_at_thresholds(labels, predictions, thresholds, - weights=None, metrics_collections=None, - updates_collections=None, name=None): +@tf_export('metrics.recall_at_thresholds') +def recall_at_thresholds(labels, + predictions, + thresholds, + weights=None, + metrics_collections=None, + updates_collections=None, + name=None): """Computes various recall values for different `thresholds` on `predictions`. The `recall_at_thresholds` function creates four local variables, @@ -2441,6 +2565,7 @@ def recall_at_thresholds(labels, predictions, thresholds, # Avoid division by zero. epsilon = 1e-7 + def compute_recall(tp, fn, name): return math_ops.div(tp, epsilon + tp + fn, name='recall_' + name) @@ -2456,7 +2581,10 @@ def recall_at_thresholds(labels, predictions, thresholds, return rec, update_op -def root_mean_squared_error(labels, predictions, weights=None, +@tf_export('metrics.root_mean_squared_error') +def root_mean_squared_error(labels, + predictions, + weights=None, metrics_collections=None, updates_collections=None, name=None): @@ -2509,9 +2637,9 @@ def root_mean_squared_error(labels, predictions, weights=None, predictions, labels, weights = _remove_squeezable_dimensions( predictions=predictions, labels=labels, weights=weights) - mse, update_mse_op = mean_squared_error( - labels, predictions, weights, None, None, - name or 'root_mean_squared_error') + mse, update_mse_op = mean_squared_error(labels, predictions, weights, None, + None, name or + 'root_mean_squared_error') rmse = math_ops.sqrt(mse) update_rmse_op = math_ops.sqrt(update_mse_op) @@ -2525,9 +2653,15 @@ def root_mean_squared_error(labels, predictions, weights=None, return rmse, update_rmse_op -def sensitivity_at_specificity( - labels, predictions, specificity, weights=None, num_thresholds=200, - metrics_collections=None, updates_collections=None, name=None): +@tf_export('metrics.sensitivity_at_specificity') +def sensitivity_at_specificity(labels, + predictions, + specificity, + weights=None, + num_thresholds=200, + metrics_collections=None, + updates_collections=None, + name=None): """Computes the specificity at a given sensitivity. The `sensitivity_at_specificity` function creates four local @@ -2588,8 +2722,9 @@ def sensitivity_at_specificity( with variable_scope.variable_scope(name, 'sensitivity_at_specificity', (predictions, labels, weights)): kepsilon = 1e-7 # to account for floating point imprecisions - thresholds = [(i + 1) * 1.0 / (num_thresholds - 1) - for i in range(num_thresholds-2)] + thresholds = [ + (i + 1) * 1.0 / (num_thresholds - 1) for i in range(num_thresholds - 2) + ] thresholds = [0.0 - kepsilon] + thresholds + [1.0 + kepsilon] values, update_ops = _confusion_matrix_at_thresholds( @@ -2601,8 +2736,7 @@ def sensitivity_at_specificity( tf_index = math_ops.cast(tf_index, dtypes.int32) # Now, we have the implicit threshold, so compute the sensitivity: - return math_ops.div(tp[tf_index], - tp[tf_index] + fn[tf_index] + kepsilon, + return math_ops.div(tp[tf_index], tp[tf_index] + fn[tf_index] + kepsilon, name) sensitivity = compute_sensitivity_at_specificity( @@ -2641,8 +2775,8 @@ def _expand_and_tile(tensor, multiple, dim=0, name=None): """ if multiple < 1: raise ValueError('Invalid multiple %s, must be > 0.' % multiple) - with ops.name_scope( - name, 'expand_and_tile', (tensor, multiple, dim)) as scope: + with ops.name_scope(name, 'expand_and_tile', + (tensor, multiple, dim)) as scope: # Sparse. tensor = sparse_tensor.convert_to_tensor_or_sparse_tensor(tensor) if isinstance(tensor, sparse_tensor.SparseTensor): @@ -2742,8 +2876,8 @@ def _sparse_average_precision_at_top_k(labels, predictions_idx): Raises: ValueError: if the last dimension of predictions_idx is not set. """ - with ops.name_scope( - None, 'average_precision', (predictions_idx, labels)) as scope: + with ops.name_scope(None, 'average_precision', + (predictions_idx, labels)) as scope: predictions_idx = math_ops.to_int64(predictions_idx, name='predictions_idx') if predictions_idx.get_shape().ndims == 0: raise ValueError('The rank of predictions_idx must be at least 1.') @@ -2780,10 +2914,12 @@ def _sparse_average_precision_at_top_k(labels, predictions_idx): retrieved_per_k = math_ops.cumsum( array_ops.ones_like(relevant_per_k), axis=-1, name='retrieved_per_k') precision_per_k = math_ops.div( - math_ops.to_double(tp_per_k), math_ops.to_double(retrieved_per_k), + math_ops.to_double(tp_per_k), + math_ops.to_double(retrieved_per_k), name='precision_per_k') relevant_precision_per_k = math_ops.multiply( - precision_per_k, math_ops.to_double(relevant_per_k), + precision_per_k, + math_ops.to_double(relevant_per_k), name='relevant_precision_per_k') # Reduce along k dimension to get the sum, yielding a [D1, ... DN] tensor. @@ -2887,6 +3023,7 @@ def _streaming_sparse_average_precision_at_top_k(labels, return mean_average_precision, update +@tf_export('metrics.sparse_average_precision_at_k') @deprecated(None, 'Use average_precision_at_k instead') def sparse_average_precision_at_k(labels, predictions, @@ -2906,6 +3043,7 @@ def sparse_average_precision_at_k(labels, name=name) +@tf_export('metrics.average_precision_at_k') def average_precision_at_k(labels, predictions, k, @@ -2971,9 +3109,8 @@ def average_precision_at_k(labels, if k < 1: raise ValueError('Invalid k=%s.' % k) - with ops.name_scope( - name, _at_k_name('average_precision', k), - (predictions, labels, weights)) as scope: + with ops.name_scope(name, _at_k_name('average_precision', k), + (predictions, labels, weights)) as scope: # Calculate top k indices to produce [D1, ... DN, k] tensor. _, predictions_idx = nn.top_k(predictions, k) return _streaming_sparse_average_precision_at_top_k( @@ -3014,17 +3151,16 @@ def _sparse_false_positive_at_k(labels, Returns: A [D1, ... DN] `Tensor` of false positive counts. """ - with ops.name_scope( - None, 'false_positives', (predictions_idx, labels, weights)): - labels, predictions_idx = _maybe_select_class_id(labels, - predictions_idx, + with ops.name_scope(None, 'false_positives', + (predictions_idx, labels, weights)): + labels, predictions_idx = _maybe_select_class_id(labels, predictions_idx, class_id) - fp = sets.set_size(sets.set_difference( - predictions_idx, labels, aminusb=True)) + fp = sets.set_size( + sets.set_difference(predictions_idx, labels, aminusb=True)) fp = math_ops.to_double(fp) if weights is not None: - with ops.control_dependencies(( - weights_broadcast_ops.assert_broadcastable(weights, fp),)): + with ops.control_dependencies((weights_broadcast_ops.assert_broadcastable( + weights, fp),)): weights = math_ops.to_double(weights) fp = math_ops.multiply(fp, weights) return fp @@ -3068,11 +3204,12 @@ def _streaming_sparse_false_positive_at_k(labels, Raises: ValueError: If `weights` is not `None` and has an incompatible shape. """ - with ops.name_scope( - name, _at_k_name('false_positive', k, class_id=class_id), - (predictions_idx, labels, weights)) as scope: + with ops.name_scope(name, _at_k_name('false_positive', k, class_id=class_id), + (predictions_idx, labels, weights)) as scope: fp = _sparse_false_positive_at_k( - predictions_idx=predictions_idx, labels=labels, class_id=class_id, + predictions_idx=predictions_idx, + labels=labels, + class_id=class_id, weights=weights) batch_total_fp = math_ops.to_double(math_ops.reduce_sum(fp)) @@ -3080,6 +3217,7 @@ def _streaming_sparse_false_positive_at_k(labels, return var, state_ops.assign_add(var, batch_total_fp, name='update') +@tf_export('metrics.precision_at_top_k') def precision_at_top_k(labels, predictions_idx, k=None, @@ -3143,10 +3281,16 @@ def precision_at_top_k(labels, labels = _maybe_expand_labels(labels, predictions_idx) top_k_idx = math_ops.to_int64(predictions_idx) tp, tp_update = _streaming_sparse_true_positive_at_k( - predictions_idx=top_k_idx, labels=labels, k=k, class_id=class_id, + predictions_idx=top_k_idx, + labels=labels, + k=k, + class_id=class_id, weights=weights) fp, fp_update = _streaming_sparse_false_positive_at_k( - predictions_idx=top_k_idx, labels=labels, k=k, class_id=class_id, + predictions_idx=top_k_idx, + labels=labels, + k=k, + class_id=class_id, weights=weights) metric = math_ops.div(tp, math_ops.add(tp, fp), name=scope) @@ -3159,6 +3303,7 @@ def precision_at_top_k(labels, return metric, update +@tf_export('metrics.sparse_precision_at_k') @deprecated(None, 'Use precision_at_k instead') def sparse_precision_at_k(labels, predictions, @@ -3180,6 +3325,7 @@ def sparse_precision_at_k(labels, name=name) +@tf_export('metrics.precision_at_k') def precision_at_k(labels, predictions, k, @@ -3273,9 +3419,15 @@ def precision_at_k(labels, name=scope) -def specificity_at_sensitivity( - labels, predictions, sensitivity, weights=None, num_thresholds=200, - metrics_collections=None, updates_collections=None, name=None): +@tf_export('metrics.specificity_at_sensitivity') +def specificity_at_sensitivity(labels, + predictions, + sensitivity, + weights=None, + num_thresholds=200, + metrics_collections=None, + updates_collections=None, + name=None): """Computes the specificity at a given sensitivity. The `specificity_at_sensitivity` function creates four local @@ -3336,8 +3488,9 @@ def specificity_at_sensitivity( with variable_scope.variable_scope(name, 'specificity_at_sensitivity', (predictions, labels, weights)): kepsilon = 1e-7 # to account for floating point imprecisions - thresholds = [(i + 1) * 1.0 / (num_thresholds - 1) - for i in range(num_thresholds-2)] + thresholds = [ + (i + 1) * 1.0 / (num_thresholds - 1) for i in range(num_thresholds - 2) + ] thresholds = [0.0 - kepsilon] + thresholds + [1.0 - kepsilon] values, update_ops = _confusion_matrix_at_thresholds( @@ -3369,8 +3522,7 @@ def specificity_at_sensitivity( tf_index = math_ops.cast(tf_index, dtypes.int32) # Now, we have the implicit threshold, so compute the specificity: - return math_ops.div(tn[tf_index], - tn[tf_index] + fp[tf_index] + kepsilon, + return math_ops.div(tn[tf_index], tn[tf_index] + fp[tf_index] + kepsilon, name) specificity = compute_specificity_at_sensitivity( diff --git a/tensorflow/python/ops/nn_impl.py b/tensorflow/python/ops/nn_impl.py index fd96f7b8fcf423e2381f84b50b0532e46ce2fe6e..55fcd176d62009b9c29afb763dc20daf78cdb5d9 100644 --- a/tensorflow/python/ops/nn_impl.py +++ b/tensorflow/python/ops/nn_impl.py @@ -35,8 +35,10 @@ from tensorflow.python.ops import sparse_ops from tensorflow.python.ops import variables from tensorflow.python.util.deprecation import deprecated_args from tensorflow.python.util.deprecation import deprecated_argument_lookup +from tensorflow.python.util.tf_export import tf_export +@tf_export("nn.log_poisson_loss") def log_poisson_loss(targets, log_input, compute_full_loss=False, name=None): """Computes log Poisson loss given `log_input`. @@ -101,6 +103,7 @@ def log_poisson_loss(targets, log_input, compute_full_loss=False, name=None): return result +@tf_export("nn.sigmoid_cross_entropy_with_logits") def sigmoid_cross_entropy_with_logits( # pylint: disable=invalid-name _sentinel=None, labels=None, @@ -180,6 +183,7 @@ def sigmoid_cross_entropy_with_logits( # pylint: disable=invalid-name name=name) +@tf_export("nn.weighted_cross_entropy_with_logits") def weighted_cross_entropy_with_logits(targets, logits, pos_weight, name=None): """Computes a weighted cross entropy. @@ -192,7 +196,13 @@ def weighted_cross_entropy_with_logits(targets, logits, pos_weight, name=None): targets * -log(sigmoid(logits)) + (1 - targets) * -log(1 - sigmoid(logits)) - The argument `pos_weight` is used as a multiplier for the positive targets: + A value `pos_weights > 1` decreases the false negative count, hence increasing + the recall. + Conversely setting `pos_weights < 1` decreases the false positive count and + increases the precision. + This can be seen from the fact that `pos_weight` is introduced as a + multiplicative coefficient for the positive targets term + in the loss expression: targets * -log(sigmoid(logits)) * pos_weight + (1 - targets) * -log(1 - sigmoid(logits)) @@ -251,6 +261,7 @@ def weighted_cross_entropy_with_logits(targets, logits, pos_weight, name=None): name=name) +@tf_export("nn.relu_layer") def relu_layer(x, weights, biases, name=None): """Computes Relu(x * weight + biases). @@ -297,6 +308,7 @@ def _swish_grad(features, grad): shape_func=_swish_shape, func_name="swish", noinline=True) +@tf_export("nn.swish") def swish(features): # pylint: disable=g-doc-args """Computes the Swish activation function: `x * sigmoid(x)`. @@ -316,6 +328,7 @@ def swish(features): return features * math_ops.sigmoid(features) +@tf_export("nn.l2_normalize") @deprecated_args(None, "dim is deprecated, use axis instead", "dim") def l2_normalize(x, axis=None, epsilon=1e-12, name=None, dim=None): """Normalizes along dimension `axis` using an L2 norm. @@ -347,6 +360,7 @@ def l2_normalize(x, axis=None, epsilon=1e-12, name=None, dim=None): return math_ops.multiply(x, x_inv_norm, name=name) +@tf_export("nn.zero_fraction") def zero_fraction(value, name=None): """Returns the fraction of zeros in `value`. @@ -374,6 +388,7 @@ def zero_fraction(value, name=None): # pylint: disable=redefined-builtin +@tf_export("nn.depthwise_conv2d") def depthwise_conv2d(input, filter, strides, @@ -450,6 +465,7 @@ def depthwise_conv2d(input, # pylint: disable=redefined-builtin,line-too-long +@tf_export("nn.separable_conv2d") def separable_conv2d(input, depthwise_filter, pointwise_filter, @@ -550,6 +566,7 @@ def separable_conv2d(input, # pylint: enable=redefined-builtin,line-too-long +@tf_export("nn.sufficient_statistics") def sufficient_statistics(x, axes, shift=None, keep_dims=False, name=None): """Calculate the sufficient statistics for the mean and variance of `x`. @@ -599,6 +616,7 @@ def sufficient_statistics(x, axes, shift=None, keep_dims=False, name=None): return counts, m_ss, v_ss, shift +@tf_export("nn.normalize_moments") def normalize_moments(counts, mean_ss, variance_ss, shift, name=None): """Calculate the mean and variance of based on the sufficient statistics. @@ -630,9 +648,13 @@ def normalize_moments(counts, mean_ss, variance_ss, shift, name=None): return (mean, variance) -def moments(x, axes, - shift=None, # pylint: disable=unused-argument - name=None, keep_dims=False): +@tf_export("nn.moments") +def moments( + x, + axes, + shift=None, # pylint: disable=unused-argument + name=None, + keep_dims=False): """Calculate the mean and variance of `x`. The mean and variance are calculated by aggregating the contents of `x` @@ -676,12 +698,13 @@ def moments(x, axes, mean = array_ops.squeeze(mean, axes) variance = array_ops.squeeze(variance, axes) if x.dtype == dtypes.float16: - return (math_ops.cast(mean, dtypes.float16), math_ops.cast( - variance, dtypes.float16)) + return (math_ops.cast(mean, dtypes.float16), + math_ops.cast(variance, dtypes.float16)) else: return (mean, variance) +@tf_export("nn.weighted_moments") def weighted_moments(x, axes, frequency_weights, name=None, keep_dims=False): """Returns the frequency-weighted mean and variance of `x`. @@ -753,6 +776,7 @@ def weighted_moments(x, axes, frequency_weights, name=None, keep_dims=False): return weighted_mean, weighted_variance +@tf_export("nn.batch_normalization") def batch_normalization(x, mean, variance, @@ -806,10 +830,11 @@ def batch_normalization(x, inv = math_ops.rsqrt(variance + variance_epsilon) if scale is not None: inv *= scale - return x * inv + (offset - mean * inv - if offset is not None else -mean * inv) + return x * inv + ( + offset - mean * inv if offset is not None else -mean * inv) +@tf_export("nn.fused_batch_norm") def fused_batch_norm( x, scale, @@ -882,6 +907,7 @@ def fused_batch_norm( return y, batch_mean, batch_var +@tf_export("nn.batch_norm_with_global_normalization") def batch_norm_with_global_normalization(t, m, v, @@ -943,7 +969,8 @@ def _compute_sampled_logits(weights, subtract_log_q=True, remove_accidental_hits=False, partition_strategy="mod", - name=None): + name=None, + seed=None): """Helper function for nce_loss and sampled_softmax_loss functions. Computes sampled output training logits and labels suitable for implementing @@ -981,6 +1008,8 @@ def _compute_sampled_logits(weights, if `len(weights) > 1`. Currently `"div"` and `"mod"` are supported. Default is `"mod"`. See `tf.nn.embedding_lookup` for more details. name: A name for the operation (optional). + seed: random seed for candidate sampling. Default to None, which doesn't set + the op-level random seed for candidate sampling. Returns: out_logits: `Tensor` object with shape `[batch_size, num_true + num_sampled]`, for passing to either @@ -1010,7 +1039,8 @@ def _compute_sampled_logits(weights, num_true=num_true, num_sampled=num_sampled, unique=True, - range_max=num_classes) + range_max=num_classes, + seed=seed) # NOTE: pylint cannot tell that 'sampled_values' is a sequence # pylint: disable=unpacking-non-sequence sampled, true_expected_count, sampled_expected_count = ( @@ -1109,6 +1139,7 @@ def _compute_sampled_logits(weights, return out_logits, out_labels +@tf_export("nn.nce_loss") def nce_loss(weights, biases, labels, @@ -1217,6 +1248,7 @@ def nce_loss(weights, return _sum_rows(sampled_losses) +@tf_export("nn.sampled_softmax_loss") def sampled_softmax_loss(weights, biases, labels, @@ -1227,7 +1259,8 @@ def sampled_softmax_loss(weights, sampled_values=None, remove_accidental_hits=True, partition_strategy="mod", - name="sampled_softmax_loss"): + name="sampled_softmax_loss", + seed=None): """Computes and returns the sampled softmax training loss. This is a faster way to train a softmax classifier over a huge number of @@ -1288,6 +1321,8 @@ def sampled_softmax_loss(weights, if `len(weights) > 1`. Currently `"div"` and `"mod"` are supported. Default is `"mod"`. See `tf.nn.embedding_lookup` for more details. name: A name for the operation (optional). + seed: random seed for candidate sampling. Default to None, which doesn't set + the op-level random seed for candidate sampling. Returns: A `batch_size` 1-D tensor of per-example sampled softmax losses. @@ -1305,7 +1340,8 @@ def sampled_softmax_loss(weights, subtract_log_q=True, remove_accidental_hits=remove_accidental_hits, partition_strategy=partition_strategy, - name=name) + name=name, + seed=seed) sampled_losses = nn_ops.softmax_cross_entropy_with_logits( labels=labels, logits=logits) # sampled_losses is a [batch_size] tensor. diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py index 865e459e900c0bbfd9b08fbc62725ac6f6a4bcf6..32b14f86b567ce26334c1594e9ac6f00afd5b9d1 100644 --- a/tensorflow/python/ops/nn_ops.py +++ b/tensorflow/python/ops/nn_ops.py @@ -39,6 +39,7 @@ from tensorflow.python.ops.gen_nn_ops import * # pylint: enable=wildcard-import from tensorflow.python.util import deprecation +from tensorflow.python.util.tf_export import tf_export # Aliases for some automatically-generated names. @@ -190,6 +191,7 @@ class _NonAtrousConvolution(object): name=self.name) +@tf_export("nn.with_space_to_batch") def with_space_to_batch( input, # pylint: disable=redefined-builtin dilation_rate, @@ -633,6 +635,7 @@ def _get_strides_and_dilation_rate(num_spatial_dims, strides, dilation_rate): return strides, dilation_rate +@tf_export("nn.convolution") def convolution(input, filter, # pylint: disable=redefined-builtin padding, strides=None, dilation_rate=None, name=None, data_format=None): @@ -848,6 +851,7 @@ class Convolution(object): return self.conv_op(inp, filter) +@tf_export("nn.pool") def pool(input, # pylint: disable=redefined-builtin window_shape, pooling_type, @@ -1015,6 +1019,7 @@ def pool(input, # pylint: disable=redefined-builtin filter_shape=window_shape) +@tf_export("nn.atrous_conv2d") def atrous_conv2d(value, filters, rate, padding, name=None): """Atrous convolution (a.k.a. convolution with holes or dilated convolution). @@ -1150,6 +1155,7 @@ def atrous_conv2d(value, filters, rate, padding, name=None): name=name) +@tf_export("nn.conv2d_transpose") def conv2d_transpose(value, filter, # pylint: disable=redefined-builtin output_shape, @@ -1225,6 +1231,7 @@ def conv2d_transpose(value, name=name) +@tf_export("nn.atrous_conv2d_transpose") def atrous_conv2d_transpose(value, filters, output_shape, @@ -1371,6 +1378,7 @@ def atrous_conv2d_transpose(value, block_size=rate) +@tf_export("nn.conv3d_transpose") def conv3d_transpose(value, filter, # pylint: disable=redefined-builtin output_shape, @@ -1444,6 +1452,7 @@ def conv3d_transpose(value, # pylint: disable=protected-access +@tf_export("nn.bias_add") def bias_add(value, bias, data_format=None, name=None): """Adds `bias` to `value`. @@ -1498,6 +1507,7 @@ def bias_add_v1(value, bias, name=None): return gen_nn_ops._bias_add_v1(value, bias, name=name) +@tf_export("nn.crelu") def crelu(features, name=None, axis=-1): """Computes Concatenated ReLU. @@ -1521,6 +1531,7 @@ def crelu(features, name=None, axis=-1): return gen_nn_ops.relu(c) +@tf_export("nn.relu6") def relu6(features, name=None): """Computes Rectified Linear 6: `min(max(features, 0), 6)`. Source: [Convolutional Deep Belief Networks on CIFAR-10. A. Krizhevsky](http://www.cs.utoronto.ca/~kriz/conv-cifar10-aug2010.pdf) @@ -1538,6 +1549,7 @@ def relu6(features, name=None): return gen_nn_ops._relu6(features, name=name) +@tf_export("nn.leaky_relu") def leaky_relu(features, alpha=0.2, name=None): """Compute the Leaky ReLU activation function. @@ -1546,7 +1558,8 @@ def leaky_relu(features, alpha=0.2, name=None): http://web.stanford.edu/~awni/papers/relu_hybrid_icml2013_final.pdf Args: - features: A `Tensor` representing preactivation values. + features: A `Tensor` representing preactivation values. Must be one of + the following types: `float16`, `float32`, `float64`, `int32`, `int64`. alpha: Slope of the activation function at x < 0. name: A name for the operation (optional). @@ -1555,7 +1568,9 @@ def leaky_relu(features, alpha=0.2, name=None): """ with ops.name_scope(name, "LeakyRelu", [features, alpha]): features = ops.convert_to_tensor(features, name="features") - alpha = ops.convert_to_tensor(alpha, name="alpha") + if features.dtype.is_integer: + features = math_ops.to_float(features) + alpha = ops.convert_to_tensor(alpha, dtype=features.dtype, name="alpha") return math_ops.maximum(alpha * features, features) @@ -1661,6 +1676,7 @@ def _softmax(logits, compute_op, dim=-1, name=None): return output +@tf_export("nn.softmax") @deprecation.deprecated_args(None, "dim is deprecated, use axis instead", "dim") def softmax(logits, axis=None, name=None, dim=None): """Computes softmax activations. @@ -1690,6 +1706,7 @@ def softmax(logits, axis=None, name=None, dim=None): return _softmax(logits, gen_nn_ops._softmax, axis, name) +@tf_export("nn.log_softmax") @deprecation.deprecated_args(None, "dim is deprecated, use axis instead", "dim") def log_softmax(logits, axis=None, name=None, dim=None): """Computes log softmax activations. @@ -1728,6 +1745,7 @@ def _ensure_xent_args(name, sentinel, labels, logits): raise ValueError("Both labels and logits must be provided.") +@tf_export("nn.softmax_cross_entropy_with_logits_v2") def softmax_cross_entropy_with_logits_v2(_sentinel=None, # pylint: disable=invalid-name labels=None, logits=None, dim=-1, name=None): @@ -1842,6 +1860,7 @@ See tf.nn.softmax_cross_entropy_with_logits_v2. """ +@tf_export("nn.softmax_cross_entropy_with_logits") @deprecation.deprecated(date=None, instructions=_XENT_DEPRECATION) def softmax_cross_entropy_with_logits(_sentinel=None, # pylint: disable=invalid-name labels=None, logits=None, @@ -1898,6 +1917,7 @@ def softmax_cross_entropy_with_logits(_sentinel=None, # pylint: disable=invalid labels=labels, logits=logits, dim=dim, name=name) +@tf_export("nn.sparse_softmax_cross_entropy_with_logits") def sparse_softmax_cross_entropy_with_logits(_sentinel=None, # pylint: disable=invalid-name labels=None, logits=None, name=None): @@ -1996,6 +2016,7 @@ def sparse_softmax_cross_entropy_with_logits(_sentinel=None, # pylint: disable= return cost +@tf_export("nn.avg_pool") def avg_pool(value, ksize, strides, padding, data_format="NHWC", name=None): """Performs the average pooling on the input. @@ -2028,6 +2049,7 @@ def avg_pool(value, ksize, strides, padding, data_format="NHWC", name=None): name=name) +@tf_export("nn.max_pool") def max_pool(value, ksize, strides, padding, data_format="NHWC", name=None): """Performs the max pooling on the input. @@ -2099,6 +2121,7 @@ def _calc_bias_add_flops(graph, node): return ops.OpStats("flops", input_count) +@tf_export("nn.xw_plus_b") def xw_plus_b(x, weights, biases, name=None): # pylint: disable=invalid-name """Computes matmul(x, weights) + biases. @@ -2145,6 +2168,7 @@ def xw_plus_b_v1(x, weights, biases, name=None): # pylint: disable=invalid-name return bias_add_v1(mm, biases, name=name) +@tf_export("nn.dropout") def dropout(x, keep_prob, noise_shape=None, seed=None, name=None): # pylint: disable=invalid-name """Computes dropout. @@ -2209,6 +2233,7 @@ def dropout(x, keep_prob, noise_shape=None, seed=None, name=None): # pylint: di return ret +@tf_export("nn.top_k") def top_k(input, k=1, sorted=True, name=None): """Finds values and indices of the `k` largest entries for the last dimension. @@ -2266,6 +2291,7 @@ def nth_element(input, n, reverse=False, name=None): return gen_nn_ops.nth_element(input, n, reverse=reverse, name=name) +@tf_export("nn.conv1d") @deprecation.deprecated_arg_values( None, "`NCHW` for data_format is deprecated, use `NCW` instead", warn_once=True, data_format="NCHW") @@ -2300,7 +2326,7 @@ def conv1d(value, filters, stride, padding, returned to the caller. Args: - value: A 3D `Tensor`. Must be of type `float32` or `float64`. + value: A 3D `Tensor`. Must be of type `float16` or `float32`. filters: A 3D `Tensor`. Must have the same type as `input`. stride: An `integer`. The number of entries by which the filter is moved right at each step. @@ -2451,6 +2477,7 @@ def _calc_dilation2d_flops(graph, node): return ops.OpStats("flops", (output_count * filter_height * filter_width * 2)) +@tf_export("nn.erosion2d") def erosion2d(value, kernel, strides, rates, padding, name=None): """Computes the grayscale erosion of 4-D `value` and 3-D `kernel` tensors. @@ -2508,6 +2535,7 @@ def erosion2d(value, kernel, strides, rates, padding, name=None): name=name)) +@tf_export("nn.in_top_k") def in_top_k(predictions, targets, k, name=None): r"""Says whether the targets are in the top `K` predictions. diff --git a/tensorflow/python/ops/nn_test.py b/tensorflow/python/ops/nn_test.py index 66bc0803b736829ad8d8f3243bf23e146f2f89b9..5a45bdc1e5e1d38a34176ed9443fcd1713f38e1e 100644 --- a/tensorflow/python/ops/nn_test.py +++ b/tensorflow/python/ops/nn_test.py @@ -131,8 +131,7 @@ class LogPoissonLossTest(test_lib.TestCase): y_np = self._log_poisson_loss(x_np, z_np, compute_full_loss=False) y_np_stirling = self._log_poisson_loss(x_np, z_np, compute_full_loss=True) y_tf = nn_impl.log_poisson_loss(z_np, x_np, compute_full_loss=False) - y_tf_stirling = nn_impl.log_poisson_loss( - z_np, x_np, compute_full_loss=True) + y_tf_stirling = nn_impl.log_poisson_loss(z_np, x_np, compute_full_loss=True) y_tf_np = self.evaluate(y_tf) y_tf_np_stirling = self.evaluate(y_tf_stirling) eps = 1e-3 @@ -773,8 +772,8 @@ class ComputeSampledLogitsTest(test_lib.TestCase): def _SoftmaxCrossEntropyWithLogits(logits, targets): # logits, targets: float arrays of the same shape. assert logits.shape == targets.shape - stable_exp_logits = np.exp(logits - np.amax( - logits, axis=1, keepdims=True)) + stable_exp_logits = np.exp( + logits - np.amax(logits, axis=1, keepdims=True)) pred = stable_exp_logits / np.sum(stable_exp_logits, 1, keepdims=True) return -np.sum(targets * np.log(pred + 1.0e-20), axis=1) @@ -865,8 +864,8 @@ class LeakyReluTest(test_lib.TestCase): batch_size = 3 height, width = 4, 4 np.random.seed(1) # Make it reproducible. - inputs = np.random.uniform( - size=(batch_size, height, width, 3)).astype(np.float32) + inputs = np.random.uniform(size=(batch_size, height, width, 3)).astype( + np.float32) inputs = constant_op.constant(inputs) outputs = nn_ops.leaky_relu(inputs) @@ -878,11 +877,14 @@ class LeakyReluTest(test_lib.TestCase): self.assertAllClose(inputs, outputs) def testValues(self): - np_values = np.array([-1.0, 0.0, 0.5, 1.0, 2.0], dtype=np.float32) - outputs = nn_ops.leaky_relu(constant_op.constant(np_values)) - with self.test_session() as sess: - outputs = sess.run(outputs) - self.assertAllClose(outputs, [-0.2, 0.0, 0.5, 1.0, 2.0]) + for dtype in [np.int32, np.int64, np.float16, np.float32, np.float64]: + np_values = np.array([-2, -1, 0, 1, 2], dtype=dtype) + outputs = nn_ops.leaky_relu(constant_op.constant(np_values)) + with self.test_session() as sess: + outputs = sess.run(outputs) + tol = 2e-3 if dtype == np.float16 else 1e-6 + self.assertAllClose( + outputs, [-0.4, -0.2, 0.0, 1.0, 2.0], rtol=tol, atol=tol) class SwishTest(test_lib.TestCase): @@ -913,7 +915,10 @@ class SwishTest(test_lib.TestCase): class MomentsTest(test_lib.TestCase): - def doOutputTest(self, input_shape, moments_axes, tol=1e-4, + def doOutputTest(self, + input_shape, + moments_axes, + tol=1e-4, check_gradients=False): for mu in [0.0, 1.0, 1e3]: for sigma in [1.0, 0.1]: diff --git a/tensorflow/python/ops/numerics.py b/tensorflow/python/ops/numerics.py index f3558fda9ca940f2567a451bb6ad14feb10aaba7..b4ce1cbf25346412e2781a520b7e2cdcf720bcd5 100644 --- a/tensorflow/python/ops/numerics.py +++ b/tensorflow/python/ops/numerics.py @@ -24,8 +24,10 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.util.tf_export import tf_export +@tf_export("verify_tensor_all_finite") def verify_tensor_all_finite(t, msg, name=None): """Assert that the tensor does not contain any NaN's or Inf's. @@ -45,6 +47,7 @@ def verify_tensor_all_finite(t, msg, name=None): return out +@tf_export("add_check_numerics_ops") def add_check_numerics_ops(): """Connect a `check_numerics` to every floating point tensor. diff --git a/tensorflow/python/ops/parsing_ops.py b/tensorflow/python/ops/parsing_ops.py index 7b6f08f68cec60a464a31671bab2cf88b3293bb9..b0315ceee268be8ac1813dae5a262a7d9496e154 100644 --- a/tensorflow/python/ops/parsing_ops.py +++ b/tensorflow/python/ops/parsing_ops.py @@ -36,6 +36,7 @@ from tensorflow.python.ops import sparse_ops from tensorflow.python.ops.gen_parsing_ops import * # pylint: enable=wildcard-import,undefined-variable from tensorflow.python.platform import tf_logging +from tensorflow.python.util.tf_export import tf_export ops.NotDifferentiable("DecodeRaw") @@ -44,6 +45,7 @@ ops.NotDifferentiable("SerializeTensor") ops.NotDifferentiable("StringToNumber") +@tf_export("VarLenFeature") class VarLenFeature(collections.namedtuple("VarLenFeature", ["dtype"])): """Configuration for parsing a variable-length input feature. @@ -53,6 +55,7 @@ class VarLenFeature(collections.namedtuple("VarLenFeature", ["dtype"])): pass +@tf_export("SparseFeature") class SparseFeature( collections.namedtuple( "SparseFeature", @@ -127,6 +130,7 @@ class SparseFeature( cls, index_key, value_key, dtype, size, already_sorted) +@tf_export("FixedLenFeature") class FixedLenFeature(collections.namedtuple( "FixedLenFeature", ["shape", "dtype", "default_value"])): """Configuration for parsing a fixed-length input feature. @@ -146,6 +150,7 @@ class FixedLenFeature(collections.namedtuple( cls, shape, dtype, default_value) +@tf_export("FixedLenSequenceFeature") class FixedLenSequenceFeature(collections.namedtuple( "FixedLenSequenceFeature", ["shape", "dtype", "allow_missing", "default_value"])): @@ -355,6 +360,7 @@ def _prepend_none_dimension(features): return features +@tf_export("parse_example") def parse_example(serialized, features, name=None, example_names=None): # pylint: disable=line-too-long """Parses `Example` protos into a `dict` of tensors. @@ -715,6 +721,7 @@ def _parse_example_raw(serialized, return dict(zip(sparse_keys + dense_keys, sparse_tensors + dense_values)) +@tf_export("parse_single_example") def parse_single_example(serialized, features, name=None, example_names=None): """Parses a single `Example` proto. @@ -850,6 +857,7 @@ def _parse_single_example_raw(serialized, return outputs +@tf_export("parse_single_sequence_example") def parse_single_sequence_example( serialized, context_features=None, sequence_features=None, example_name=None, name=None): @@ -1171,6 +1179,7 @@ def _parse_single_sequence_example_raw(serialized, # Swap `name` and `na_value` for backward compatibility. +@tf_export("decode_csv") def decode_csv(records, record_defaults, field_delim=",", use_quote_delim=True, name=None, na_value=""): # pylint: disable=protected-access diff --git a/tensorflow/python/ops/partitioned_variables.py b/tensorflow/python/ops/partitioned_variables.py index edcc0e1d7c11f86ace8e42221308270ccc188b5d..174cabdf8027e75c780441d06a98a24c19be0cfc 100644 --- a/tensorflow/python/ops/partitioned_variables.py +++ b/tensorflow/python/ops/partitioned_variables.py @@ -58,6 +58,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import variable_scope from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.util.tf_export import tf_export __all__ = [ "create_partitioned_variables", @@ -67,6 +68,7 @@ __all__ = [ ] +@tf_export("variable_axis_size_partitioner") def variable_axis_size_partitioner( max_shard_bytes, axis=0, bytes_per_string_element=16, max_shards=None): """Get a partitioner for VariableScope to keep shards below `max_shard_bytes`. @@ -151,6 +153,7 @@ def variable_axis_size_partitioner( return _partitioner +@tf_export("min_max_variable_partitioner") def min_max_variable_partitioner(max_partitions=1, axis=0, min_slice_size=256 << 10, bytes_per_string_element=16): @@ -214,6 +217,7 @@ def min_max_variable_partitioner(max_partitions=1, axis=0, return _partitioner +@tf_export("fixed_size_partitioner") def fixed_size_partitioner(num_shards, axis=0): """Partitioner to specify a fixed number of shards along given axis. @@ -232,6 +236,7 @@ def fixed_size_partitioner(num_shards, axis=0): return _partitioner +@tf_export("create_partitioned_variables") def create_partitioned_variables( shape, slicing, initializer, dtype=dtypes.float32, trainable=True, collections=None, name=None, reuse=None): diff --git a/tensorflow/python/ops/random_ops.py b/tensorflow/python/ops/random_ops.py index a2264a7bdfff398e405ccd4a509d20c592ee886b..2c86358d21b1c280b8d7ade625fd4b7a44c5de26 100644 --- a/tensorflow/python/ops/random_ops.py +++ b/tensorflow/python/ops/random_ops.py @@ -29,6 +29,7 @@ from tensorflow.python.ops import math_ops # go/tf-wildcard-import # pylint: disable=wildcard-import from tensorflow.python.ops.gen_random_ops import * +from tensorflow.python.util.tf_export import tf_export # pylint: enable=wildcard-import @@ -43,6 +44,7 @@ def _ShapeTensor(shape): # pylint: disable=protected-access +@tf_export("random_normal") def random_normal(shape, mean=0.0, stddev=1.0, @@ -135,6 +137,7 @@ def parameterized_truncated_normal(shape, return rnd +@tf_export("truncated_normal") def truncated_normal(shape, mean=0.0, stddev=1.0, @@ -179,6 +182,7 @@ ops.NotDifferentiable("ParameterizedTruncatedNormal") ops.NotDifferentiable("TruncatedNormal") +@tf_export("random_uniform") def random_uniform(shape, minval=0, maxval=None, @@ -244,6 +248,7 @@ def random_uniform(shape, ops.NotDifferentiable("RandomUniform") +@tf_export("random_shuffle") def random_shuffle(value, seed=None, name=None): """Randomly shuffles a tensor along its first dimension. @@ -274,6 +279,7 @@ def random_shuffle(value, seed=None, name=None): value, seed=seed1, seed2=seed2, name=name) +@tf_export("random_crop") def random_crop(value, size, seed=None, name=None): """Randomly crops a tensor to a given size. @@ -316,6 +322,7 @@ def random_crop(value, size, seed=None, name=None): return array_ops.slice(value, offset, size, name=name) +@tf_export("multinomial") def multinomial(logits, num_samples, seed=None, name=None, output_dtype=None): """Draws samples from a multinomial distribution. @@ -351,6 +358,7 @@ def multinomial(logits, num_samples, seed=None, name=None, output_dtype=None): ops.NotDifferentiable("Multinomial") +@tf_export("random_gamma") def random_gamma(shape, alpha, beta=None, @@ -418,6 +426,7 @@ def random_gamma(shape, ops.NotDifferentiable("RandomGamma") +@tf_export("random_poisson") def random_poisson(lam, shape, dtype=dtypes.float32, seed=None, name=None): """Draws `shape` samples from each of the given Poisson distribution(s). diff --git a/tensorflow/python/ops/rnn.py b/tensorflow/python/ops/rnn.py index fd14740a00a24b006cd1e47b20d46e86e261528a..a1008f1c834f7c01af0ff8b3a0a648f499ce1f8a 100644 --- a/tensorflow/python/ops/rnn.py +++ b/tensorflow/python/ops/rnn.py @@ -35,12 +35,12 @@ from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import control_flow_util from tensorflow.python.ops import math_ops from tensorflow.python.ops import rnn_cell_impl from tensorflow.python.ops import tensor_array_ops from tensorflow.python.ops import variable_scope as vs from tensorflow.python.util import nest +from tensorflow.python.util.tf_export import tf_export # pylint: disable=protected-access @@ -321,6 +321,7 @@ def _reverse_seq(input_seq, lengths): return results +@tf_export("nn.bidirectional_dynamic_rnn") def bidirectional_dynamic_rnn(cell_fw, cell_bw, inputs, sequence_length=None, initial_state_fw=None, initial_state_bw=None, dtype=None, parallel_iterations=None, @@ -450,6 +451,7 @@ def bidirectional_dynamic_rnn(cell_fw, cell_bw, inputs, sequence_length=None, return (outputs, output_states) +@tf_export("nn.dynamic_rnn") def dynamic_rnn(cell, inputs, sequence_length=None, initial_state=None, dtype=None, parallel_iterations=None, swap_memory=False, time_major=False, scope=None): @@ -723,6 +725,8 @@ def _dynamic_rnn_loop(cell, if sequence_length is not None: min_sequence_length = math_ops.reduce_min(sequence_length) max_sequence_length = math_ops.reduce_max(sequence_length) + else: + max_sequence_length = time_steps time = array_ops.constant(0, dtype=dtypes.int32, name="time") @@ -807,28 +811,18 @@ def _dynamic_rnn_loop(cell, return (time + 1, output_ta_t, new_state) - # TODO(pbar) `loop_bound` can be reduced to `max_sequence_length` once - # TensorArray shape inference is working. When sequence lengths are highly - # variable, this will reduce the performance overheads of padding to a fixed - # maximum length. - loop_bound = time_steps - - # This is a workaround since we cannot currently use maximum_iterations if - # time_steps is defined inside control flow, see the comment in - # control_flow_ops.py. - if (context.in_eager_mode() or - not (control_flow_util.IsInWhileLoop(time_steps.op) or - control_flow_util.IsInCond(time_steps.op))): - maximum_iterations = time_steps + if in_graph_mode: + loop_bound = max_sequence_length else: - maximum_iterations = None + # Using max_sequence_length isn't currently supported in the Eager branch. + loop_bound = time_steps _, output_final_ta, final_state = control_flow_ops.while_loop( cond=lambda time, *_: time < loop_bound, body=_time_step, loop_vars=(time, output_ta, state), parallel_iterations=parallel_iterations, - maximum_iterations=maximum_iterations, + maximum_iterations=time_steps, swap_memory=swap_memory) # Unpack final output if not using output tuples. @@ -850,6 +844,7 @@ def _dynamic_rnn_loop(cell, return (final_outputs, final_state) +@tf_export("nn.raw_rnn") def raw_rnn(cell, loop_fn, parallel_iterations=None, swap_memory=False, scope=None): """Creates an `RNN` specified by RNNCell `cell` and loop function `loop_fn`. @@ -1157,6 +1152,7 @@ def raw_rnn(cell, loop_fn, return (emit_ta, final_state, final_loop_state) +@tf_export("nn.static_rnn") def static_rnn(cell, inputs, initial_state=None, @@ -1326,6 +1322,7 @@ def static_rnn(cell, return (outputs, state) +@tf_export("nn.static_state_saving_rnn") def static_state_saving_rnn(cell, inputs, state_saver, @@ -1410,6 +1407,7 @@ def static_state_saving_rnn(cell, return (outputs, state) +@tf_export("nn.static_bidirectional_rnn") def static_bidirectional_rnn(cell_fw, cell_bw, inputs, diff --git a/tensorflow/python/ops/rnn_cell_impl.py b/tensorflow/python/ops/rnn_cell_impl.py index b41aff76d4961c8a563599ee01e5956ab05fc71d..f1ac3e9bafa09e4647b4a4263e74fad29b643fd5 100644 --- a/tensorflow/python/ops/rnn_cell_impl.py +++ b/tensorflow/python/ops/rnn_cell_impl.py @@ -47,6 +47,7 @@ from tensorflow.python.ops import variable_scope as vs from tensorflow.python.ops import variables as tf_variables from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import nest +from tensorflow.python.util.tf_export import tf_export _BIAS_VARIABLE_NAME = "bias" @@ -133,6 +134,7 @@ def _zero_state_tensors(state_size, batch_size, dtype): return nest.map_structure(get_state_shape, state_size) +@tf_export("nn.rnn_cell.RNNCell") class RNNCell(base_layer.Layer): """Abstract object representing an RNN cell. @@ -294,6 +296,7 @@ class _LayerRNNCell(RNNCell): *args, **kwargs) +@tf_export("nn.rnn_cell.BasicRNNCell") class BasicRNNCell(_LayerRNNCell): """The most basic RNN cell. @@ -351,6 +354,7 @@ class BasicRNNCell(_LayerRNNCell): return output, output +@tf_export("nn.rnn_cell.GRUCell") class GRUCell(_LayerRNNCell): """Gated Recurrent Unit cell (cf. http://arxiv.org/abs/1406.1078). @@ -448,6 +452,7 @@ class GRUCell(_LayerRNNCell): _LSTMStateTuple = collections.namedtuple("LSTMStateTuple", ("c", "h")) +@tf_export("nn.rnn_cell.LSTMStateTuple") class LSTMStateTuple(_LSTMStateTuple): """Tuple used by LSTM Cells for `state_size`, `zero_state`, and output state. @@ -467,6 +472,7 @@ class LSTMStateTuple(_LSTMStateTuple): return c.dtype +@tf_export("nn.rnn_cell.BasicLSTMCell") class BasicLSTMCell(_LayerRNNCell): """Basic LSTM recurrent network cell. @@ -591,6 +597,7 @@ class BasicLSTMCell(_LayerRNNCell): return new_h, new_state +@tf_export("nn.rnn_cell.LSTMCell") class LSTMCell(_LayerRNNCell): """Long short-term memory unit (LSTM) recurrent network cell. @@ -834,6 +841,7 @@ def _default_dropout_state_filter_visitor(substate): return True +@tf_export("nn.rnn_cell.DropoutWrapper") class DropoutWrapper(RNNCell): """Operator adding dropout to inputs and outputs of the given cell.""" @@ -979,6 +987,10 @@ class DropoutWrapper(RNNCell): string = (str(self._seed) + salt).encode("utf-8") return int(hashlib.md5(string).hexdigest()[:8], 16) & 0x7FFFFFFF + @property + def wrapped_cell(self): + return self._cell + @property def state_size(self): return self._cell.state_size @@ -1058,6 +1070,7 @@ class DropoutWrapper(RNNCell): return output, new_state +@tf_export("nn.rnn_cell.ResidualWrapper") class ResidualWrapper(RNNCell): """RNNCell wrapper that ensures cell inputs are added to the outputs.""" @@ -1113,6 +1126,7 @@ class ResidualWrapper(RNNCell): return (res_outputs, new_state) +@tf_export("nn.rnn_cell.DeviceWrapper") class DeviceWrapper(RNNCell): """Operator that ensures an RNNCell runs on a particular device.""" @@ -1147,6 +1161,7 @@ class DeviceWrapper(RNNCell): return self._cell(inputs, state, scope=scope) +@tf_export("nn.rnn_cell.MultiRNNCell") class MultiRNNCell(RNNCell): """RNN cell composed sequentially of multiple simple cells.""" diff --git a/tensorflow/python/ops/script_ops.py b/tensorflow/python/ops/script_ops.py index c0c1ade495455df6a4965eefba4b823ca84e7c31..4b5072fd6799ae289d3c1a1b2a40878e36604bf4 100644 --- a/tensorflow/python/ops/script_ops.py +++ b/tensorflow/python/ops/script_ops.py @@ -33,6 +33,7 @@ from tensorflow.python.eager import context from tensorflow.python.framework import function from tensorflow.python.framework import ops from tensorflow.python.ops import gen_script_ops +from tensorflow.python.util.tf_export import tf_export class EagerFunc(object): @@ -243,6 +244,7 @@ def eager_py_func(func, inp, Tout, name=None): return _internal_py_func(func=func, inp=inp, Tout=Tout, eager=True, name=name) +@tf_export("py_func") def py_func(func, inp, Tout, stateful=True, name=None): """Wraps a python function and uses it as a TensorFlow op. diff --git a/tensorflow/python/ops/session_ops.py b/tensorflow/python/ops/session_ops.py index dc4d913c938a89f23297c02c2d18b286fd3bb9e8..cedd36c1deed541adcf601ff9447345e2279e8f9 100644 --- a/tensorflow/python/ops/session_ops.py +++ b/tensorflow/python/ops/session_ops.py @@ -36,6 +36,7 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_data_flow_ops from tensorflow.python.util import compat +from tensorflow.python.util.tf_export import tf_export def encode_resource_handle(resource_handle): @@ -141,6 +142,7 @@ class TensorHandle(object): return feeder.op.name + ";" + TensorHandle._get_reader_key(handle) +@tf_export("get_session_handle") def get_session_handle(data, name=None): """Return the handle of `data`. @@ -183,6 +185,7 @@ def get_session_handle(data, name=None): return gen_data_flow_ops._get_session_handle(data, name=name) # pylint: disable=protected-access +@tf_export("get_session_tensor") def get_session_tensor(handle, dtype, name=None): """Get the tensor of type `dtype` by feeding a tensor handle. @@ -223,6 +226,7 @@ def get_session_tensor(handle, dtype, name=None): return (holder, tensor) +@tf_export("delete_session_tensor") def delete_session_tensor(handle, name=None): """Delete the tensor for the given tensor handle. diff --git a/tensorflow/python/ops/sets_impl.py b/tensorflow/python/ops/sets_impl.py index 6aa9e3419ea497594b455bc5481dec5a77404bcf..b0eecd8a1e812857de8f47e1370e4fc5f1004bc0 100644 --- a/tensorflow/python/ops/sets_impl.py +++ b/tensorflow/python/ops/sets_impl.py @@ -23,6 +23,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import gen_set_ops +from tensorflow.python.util.tf_export import tf_export _VALID_DTYPES = set([ @@ -30,6 +31,7 @@ _VALID_DTYPES = set([ dtypes.uint8, dtypes.uint16, dtypes.string]) +@tf_export("sets.set_size") def set_size(a, validate_indices=True): """Compute number of unique elements along last dimension of `a`. @@ -131,6 +133,7 @@ def _set_operation(a, b, set_operation, validate_indices=True): return sparse_tensor.SparseTensor(indices, values, shape) +@tf_export("sets.set_intersection") def set_intersection(a, b, validate_indices=True): """Compute set intersection of elements in last dimension of `a` and `b`. @@ -197,6 +200,7 @@ def set_intersection(a, b, validate_indices=True): return _set_operation(a, b, "intersection", validate_indices) +@tf_export("sets.set_difference") def set_difference(a, b, aminusb=True, validate_indices=True): """Compute set difference of elements in last dimension of `a` and `b`. @@ -267,6 +271,7 @@ def set_difference(a, b, aminusb=True, validate_indices=True): return _set_operation(a, b, "a-b" if aminusb else "b-a", validate_indices) +@tf_export("sets.set_union") def set_union(a, b, validate_indices=True): """Compute set union of elements in last dimension of `a` and `b`. diff --git a/tensorflow/python/ops/sparse_ops.py b/tensorflow/python/ops/sparse_ops.py index c368d166f5654a3fc5c3464e552e6497b6ee19a3..3224856d7be0674a2cc064a226bf1a38abb6bc2b 100644 --- a/tensorflow/python/ops/sparse_ops.py +++ b/tensorflow/python/ops/sparse_ops.py @@ -65,6 +65,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops.gen_sparse_ops import * # pylint: enable=wildcard-import from tensorflow.python.util import deprecation +from tensorflow.python.util.tf_export import tf_export def _convert_to_sparse_tensor(sp_input): @@ -108,6 +109,7 @@ def _convert_to_sparse_tensors(sp_inputs): # pylint: disable=protected-access +@tf_export("sparse_concat") def sparse_concat(axis, sp_inputs, name=None, @@ -236,6 +238,7 @@ def sparse_concat(axis, return sparse_tensor.SparseTensor(output_ind, output_val, output_shape) +@tf_export("sparse_add") def sparse_add(a, b, thresh=0): """Adds two tensors, at least one of each is a `SparseTensor`. @@ -463,6 +466,7 @@ def sparse_dense_cwise_add(sp_t, dense_t): return sparse_tensor.SparseTensor(sp_t.indices, result, sp_t.dense_shape) +@tf_export("sparse_reorder") def sparse_reorder(sp_input, name=None): """Reorders a `SparseTensor` into the canonical, row-major ordering. @@ -511,6 +515,7 @@ def sparse_reorder(sp_input, name=None): return sparse_tensor.SparseTensor(reordered_ind, reordered_val, dense_shape) +@tf_export("sparse_reshape") def sparse_reshape(sp_input, shape, name=None): """Reshapes a `SparseTensor` to represent values in a new dense shape. @@ -603,6 +608,7 @@ class KeywordRequired(object): return "KeywordRequired()" +@tf_export("sparse_split") def sparse_split(keyword_required=KeywordRequired(), sp_input=None, num_split=None, axis=None, name=None, split_dim=None): @@ -669,6 +675,7 @@ def sparse_split(keyword_required=KeywordRequired(), return sparse_tensors +@tf_export("sparse_slice") def sparse_slice(sp_input, start, size, name=None): """Slice a `SparseTensor` based on the `start` and `size. @@ -713,6 +720,8 @@ def sparse_slice(sp_input, start, size, name=None): output_values, output_shape) + +@tf_export("sparse_to_dense") def sparse_to_dense(sparse_indices, output_shape, sparse_values, @@ -768,6 +777,7 @@ def sparse_to_dense(sparse_indices, name=name) +@tf_export("sparse_reduce_max") def sparse_reduce_max(sp_input, axis=None, keep_dims=False, reduction_axes=None): """Computes the max of elements across dimensions of a SparseTensor. @@ -815,6 +825,7 @@ def sparse_reduce_max(sp_input, axis=None, keep_dims=False, keep_dims) +@tf_export("sparse_reduce_max_sparse") def sparse_reduce_max_sparse(sp_input, axis=None, keep_dims=False, reduction_axes=None): """Computes the max of elements across dimensions of a SparseTensor. @@ -852,6 +863,7 @@ def sparse_reduce_max_sparse(sp_input, axis=None, keep_dims=False, return sparse_tensor.SparseTensor(output_ind, output_val, output_shape) +@tf_export("sparse_reduce_sum") def sparse_reduce_sum(sp_input, axis=None, keep_dims=False, reduction_axes=None): """Computes the sum of elements across dimensions of a SparseTensor. @@ -899,6 +911,7 @@ def sparse_reduce_sum(sp_input, axis=None, keep_dims=False, keep_dims) +@tf_export("sparse_reduce_sum_sparse") def sparse_reduce_sum_sparse(sp_input, axis=None, keep_dims=False, reduction_axes=None): """Computes the sum of elements across dimensions of a SparseTensor. @@ -936,6 +949,7 @@ def sparse_reduce_sum_sparse(sp_input, axis=None, keep_dims=False, return sparse_tensor.SparseTensor(output_ind, output_val, output_shape) +@tf_export("sparse_tensor_to_dense") def sparse_tensor_to_dense(sp_input, default_value=0, validate_indices=True, @@ -987,6 +1001,7 @@ def sparse_tensor_to_dense(sp_input, name=name) +@tf_export("sparse_to_indicator") def sparse_to_indicator(sp_input, vocab_size, name=None): """Converts a `SparseTensor` of ids into a dense bool indicator tensor. @@ -1049,6 +1064,7 @@ def sparse_to_indicator(sp_input, vocab_size, name=None): sp_new, default_value=False, validate_indices=False, name=name) +@tf_export("sparse_merge") def sparse_merge(sp_ids, sp_values, vocab_size, name=None, already_sorted=False): """Combines a batch of feature ids and values into a single `SparseTensor`. @@ -1189,6 +1205,7 @@ def sparse_merge(sp_ids, sp_values, vocab_size, name=None, return result if already_sorted else sparse_reorder(result) +@tf_export("sparse_retain") def sparse_retain(sp_input, to_retain): """Retains specified non-empty values within a `SparseTensor`. @@ -1232,6 +1249,7 @@ def sparse_retain(sp_input, to_retain): array_ops.identity(sp_input.dense_shape)) +@tf_export("sparse_reset_shape") def sparse_reset_shape(sp_input, new_shape=None): """Resets the shape of a `SparseTensor` with indices and values unchanged. @@ -1333,6 +1351,7 @@ def sparse_reset_shape(sp_input, new_shape=None): return sparse_tensor.SparseTensor(in_indices, in_values, output_shape_tensor) +@tf_export("sparse_fill_empty_rows") def sparse_fill_empty_rows(sp_input, default_value, name=None): """Fills empty rows in the input 2-D `SparseTensor` with a default value. @@ -1396,6 +1415,7 @@ def sparse_fill_empty_rows(sp_input, default_value, name=None): empty_row_indicator) +@tf_export("serialize_sparse") def serialize_sparse(sp_input, name=None, out_type=dtypes.string): """Serialize a `SparseTensor` into a 3-vector (1-D `Tensor`) object. @@ -1421,6 +1441,7 @@ def serialize_sparse(sp_input, name=None, out_type=dtypes.string): out_type=out_type) +@tf_export("serialize_many_sparse") def serialize_many_sparse(sp_input, name=None, out_type=dtypes.string): """Serialize `N`-minibatch `SparseTensor` into an `[N, 3]` `Tensor`. @@ -1521,6 +1542,7 @@ def deserialize_sparse(serialized_sparse, dtype, rank=None, name=None): return sparse_tensor.SparseTensor(output_indices, output_values, output_shape) +@tf_export("deserialize_many_sparse") def deserialize_many_sparse(serialized_sparse, dtype, rank=None, name=None): """Deserialize and concatenate `SparseTensors` from a serialized minibatch. @@ -1590,6 +1612,7 @@ def deserialize_many_sparse(serialized_sparse, dtype, rank=None, name=None): return sparse_tensor.SparseTensor(output_indices, output_values, output_shape) +@tf_export("sparse_tensor_dense_matmul") def sparse_tensor_dense_matmul(sp_a, b, adjoint_a=False, @@ -1806,6 +1829,7 @@ def sparse_tensor_dense_matmul(sp_a, adjoint_b=adjoint_b) +@tf_export("sparse_softmax") def sparse_softmax(sp_input, name=None): """Applies softmax to a batched N-D `SparseTensor`. @@ -1860,6 +1884,7 @@ def sparse_softmax(sp_input, name=None): sp_input.indices, out_vals, sp_input.dense_shape) +@tf_export("sparse_maximum") def sparse_maximum(sp_a, sp_b, name=None): """Returns the element-wise max of two SparseTensors. @@ -1896,6 +1921,7 @@ def sparse_maximum(sp_a, sp_b, name=None): return sparse_tensor.SparseTensor(out_indices, out_values, sp_a.dense_shape) +@tf_export("sparse_minimum") def sparse_minimum(sp_a, sp_b, name=None): """Returns the element-wise min of two SparseTensors. @@ -1932,6 +1958,7 @@ def sparse_minimum(sp_a, sp_b, name=None): return sparse_tensor.SparseTensor(out_indices, out_values, sp_a.dense_shape) +@tf_export("sparse_transpose") def sparse_transpose(sp_input, perm=None, name=None): """Transposes a `SparseTensor` diff --git a/tensorflow/python/ops/special_math_ops.py b/tensorflow/python/ops/special_math_ops.py index fe3f7343222f7b10bc6af272146e8960d6f39c3d..19900870725f5f01c4ba12979265a5533297d4c3 100644 --- a/tensorflow/python/ops/special_math_ops.py +++ b/tensorflow/python/ops/special_math_ops.py @@ -31,9 +31,11 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.util.tf_export import tf_export # TODO(b/27419586) Change docstring for required dtype of x once int allowed +@tf_export('lbeta') def lbeta(x, name='lbeta'): r"""Computes \\(ln(|Beta(x)|)\\), reducing along the last dimension. @@ -82,6 +84,7 @@ def lbeta(x, name='lbeta'): return result +@tf_export('einsum', 'linalg.einsum') def einsum(equation, *inputs, **kwargs): """A generalized contraction between tensors of arbitrary dimension. diff --git a/tensorflow/python/ops/spectral_ops.py b/tensorflow/python/ops/spectral_ops.py index 69f868c67ada748ef76029155e470d79a643cbf4..a5796882768a87c76e0acdec9b3d99caf41e02eb 100644 --- a/tensorflow/python/ops/spectral_ops.py +++ b/tensorflow/python/ops/spectral_ops.py @@ -41,6 +41,7 @@ from tensorflow.python.ops import array_ops as _array_ops from tensorflow.python.ops import gen_spectral_ops from tensorflow.python.ops import math_ops as _math_ops from tensorflow.python.util.all_util import remove_undocumented +from tensorflow.python.util.tf_export import tf_export def _infer_fft_length_for_rfft(input_tensor, fft_rank): @@ -164,11 +165,17 @@ ifft2d = gen_spectral_ops.ifft2d fft3d = gen_spectral_ops.fft3d ifft3d = gen_spectral_ops.ifft3d rfft = _rfft_wrapper(gen_spectral_ops.rfft, 1, "rfft") +tf_export("spectral.rfft")(rfft) irfft = _irfft_wrapper(gen_spectral_ops.irfft, 1, "irfft") +tf_export("spectral.irfft")(irfft) rfft2d = _rfft_wrapper(gen_spectral_ops.rfft2d, 2, "rfft2d") +tf_export("spectral.rfft2d")(rfft2d) irfft2d = _irfft_wrapper(gen_spectral_ops.irfft2d, 2, "irfft2d") +tf_export("spectral.irfft2d")(irfft2d) rfft3d = _rfft_wrapper(gen_spectral_ops.rfft3d, 3, "rfft3d") +tf_export("spectral.rfft3d")(rfft3d) irfft3d = _irfft_wrapper(gen_spectral_ops.irfft3d, 3, "irfft3d") +tf_export("spectral.irfft3d")(irfft3d) def _validate_dct_arguments(dct_type, n, axis, norm): @@ -184,6 +191,7 @@ def _validate_dct_arguments(dct_type, n, axis, norm): # TODO(rjryan): Implement `type`, `n` and `axis` parameters. +@tf_export("spectral.dct") def dct(input, type=2, n=None, axis=-1, norm=None, name=None): # pylint: disable=redefined-builtin """Computes the 1D [Discrete Cosine Transform (DCT)][dct] of `input`. diff --git a/tensorflow/python/ops/state_ops.py b/tensorflow/python/ops/state_ops.py index dee495f78fa5c2fa099772d0a84f5ff0981c8c59..3cc76fdbf34ff6de47d98400cd826d671c9178eb 100644 --- a/tensorflow/python/ops/state_ops.py +++ b/tensorflow/python/ops/state_ops.py @@ -89,6 +89,7 @@ from tensorflow.python.ops import gen_state_ops # go/tf-wildcard-import # pylint: disable=wildcard-import from tensorflow.python.ops.gen_state_ops import * +from tensorflow.python.util.tf_export import tf_export # pylint: enable=wildcard-import @@ -189,6 +190,7 @@ def is_variable_initialized(ref, name=None): name=name) +@tf_export("assign_sub") def assign_sub(ref, value, use_locking=None, name=None): """Update 'ref' by subtracting 'value' from it. @@ -217,6 +219,7 @@ def assign_sub(ref, value, use_locking=None, name=None): return ref.assign_sub(value) +@tf_export("assign_add") def assign_add(ref, value, use_locking=None, name=None): """Update 'ref' by adding 'value' to it. @@ -245,6 +248,7 @@ def assign_add(ref, value, use_locking=None, name=None): return ref.assign_add(value) +@tf_export("assign") def assign(ref, value, validate_shape=None, use_locking=None, name=None): """Update 'ref' by assigning 'value' to it. @@ -277,6 +281,7 @@ def assign(ref, value, validate_shape=None, use_locking=None, name=None): return ref.assign(value) +@tf_export("count_up_to") def count_up_to(ref, limit, name=None): r"""Increments 'ref' until it reaches 'limit'. @@ -299,6 +304,7 @@ def count_up_to(ref, limit, name=None): ref.handle, limit, T=ref.dtype, name=name) +@tf_export("scatter_update") def scatter_update(ref, indices, updates, use_locking=True, name=None): # pylint: disable=line-too-long r"""Applies sparse updates to a variable reference. @@ -354,6 +360,7 @@ def scatter_update(ref, indices, updates, use_locking=True, name=None): return ref.read_value() +@tf_export("scatter_nd_update") def scatter_nd_update(ref, indices, updates, use_locking=True, name=None): r"""Applies sparse `updates` to individual values or slices in a Variable. diff --git a/tensorflow/python/ops/string_ops.py b/tensorflow/python/ops/string_ops.py index f30e79a108f159bb03237f8c232d1ee467ff458d..b8c39d91b41790c6441594b175e8eaa03620e1ec 100644 --- a/tensorflow/python/ops/string_ops.py +++ b/tensorflow/python/ops/string_ops.py @@ -47,9 +47,11 @@ from tensorflow.python.ops import math_ops # pylint: disable=wildcard-import from tensorflow.python.ops.gen_string_ops import * from tensorflow.python.util import deprecation +from tensorflow.python.util.tf_export import tf_export # pylint: enable=wildcard-import +@tf_export("string_split") def string_split(source, delimiter=" ", skip_empty=True): # pylint: disable=invalid-name """Split elements of `source` based on `delimiter` into a `SparseTensor`. @@ -120,6 +122,7 @@ def _reduce_join_reduction_dims(x, axis, reduction_indices): return math_ops.range(array_ops.rank(x) - 1, -1, -1) +@tf_export("reduce_join") def reduce_join(inputs, axis=None, keep_dims=False, separator="", diff --git a/tensorflow/python/ops/summary_ops.py b/tensorflow/python/ops/summary_ops.py index 2cf2eda16e69bcfab766c7adaa4b5d8b40d99723..7f4f4ce5ab4ee2bd309932cb81f05775996371d6 100644 --- a/tensorflow/python/ops/summary_ops.py +++ b/tensorflow/python/ops/summary_ops.py @@ -25,9 +25,11 @@ from tensorflow.python.ops import summary_op_util # go/tf-wildcard-import # pylint: disable=wildcard-import from tensorflow.python.ops.gen_logging_ops import * +from tensorflow.python.util.tf_export import tf_export # pylint: enable=wildcard-import +@tf_export("summary.tensor_summary") def tensor_summary(name, tensor, summary_description=None, diff --git a/tensorflow/python/ops/template.py b/tensorflow/python/ops/template.py index 99a71cbe79cfb2772a279960d2aec1def52960c0..84449e00beb4d2901f57c7cd41a4e755fe343c8c 100644 --- a/tensorflow/python/ops/template.py +++ b/tensorflow/python/ops/template.py @@ -29,11 +29,13 @@ from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import tf_contextlib from tensorflow.python.util import tf_decorator from tensorflow.python.util.deprecation import deprecated +from tensorflow.python.util.tf_export import tf_export __all__ = ["make_template"] +@tf_export("make_template") def make_template(name_, func_, create_scope_now_=False, unique_name_=None, custom_getter_=None, **kwargs): """Given an arbitrary function, wrap it so that it does variable sharing. diff --git a/tensorflow/python/ops/tensor_array_ops.py b/tensorflow/python/ops/tensor_array_ops.py index 398521c9b5ae9240f03a2ba5c4b0681bd8b3bfd7..5cdf03509e3c427deec7e26345059211001e2131 100644 --- a/tensorflow/python/ops/tensor_array_ops.py +++ b/tensorflow/python/ops/tensor_array_ops.py @@ -35,6 +35,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_data_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.util import tf_should_use +from tensorflow.python.util.tf_export import tf_export # _GraphTensorArray accesses many of the hidden generated ops, but is in @@ -711,6 +712,7 @@ class _EagerTensorArray(object): # TensorArray is designed to hide an underlying implementation object # and as such accesses many of that object's hidden fields. # pylint: disable=protected-access +@tf_export("TensorArray") class TensorArray(object): """Class wrapping dynamic-sized, per-time-step, write-once Tensor arrays. diff --git a/tensorflow/python/ops/variable_scope.py b/tensorflow/python/ops/variable_scope.py index 3a39af8e207f154446204b452a00537f9c25fdb1..db594ac6a0bd3c5380ec4dc368a091dbc48980eb 100644 --- a/tensorflow/python/ops/variable_scope.py +++ b/tensorflow/python/ops/variable_scope.py @@ -27,6 +27,7 @@ import sys import traceback import six +from six import iteritems from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.python.eager import context @@ -40,6 +41,7 @@ from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variables from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import tf_contextlib +from tensorflow.python.util.tf_export import tf_export __all__ = ["AUTO_REUSE", "VariableScope", "get_variable_scope", "get_variable", "get_local_variable", "variable_scope", @@ -186,6 +188,7 @@ class _ReuseMode(enum.Enum): # REUSE_TRUE = 3 AUTO_REUSE = _ReuseMode.AUTO_REUSE +tf_export("AUTO_REUSE").export_constant(__name__, "AUTO_REUSE") AUTO_REUSE.__doc__ = """ When passed in as the value for the `reuse` flag, AUTO_REUSE indicates that get_variable() should create the requested variable if it doesn't exist or, if @@ -785,26 +788,16 @@ class _VariableStore(object): if use_resource is None: # Set the default value if unspecified. use_resource = False - if use_resource: - v = resource_variable_ops.ResourceVariable( - initial_value=init_val, - name=name, - trainable=trainable, - collections=collections, - caching_device=caching_device, - dtype=variable_dtype, - validate_shape=validate_shape, - constraint=constraint) - else: - v = variables.Variable( - initial_value=init_val, - name=name, - trainable=trainable, - collections=collections, - caching_device=caching_device, - dtype=variable_dtype, - validate_shape=validate_shape, - constraint=constraint) + v = variable( + initial_value=init_val, + name=name, + trainable=trainable, + collections=collections, + caching_device=caching_device, + dtype=variable_dtype, + validate_shape=validate_shape, + constraint=constraint, + use_resource=use_resource) if context.in_graph_mode() or self._store_eager_variables: # In eager mode we do not want to keep default references to Variable # objects as this will prevent their memory from being released. @@ -863,12 +856,14 @@ class _VariableStore(object): # To stop regularization, use this regularizer +@tf_export("no_regularizer") def no_regularizer(_): """Use this function to prevent regularization of variables.""" return None # TODO(alive): support caching devices and partitioned variables in Eager mode. +@tf_export("VariableScope") class VariableScope(object): """Variable scope object to carry defaults to provide to `get_variable`. @@ -1168,6 +1163,7 @@ _VARSTORE_KEY = ("__variable_store",) _VARSCOPE_KEY = ("__varscope",) +@tf_export("get_variable_scope") def get_variable_scope(): """Returns the current variable scope.""" scope = ops.get_collection(_VARSCOPE_KEY) @@ -1247,7 +1243,38 @@ class EagerVariableStore(object): key=lambda x: x.name) # pylint: enable=protected-access + def copy(self): + """Copy this variable store and all of its contents. + + Variables contained in this store will be copied over to the new variable + store, meaning that they can be modified without affecting the variables in + this store. + + Returns: + A new EagerVariableStore instance containing copied variables. + """ + # pylint: disable=protected-access + new_store = EagerVariableStore() + for key, var in iteritems(self._store._vars): + # Strip device out of variable name. + try: + index = var.name.index(":") + except ValueError: + stripped_var_name = var.name + else: + stripped_var_name = var.name[:index] + + # Create new variable with same value, name, and "trainable" flag. + new_var = resource_variable_ops.ResourceVariable( + var.read_value(), + name=stripped_var_name, + trainable=var._trainable) + new_store._store._vars[key] = new_var + return new_store + # pylint: enable=protected-access + +@tf_export("get_variable") def get_variable(name, shape=None, dtype=None, @@ -1359,6 +1386,7 @@ get_variable.__doc__ = get_variable_or_local_docstring % ( @functools.wraps(get_variable) +@tf_export("get_local_variable") def get_local_variable(*args, **kwargs): kwargs["trainable"] = False if "collections" in kwargs: @@ -1673,7 +1701,8 @@ def _get_unique_variable_scope(prefix): # Named like a function for backwards compatibility with the # @tf_contextlib.contextmanager version, which was switched to a class to avoid # some object creation overhead. -class variable_scope(object): # pylint: disable=invalid-name +@tf_export("variable_scope") # pylint: disable=invalid-name +class variable_scope(object): """A context manager for defining ops that creates variables (layers). This context manager validates that the (optional) `values` are from the same @@ -2006,6 +2035,7 @@ class variable_scope(object): # pylint: disable=invalid-name # pylint: disable=g-doc-return-or-yield +@tf_export("variable_op_scope") @tf_contextlib.contextmanager def variable_op_scope(values, name_or_scope, @@ -2067,21 +2097,26 @@ def _compute_slice_dim_and_shape(full_shape, slicing): return slice_dim, slice_shape -def variable(initial_value=None, - trainable=True, - collections=None, - validate_shape=True, - caching_device=None, - name=None, - dtype=None, - use_resource=None): +def default_variable_creator(next_creator=None, **kwargs): + """Default variable creator.""" + assert next_creator is None + initial_value = kwargs.get("initial_value", None) + trainable = kwargs.get("trainable", True) + collections = kwargs.get("collections", None) + validate_shape = kwargs.get("validate_shape", True) + caching_device = kwargs.get("caching_device", None) + name = kwargs.get("name", None) + dtype = kwargs.get("dtype", None) + constraint = kwargs.get("constraint", None) + use_resource = kwargs.get("use_resource", None) if use_resource is None: use_resource = get_variable_scope().use_resource if use_resource or (use_resource is None and context.in_eager_mode()): return resource_variable_ops.ResourceVariable( initial_value=initial_value, trainable=trainable, collections=collections, validate_shape=validate_shape, - caching_device=caching_device, name=name, dtype=dtype) + caching_device=caching_device, name=name, dtype=dtype, + constraint=constraint) elif not use_resource and context.in_eager_mode(): raise RuntimeError( "VariableScope should use resource variable when eager execution is" @@ -2091,4 +2126,95 @@ def variable(initial_value=None, return variables.Variable( initial_value=initial_value, trainable=trainable, collections=collections, validate_shape=validate_shape, - caching_device=caching_device, name=name, dtype=dtype) + caching_device=caching_device, name=name, dtype=dtype, + constraint=constraint) + + +def _make_getter(captured_getter, captured_previous): + """Gets around capturing loop variables in python being broken.""" + return lambda **kwargs: captured_getter(captured_previous, **kwargs) + + +def variable(initial_value=None, + trainable=True, + collections=None, + validate_shape=True, + caching_device=None, + name=None, + dtype=None, + constraint=None, + use_resource=None): + previous_getter = lambda **kwargs: default_variable_creator(None, **kwargs) + for getter in ops.get_default_graph()._get_variable_creator_stack(): # pylint: disable=protected-access + previous_getter = _make_getter(getter, previous_getter) + return previous_getter(initial_value=initial_value, + trainable=trainable, + collections=collections, + validate_shape=validate_shape, + caching_device=caching_device, + name=name, dtype=dtype, + constraint=constraint, + use_resource=use_resource) + + +@tf_contextlib.contextmanager +def variable_creator_scope(variable_creator): + """Scope which defines a variable creation function to be used by variable(). + + variable_creator is expected to be a function with the following signature: + + ``` + def variable_creator(next_creator, **kwargs) + ``` + + The creator is supposed to eventually call the next_creator to create a + variable if it does want to create a variable and not call Variable or + ResourceVariable directly. This helps make creators composable. A creator may + choose to create multiple variables, return already existing variables, or + simply register that a variable was created and defer to the next creators in + line. Creators can also modify the keyword arguments seen by the next + creators. + + Custom getters in the variable scope will eventually resolve down to these + custom creators when they do create variables. + + The valid keyword arguments in kwds are: + initial_value: A `Tensor`, or Python object convertible to a `Tensor`, + which is the initial value for the Variable. The initial value must have + a shape specified unless `validate_shape` is set to False. Can also be a + callable with no argument that returns the initial value when called. In + that case, `dtype` must be specified. (Note that initializer functions + from init_ops.py must first be bound to a shape before being used here.) + trainable: If `True`, the default, also adds the variable to the graph + collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as + the default list of variables to use by the `Optimizer` classes. + collections: List of graph collections keys. The new variable is added to + these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`. + validate_shape: If `False`, allows the variable to be initialized with a + value of unknown shape. If `True`, the default, the shape of + `initial_value` must be known. + caching_device: Optional device string describing where the Variable + should be cached for reading. Defaults to the Variable's device. + If not `None`, caches on another device. Typical use is to cache + on the device where the Ops using the Variable reside, to deduplicate + copying through `Switch` and other conditional statements. + name: Optional name for the variable. Defaults to `'Variable'` and gets + uniquified automatically. + dtype: If set, initial_value will be converted to the given type. + If `None`, either the datatype will be kept (if `initial_value` is + a Tensor), or `convert_to_tensor` will decide. + constraint: A constraint function to be applied to the variable after + updates by some algorithms. + use_resource: if True, a ResourceVariable is always created. + + This set may grow over time, so it's important the signature of creators is as + mentioned above. + + Args: + variable_creator: the passed creator + + Yields: + A scope in which the creator is active + """ + with ops.get_default_graph()._variable_creator_scope(variable_creator): # pylint: disable=protected-access + yield diff --git a/tensorflow/python/ops/variables.py b/tensorflow/python/ops/variables.py index b25855633ed4ce485090fb47b09e1b5ce0ff2228..7d7fa646c08523c5f572f8f4593c1d8fe8615c67 100644 --- a/tensorflow/python/ops/variables.py +++ b/tensorflow/python/ops/variables.py @@ -31,8 +31,10 @@ from tensorflow.python.ops import state_ops from tensorflow.python.util import compat from tensorflow.python.util import tf_should_use from tensorflow.python.util.deprecation import deprecated +from tensorflow.python.util.tf_export import tf_export +@tf_export("Variable") class Variable(object): """See the @{$variables$Variables How To} for a high level overview. @@ -1308,6 +1310,7 @@ class PartitionedVariable(object): "assign() has not been implemented for PartitionedVariable.") +@tf_export("global_variables") def global_variables(scope=None): """Returns global variables. @@ -1333,6 +1336,7 @@ def global_variables(scope=None): return ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, scope) +@tf_export("all_variables") @deprecated("2017-03-02", "Please use tf.global_variables instead.") def all_variables(): """See `tf.global_variables`.""" @@ -1357,6 +1361,7 @@ def _all_saveable_objects(scope=None): ops.get_collection(ops.GraphKeys.SAVEABLE_OBJECTS, scope)) +@tf_export("local_variables") def local_variables(scope=None): """Returns local variables. @@ -1384,6 +1389,7 @@ def local_variables(scope=None): return ops.get_collection(ops.GraphKeys.LOCAL_VARIABLES, scope) +@tf_export("model_variables") def model_variables(scope=None): """Returns all variables in the MODEL_VARIABLES collection. @@ -1400,6 +1406,7 @@ def model_variables(scope=None): return ops.get_collection(ops.GraphKeys.MODEL_VARIABLES, scope) +@tf_export("trainable_variables") def trainable_variables(scope=None): """Returns all variables created with `trainable=True`. @@ -1421,6 +1428,7 @@ def trainable_variables(scope=None): return ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES, scope) +@tf_export("moving_average_variables") def moving_average_variables(scope=None): """Returns all variables that maintain their moving averages. @@ -1442,6 +1450,7 @@ def moving_average_variables(scope=None): return ops.get_collection(ops.GraphKeys.MOVING_AVERAGE_VARIABLES, scope) +@tf_export("initializers.variables", "variables_initializer") def variables_initializer(var_list, name="init"): """Returns an Op that initializes a list of variables. @@ -1467,6 +1476,7 @@ def variables_initializer(var_list, name="init"): return control_flow_ops.no_op(name=name) +@tf_export("initialize_variables") @tf_should_use.should_use_result @deprecated("2017-03-02", "Use `tf.variables_initializer` instead.") def initialize_variables(var_list, name="init"): @@ -1474,6 +1484,7 @@ def initialize_variables(var_list, name="init"): return variables_initializer(var_list, name=name) +@tf_export("initializers.global_variables", "global_variables_initializer") def global_variables_initializer(): """Returns an Op that initializes global variables. @@ -1487,6 +1498,7 @@ def global_variables_initializer(): return variables_initializer(global_variables()) +@tf_export("initialize_all_variables") @tf_should_use.should_use_result @deprecated("2017-03-02", "Use `tf.global_variables_initializer` instead.") def initialize_all_variables(): @@ -1494,6 +1506,7 @@ def initialize_all_variables(): return global_variables_initializer() +@tf_export("initializers.local_variables", "local_variables_initializer") def local_variables_initializer(): """Returns an Op that initializes all local variables. @@ -1507,6 +1520,7 @@ def local_variables_initializer(): return variables_initializer(local_variables()) +@tf_export("initialize_local_variables") @tf_should_use.should_use_result @deprecated("2017-03-02", "Use `tf.local_variables_initializer` instead.") def initialize_local_variables(): @@ -1514,6 +1528,7 @@ def initialize_local_variables(): return local_variables_initializer() +@tf_export("is_variable_initialized") @tf_should_use.should_use_result def is_variable_initialized(variable): """Tests if a variable has been initialized. @@ -1528,6 +1543,7 @@ def is_variable_initialized(variable): return state_ops.is_variable_initialized(variable) +@tf_export("assert_variables_initialized") @tf_should_use.should_use_result def assert_variables_initialized(var_list=None): """Returns an Op to check if variables are initialized. @@ -1570,6 +1586,7 @@ def assert_variables_initialized(var_list=None): return array_ops.stack(ranks) +@tf_export("report_uninitialized_variables") @tf_should_use.should_use_result def report_uninitialized_variables(var_list=None, name="report_uninitialized_variables"): diff --git a/tensorflow/python/platform/benchmark.py b/tensorflow/python/platform/benchmark.py index 837bca1dbd06c9ee4adbf05bfc7cf3586d072d16..12dae94a6404e58d31cf88af83251e4bc9e50df3 100644 --- a/tensorflow/python/platform/benchmark.py +++ b/tensorflow/python/platform/benchmark.py @@ -33,6 +33,7 @@ from tensorflow.python.platform import app from tensorflow.python.platform import gfile from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import tf_inspect +from tensorflow.python.util.tf_export import tf_export # When a subclass of the Benchmark class is created, it is added to @@ -181,6 +182,7 @@ class Benchmark(six.with_metaclass(_BenchmarkRegistrar, object)): throughput=throughput, extras=extras) +@tf_export("test.Benchmark") class TensorFlowBenchmark(Benchmark): """Abstract class that provides helpers for TensorFlow benchmarks.""" diff --git a/tensorflow/python/platform/gfile.py b/tensorflow/python/platform/gfile.py index 202475efdf29e746fb8e985677d1f826741939fb..315889e9aa8851138bf8b07b9803cc2d360f354a 100644 --- a/tensorflow/python/platform/gfile.py +++ b/tensorflow/python/platform/gfile.py @@ -34,8 +34,10 @@ from tensorflow.python.lib.io.file_io import stat as Stat from tensorflow.python.lib.io.file_io import walk as Walk # pylint: enable=unused-import from tensorflow.python.util.all_util import remove_undocumented +from tensorflow.python.util.tf_export import tf_export +@tf_export('gfile.GFile', 'gfile.Open') class GFile(_FileIO): """File I/O wrappers without thread locking.""" @@ -43,6 +45,7 @@ class GFile(_FileIO): super(GFile, self).__init__(name=name, mode=mode) +@tf_export('gfile.FastGFile') class FastGFile(_FileIO): """File I/O wrappers without thread locking.""" diff --git a/tensorflow/python/platform/sysconfig.py b/tensorflow/python/platform/sysconfig.py index f6c4f2227fbba75e4fdb41ddeaa55ba3f9168677..5c50fa023dc3b216838390d9356a39e70e2362d2 100644 --- a/tensorflow/python/platform/sysconfig.py +++ b/tensorflow/python/platform/sysconfig.py @@ -29,9 +29,11 @@ import os.path as _os_path from tensorflow.python.framework.versions import CXX11_ABI_FLAG as _CXX11_ABI_FLAG from tensorflow.python.framework.versions import MONOLITHIC_BUILD as _MONOLITHIC_BUILD from tensorflow.python.util.all_util import remove_undocumented +from tensorflow.python.util.tf_export import tf_export # pylint: disable=g-import-not-at-top +@tf_export('sysconfig.get_include') def get_include(): """Get the directory containing the TensorFlow C++ header files. @@ -46,6 +48,7 @@ def get_include(): return _os_path.join(_os_path.dirname(tf.__file__), 'include') +@tf_export('sysconfig.get_lib') def get_lib(): """Get the directory containing the TensorFlow framework library. @@ -56,6 +59,7 @@ def get_lib(): return _os_path.join(_os_path.dirname(tf.__file__)) +@tf_export('sysconfig.get_compile_flags') def get_compile_flags(): """Get the compilation flags for custom operators. @@ -69,6 +73,7 @@ def get_compile_flags(): return flags +@tf_export('sysconfig.get_link_flags') def get_link_flags(): """Get the link flags for custom operators. diff --git a/tensorflow/python/platform/test.py b/tensorflow/python/platform/test.py index ec280c6e1ee75f8192b318c6830c62cd9dec9c55..9b7655722ac5a917f2753617f8e99bf2bd2f8d11 100644 --- a/tensorflow/python/platform/test.py +++ b/tensorflow/python/platform/test.py @@ -56,6 +56,7 @@ from tensorflow.python.ops.gradient_checker import compute_gradient # pylint: enable=unused-import,g-bad-import-order import sys +from tensorflow.python.util.tf_export import tf_export if sys.version_info.major == 2: import mock # pylint: disable=g-import-not-at-top,unused-import else: @@ -68,12 +69,14 @@ Benchmark = _googletest.Benchmark # pylint: disable=invalid-name StubOutForTesting = _googletest.StubOutForTesting # pylint: disable=invalid-name +@tf_export('test.main') def main(argv=None): """Runs all unit tests.""" _test_util.InstallStackTraceHandler() return _googletest.main(argv) +@tf_export('test.get_temp_dir') def get_temp_dir(): """Returns a temporary directory for use during tests. @@ -85,6 +88,7 @@ def get_temp_dir(): return _googletest.GetTempDir() +@tf_export('test.test_src_dir_path') def test_src_dir_path(relative_path): """Creates an absolute test srcdir path given a relative path. @@ -98,6 +102,7 @@ def test_src_dir_path(relative_path): return _googletest.test_src_dir_path(relative_path) +@tf_export('test.is_built_with_cuda') def is_built_with_cuda(): """Returns whether TensorFlow was built with CUDA (GPU) support.""" return _test_util.IsGoogleCudaEnabled() diff --git a/tensorflow/python/pywrap_tfe.i b/tensorflow/python/pywrap_tfe.i index 083931aa8369f46b4e859b5ed4764c4bdfa9c3c3..3f25311a8361d11fbc583413708e148648d95906 100644 --- a/tensorflow/python/pywrap_tfe.i +++ b/tensorflow/python/pywrap_tfe.i @@ -24,10 +24,13 @@ limitations under the License. %rename("%s") TFE_ContextDisableRunMetadata; %rename("%s") TFE_ContextExportRunMetadata; %rename("%s") TFE_ContextClearCaches; +%rename("%s") TFE_ContextGetDevicePlacementPolicy; +%rename("%s") TFE_ContextSetThreadLocalDevicePlacementPolicy; %rename("%s") TFE_OpNameGetAttrType; %rename("%s") TFE_Py_InitEagerTensor; %rename("%s") TFE_Py_RegisterExceptionClass; %rename("%s") TFE_Py_Execute; +%rename("%s") TFE_Py_FastPathExecute; %rename("%s") TFE_Py_UID; %rename("%s") TFE_Py_TapeSetNew; %rename("%s") TFE_Py_TapeSetRemove; @@ -118,6 +121,7 @@ limitations under the License. %rename("%s") TFE_DEVICE_PLACEMENT_EXPLICIT; %rename("%s") TFE_DEVICE_PLACEMENT_WARN; %rename("%s") TFE_DEVICE_PLACEMENT_SILENT; +%rename("%s") TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32; %include "tensorflow/c/eager/c_api.h" @@ -155,7 +159,7 @@ limitations under the License. } $1 = &temp; $1->resize(PyInt_AsLong($input), nullptr); -} +} // Create new Status object. %typemap(in, numinputs=0) TF_Status *out_status { @@ -180,10 +184,14 @@ limitations under the License. } } +// SWIG usually unwraps the tuple that the native Python/C interface generates. +// Since we wanted to have a function with a variable length of arguments, we +// used the native Python/C interface directly (which by default supports +// passing all arguments as a tuple). +%native(TFE_Py_FastPathExecute) TFE_Py_FastPathExecute_C; %include "tensorflow/python/eager/pywrap_tfe.h" - // Clear all typemaps. %typemap(out) TF_DataType; %typemap(out) int64_t; diff --git a/tensorflow/python/summary/summary.py b/tensorflow/python/summary/summary.py index 355593eca5dd2f84419035958bfe8eea83e485b8..92c1fcadd29c7858da1d31375c209bf1b21f3103 100644 --- a/tensorflow/python/summary/summary.py +++ b/tensorflow/python/summary/summary.py @@ -286,12 +286,13 @@ def merge(inputs, collections=None, name=None): return val -def merge_all(key=_ops.GraphKeys.SUMMARIES): +def merge_all(key=_ops.GraphKeys.SUMMARIES, scope=None): """Merges all summaries collected in the default graph. Args: key: `GraphKey` used to collect the summaries. Defaults to `GraphKeys.SUMMARIES`. + scope: Optional scope used to filter the summary ops, using `re.match` Returns: If no summaries were collected, returns None. Otherwise returns a scalar @@ -310,7 +311,7 @@ def merge_all(key=_ops.GraphKeys.SUMMARIES): raise RuntimeError( 'Merging tf.summary.* ops is not compatible with eager execution. ' 'Use tf.contrib.summary instead.') - summary_ops = _ops.get_collection(key) + summary_ops = _ops.get_collection(key, scope=scope) if not summary_ops: return None else: diff --git a/tensorflow/python/tools/freeze_graph.py b/tensorflow/python/tools/freeze_graph.py index 0ddf09260b6865b4bac5b580459e6080dae7ada0..a2e86a1c43a9c27041d963b2b8d7af582e1054c7 100644 --- a/tensorflow/python/tools/freeze_graph.py +++ b/tensorflow/python/tools/freeze_graph.py @@ -72,7 +72,8 @@ def freeze_graph_with_def_protos(input_graph_def, variable_names_blacklist="", input_meta_graph_def=None, input_saved_model_dir=None, - saved_model_tags=None): + saved_model_tags=None, + checkpoint_version=saver_pb2.SaverDef.V2): """Converts all variables in a graph and checkpoint into constants.""" del restore_op_name, filename_tensor_name # Unused by updated loading code. @@ -100,7 +101,8 @@ def freeze_graph_with_def_protos(input_graph_def, _ = importer.import_graph_def(input_graph_def, name="") with session.Session() as sess: if input_saver_def: - saver = saver_lib.Saver(saver_def=input_saver_def) + saver = saver_lib.Saver(saver_def=input_saver_def, + write_version=checkpoint_version) saver.restore(sess, input_checkpoint) elif input_meta_graph_def: restorer = saver_lib.import_meta_graph( @@ -124,7 +126,8 @@ def freeze_graph_with_def_protos(input_graph_def, # 'global_step' or a similar housekeeping element) so skip it. continue var_list[key] = tensor - saver = saver_lib.Saver(var_list=var_list) + saver = saver_lib.Saver(var_list=var_list, + write_version=checkpoint_version) saver.restore(sess, input_checkpoint) if initializer_nodes: sess.run(initializer_nodes.split(",")) @@ -217,7 +220,8 @@ def freeze_graph(input_graph, variable_names_blacklist="", input_meta_graph=None, input_saved_model_dir=None, - saved_model_tags=tag_constants.SERVING): + saved_model_tags=tag_constants.SERVING, + checkpoint_version=saver_pb2.SaverDef.V2): """Converts all variables in a graph and checkpoint into constants.""" input_graph_def = None if input_saved_model_dir: @@ -236,7 +240,8 @@ def freeze_graph(input_graph, input_graph_def, input_saver_def, input_checkpoint, output_node_names, restore_op_name, filename_tensor_name, output_graph, clear_devices, initializer_nodes, variable_names_whitelist, variable_names_blacklist, - input_meta_graph_def, input_saved_model_dir, saved_model_tags.split(",")) + input_meta_graph_def, input_saved_model_dir, + saved_model_tags.split(","), checkpoint_version=checkpoint_version) def main(unused_args): @@ -246,7 +251,7 @@ def main(unused_args): FLAGS.output_graph, FLAGS.clear_devices, FLAGS.initializer_nodes, FLAGS.variable_names_whitelist, FLAGS.variable_names_blacklist, FLAGS.input_meta_graph, FLAGS.input_saved_model_dir, - FLAGS.saved_model_tags) + FLAGS.saved_model_tags, checkpoint_version=checkpoint_version) if __name__ == "__main__": @@ -267,6 +272,11 @@ if __name__ == "__main__": type=str, default="", help="TensorFlow variables file to load.") + parser.add_argument( + "--checkpoint_version", + type=int, + default=saver_pb2.SaverDef.V2, + help="Tensorflow variable file format") parser.add_argument( "--output_graph", type=str, diff --git a/tensorflow/python/tools/freeze_graph_test.py b/tensorflow/python/tools/freeze_graph_test.py index feeed7102cd49a79d0280cc04431de00ad3286d5..342732465d48f40a4ffeac97146fb1b6d564c568 100644 --- a/tensorflow/python/tools/freeze_graph_test.py +++ b/tensorflow/python/tools/freeze_graph_test.py @@ -86,7 +86,8 @@ class FreezeGraphTest(test_util.TensorFlowTestCase): freeze_graph.freeze_graph( input_graph_path, input_saver_def_path, input_binary, checkpoint_path, output_node_names, restore_op_name, filename_tensor_name, - output_graph_path, clear_devices, "", "", input_meta_graph) + output_graph_path, clear_devices, "", "", input_meta_graph, + checkpoint_version=saver_write_version) # Now we make sure the variable is now a constant, and that the graph still # produces the expected result. diff --git a/tensorflow/python/tools/saved_model_cli.py b/tensorflow/python/tools/saved_model_cli.py index ce64fdf70981cd78ac9dc7e5dbae15b90df654a2..21e8e803fcb3d12a2e41b5f9e2810742ec220be8 100644 --- a/tensorflow/python/tools/saved_model_cli.py +++ b/tensorflow/python/tools/saved_model_cli.py @@ -33,6 +33,7 @@ import numpy as np from tensorflow.contrib.saved_model.python.saved_model import reader from tensorflow.contrib.saved_model.python.saved_model import signature_def_utils +from tensorflow.core.example import example_pb2 from tensorflow.core.framework import types_pb2 from tensorflow.python.client import session from tensorflow.python.debug.wrappers import local_cli_wrapper @@ -377,7 +378,7 @@ def preprocess_input_exprs_arg_string(input_exprs_str): 'input_key=' Returns: - A dictionary that maps input keys to python expressions. + A dictionary that maps input keys to their values. Raises: RuntimeError: An error when the given input string is in a bad format. @@ -388,17 +389,75 @@ def preprocess_input_exprs_arg_string(input_exprs_str): if '=' not in input_exprs_str: raise RuntimeError('--input_exprs "%s" format is incorrect. Please follow' '"="' % input_exprs_str) - input_key, expr = input_raw.split('=') - input_dict[input_key] = expr + input_key, expr = input_raw.split('=', 1) + # ast.literal_eval does not work with numpy expressions + input_dict[input_key] = eval(expr) # pylint: disable=eval-used + return input_dict + +def preprocess_input_examples_arg_string(input_examples_str): + """Parses input into dict that maps input keys to lists of tf.Example. + + Parses input string in the format of 'input_key1=[{feature_name: + feature_list}];input_key2=[{feature_name:feature_list}];' into a dictionary + that maps each input_key to its list of serialized tf.Example. + + Args: + input_examples_str: A string that specifies a list of dictionaries of + feature_names and their feature_lists for each input. + Each input is separated by semicolon. For each input key: + 'input=[{feature_name1: feature_list1, feature_name2:feature_list2}]' + items in feature_list can be the type of float, int, long or str. + + Returns: + A dictionary that maps input keys to lists of serialized tf.Example. + + Raises: + ValueError: An error when the given tf.Example is not a list. + """ + input_dict = preprocess_input_exprs_arg_string(input_examples_str) + for input_key, example_list in input_dict.items(): + if not isinstance(example_list, list): + raise ValueError( + 'tf.Example input must be a list of dictionaries, but "%s" is %s' % + (example_list, type(example_list))) + input_dict[input_key] = [ + _create_example_string(example) for example in example_list + ] return input_dict -def load_inputs_from_input_arg_string(inputs_str, input_exprs_str): +def _create_example_string(example_dict): + """Create a serialized tf.example from feature dictionary.""" + example = example_pb2.Example() + for feature_name, feature_list in example_dict.items(): + if not isinstance(feature_list, list): + raise ValueError('feature value must be a list, but %s: "%s" is %s' % + (feature_name, feature_list, type(feature_list))) + if isinstance(feature_list[0], float): + example.features.feature[feature_name].float_list.value.extend( + feature_list) + elif isinstance(feature_list[0], str): + example.features.feature[feature_name].bytes_list.value.extend( + feature_list) + elif isinstance(feature_list[0], (int, long)): + example.features.feature[feature_name].int64_list.value.extend( + feature_list) + else: + raise ValueError( + 'Type %s for value %s is not supported for tf.train.Feature.' % + (type(feature_list[0]), feature_list[0])) + return example.SerializeToString() + + +def load_inputs_from_input_arg_string(inputs_str, input_exprs_str, + input_examples_str): """Parses input arg strings and create inputs feed_dict. Parses '--inputs' string for inputs to be loaded from file, and parses '--input_exprs' string for inputs to be evaluated from python expression. + '--input_examples' string for inputs to be created from tf.example feature + dictionary list. Args: inputs_str: A string that specified where to load inputs. Each input is @@ -424,9 +483,11 @@ def load_inputs_from_input_arg_string(inputs_str, input_exprs_str): to the specified input tensor, else SavedModel CLI will assume a dictionary is stored in the pickle file and the value corresponding to the variable_name will be used. - input_exprs_str: A string that specified python expressions for inputs. + input_exprs_str: A string that specifies python expressions for inputs. * In the format of: '='. * numpy module is available as np. + input_examples_str: A string that specifies tf.Example with dictionary. + * In the format of: '=<[{feature:value list}]>' Returns: A dictionary that maps input tensor keys to numpy ndarrays. @@ -441,6 +502,7 @@ def load_inputs_from_input_arg_string(inputs_str, input_exprs_str): inputs = preprocess_inputs_arg_string(inputs_str) input_exprs = preprocess_input_exprs_arg_string(input_exprs_str) + input_examples = preprocess_input_examples_arg_string(input_examples_str) for input_tensor_key, (filename, variable_name) in inputs.items(): data = np.load(filename) @@ -474,15 +536,20 @@ def load_inputs_from_input_arg_string(inputs_str, input_exprs_str): tensor_key_feed_dict[input_tensor_key] = data # When input is a python expression: - for input_tensor_key, py_expr in input_exprs.items(): + for input_tensor_key, py_expr_evaluated in input_exprs.items(): if input_tensor_key in tensor_key_feed_dict: warnings.warn( 'input_key %s has been specified with both --inputs and --input_exprs' ' options. Value in --input_exprs will be used.' % input_tensor_key) + tensor_key_feed_dict[input_tensor_key] = py_expr_evaluated - # ast.literal_eval does not work with numpy expressions - tensor_key_feed_dict[input_tensor_key] = eval(py_expr) # pylint: disable=eval-used - + # When input is a tf.Example: + for input_tensor_key, example in input_examples.items(): + if input_tensor_key in tensor_key_feed_dict: + warnings.warn( + 'input_key %s has been specified in multiple options. Value in ' + '--input_examples will be used.' % input_tensor_key) + tensor_key_feed_dict[input_tensor_key] = example return tensor_key_feed_dict @@ -518,11 +585,12 @@ def run(args): AttributeError: An error when neither --inputs nor --input_exprs is passed to run command. """ - if not args.inputs and not args.input_exprs: + if not args.inputs and not args.input_exprs and not args.input_examples: raise AttributeError( - 'At least one of --inputs and --input_exprs must be required') + 'At least one of --inputs, --input_exprs or --input_examples must be ' + 'required') tensor_key_feed_dict = load_inputs_from_input_arg_string( - args.inputs, args.input_exprs) + args.inputs, args.input_exprs, args.input_examples) run_saved_model_with_feed_dict(args.dir, args.tag_set, args.signature_def, tensor_key_feed_dict, args.outdir, args.overwrite, tf_debug=args.tf_debug) @@ -589,10 +657,12 @@ def create_parser(): run_msg = ('Usage example:\n' 'To run input tensors from files through a MetaGraphDef and save' ' the output tensors to files:\n' - '$saved_model_cli show --dir /tmp/saved_model --tag_set serve' + '$saved_model_cli show --dir /tmp/saved_model --tag_set serve ' '--signature_def serving_default ' - '--inputs input1_key=/tmp/124.npz[x],input2_key=/tmp/123.npy' - '--input_exprs \'input3_key=np.ones(2)\' --outdir=/out\n\n' + '--inputs input1_key=/tmp/124.npz[x],input2_key=/tmp/123.npy ' + '--input_exprs \'input3_key=np.ones(2)\' --input_examples ' + '\'input4_key=[{"id":[26],"weights":[0.5, 0.5]}]\' ' + '--outdir=/out\n\n' 'For more information about input file format, please see:\n' 'https://www.tensorflow.org/programmers_guide/saved_model_cli\n') parser_run = subparsers.add_parser( @@ -620,8 +690,14 @@ def create_parser(): msg = ('Specifying inputs by python expressions, in the format of' ' "=\'\'", separated by \';\'. ' 'numpy module is available as \'np\'. ' - 'Will override duplicate input_keys from --inputs option.') + 'Will override duplicate input keys from --inputs option.') parser_run.add_argument('--input_exprs', type=str, default='', help=msg) + msg = ( + 'Specifying tf.Example inputs as list of dictionaries. For example: ' + '=[{feature0:value_list,feature1:value_list}]. Use ";" to ' + 'separate input keys. Will override duplicate input keys from --inputs ' + 'and --input_exprs option.') + parser_run.add_argument('--input_examples', type=str, default='', help=msg) parser_run.add_argument( '--outdir', type=str, diff --git a/tensorflow/python/tools/saved_model_cli_test.py b/tensorflow/python/tools/saved_model_cli_test.py index 0789e1e107cf63b41e37dd7afea0e673d93b2f89..d6cbc49ba1e08a6b808b228fb8d69fc14f36e3d2 100644 --- a/tensorflow/python/tools/saved_model_cli_test.py +++ b/tensorflow/python/tools/saved_model_cli_test.py @@ -218,8 +218,9 @@ Method name is: tensorflow/serving/predict""" input_expr_str) self.assertTrue(input_dict['input1'] == ('/path/file.txt', 'ab3')) self.assertTrue(input_dict['input2'] == ('file2', None)) - self.assertTrue(input_expr_dict['input3'] == 'np.zeros([2,2])') - self.assertTrue(input_expr_dict['input4'] == '[4,5]') + print(input_expr_dict['input3']) + self.assertAllClose(input_expr_dict['input3'], np.zeros([2, 2])) + self.assertAllClose(input_expr_dict['input4'], [4, 5]) self.assertTrue(len(input_dict) == 2) self.assertTrue(len(input_expr_dict) == 2) @@ -250,7 +251,8 @@ Method name is: tensorflow/serving/predict""" np.save(input0_path, x0) np.save(input1_path, x1) input_str = 'x0=' + input0_path + '[x0];x1=' + input1_path - feed_dict = saved_model_cli.load_inputs_from_input_arg_string(input_str, '') + feed_dict = saved_model_cli.load_inputs_from_input_arg_string( + input_str, '', '') self.assertTrue(np.all(feed_dict['x0'] == x0)) self.assertTrue(np.all(feed_dict['x1'] == x1)) @@ -259,7 +261,8 @@ Method name is: tensorflow/serving/predict""" input_path = os.path.join(test.get_temp_dir(), 'input.npz') np.savez(input_path, a=x0) input_str = 'x=' + input_path + '[a];y=' + input_path - feed_dict = saved_model_cli.load_inputs_from_input_arg_string(input_str, '') + feed_dict = saved_model_cli.load_inputs_from_input_arg_string( + input_str, '', '') self.assertTrue(np.all(feed_dict['x'] == x0)) self.assertTrue(np.all(feed_dict['y'] == x0)) @@ -278,7 +281,8 @@ Method name is: tensorflow/serving/predict""" pickle.dump(pkl2, f) input_str = 'x=' + input_path0 + '[b];y=' + input_path1 + '[c];' input_str += 'z=' + input_path2 - feed_dict = saved_model_cli.load_inputs_from_input_arg_string(input_str, '') + feed_dict = saved_model_cli.load_inputs_from_input_arg_string( + input_str, '', '') self.assertTrue(np.all(feed_dict['x'] == pkl0['b'])) self.assertTrue(np.all(feed_dict['y'] == pkl1)) self.assertTrue(np.all(feed_dict['z'] == pkl2)) @@ -291,7 +295,7 @@ Method name is: tensorflow/serving/predict""" input_expr_str = ('x1=np.ones([2,10]);x2=np.array([[1],[2],[3]]);' 'x3=np.mgrid[0:5,0:5];x4=[[3],[4]]') feed_dict = saved_model_cli.load_inputs_from_input_arg_string( - '', input_expr_str) + '', input_expr_str, '') self.assertTrue(np.all(feed_dict['x1'] == x1)) self.assertTrue(np.all(feed_dict['x2'] == x2)) self.assertTrue(np.all(feed_dict['x3'] == x3)) @@ -305,7 +309,7 @@ Method name is: tensorflow/serving/predict""" input_str = 'x0=' + input_path + '[a]' input_expr_str = 'x1=np.ones([2,10])' feed_dict = saved_model_cli.load_inputs_from_input_arg_string( - input_str, input_expr_str) + input_str, input_expr_str, '') self.assertTrue(np.all(feed_dict['x0'] == x0)) self.assertTrue(np.all(feed_dict['x1'] == x1)) @@ -317,7 +321,7 @@ Method name is: tensorflow/serving/predict""" input_str = 'x0=' + input_path + '[a]' input_expr_str = 'x0=np.ones([2,10])' feed_dict = saved_model_cli.load_inputs_from_input_arg_string( - input_str, input_expr_str) + input_str, input_expr_str, '') self.assertTrue(np.all(feed_dict['x0'] == x1)) def testInputParserErrorNoName(self): @@ -327,7 +331,7 @@ Method name is: tensorflow/serving/predict""" np.savez(input_path, a=x0, b=x1) input_str = 'x=' + input_path with self.assertRaises(RuntimeError): - saved_model_cli.load_inputs_from_input_arg_string(input_str, '') + saved_model_cli.load_inputs_from_input_arg_string(input_str, '', '') def testInputParserErrorWrongName(self): x0 = np.array([[1], [2]]) @@ -336,7 +340,22 @@ Method name is: tensorflow/serving/predict""" np.savez(input_path, a=x0, b=x1) input_str = 'x=' + input_path + '[c]' with self.assertRaises(RuntimeError): - saved_model_cli.load_inputs_from_input_arg_string(input_str, '') + saved_model_cli.load_inputs_from_input_arg_string(input_str, '', '') + + def testRunCommandInputExamples(self): + self.parser = saved_model_cli.create_parser() + base_path = test.test_src_dir_path(SAVED_MODEL_PATH) + output_dir = os.path.join(test.get_temp_dir(), 'new_dir') + args = self.parser.parse_args([ + 'run', '--dir', base_path, '--tag_set', 'serve', '--signature_def', + 'regress_x_to_y', '--input_examples', + 'inputs=[{"x":[8.0],"x2":[5.0]}, {"x":[4.0],"x2":[3.0]}]', '--outdir', + output_dir + ]) + saved_model_cli.run(args) + y_actual = np.load(os.path.join(output_dir, 'outputs.npy')) + y_expected = np.array([[6.0], [4.0]]) + self.assertAllEqual(y_expected, y_actual) def testRunCommandExistingOutdir(self): self.parser = saved_model_cli.create_parser() @@ -410,6 +429,42 @@ Method name is: tensorflow/serving/predict""" with self.assertRaises(ValueError): saved_model_cli.run(args) + def testRunCommandInputExamplesNotListError(self): + self.parser = saved_model_cli.create_parser() + base_path = test.test_src_dir_path(SAVED_MODEL_PATH) + output_dir = os.path.join(test.get_temp_dir(), 'new_dir') + args = self.parser.parse_args([ + 'run', '--dir', base_path, '--tag_set', 'serve', '--signature_def', + 'regress_x_to_y', '--input_examples', 'inputs={"x":8.0,"x2":5.0}', + '--outdir', output_dir + ]) + with self.assertRaisesRegexp(ValueError, 'must be a list'): + saved_model_cli.run(args) + + def testRunCommandInputExamplesFeatureValueNotListError(self): + self.parser = saved_model_cli.create_parser() + base_path = test.test_src_dir_path(SAVED_MODEL_PATH) + output_dir = os.path.join(test.get_temp_dir(), 'new_dir') + args = self.parser.parse_args([ + 'run', '--dir', base_path, '--tag_set', 'serve', '--signature_def', + 'regress_x_to_y', '--input_examples', 'inputs=[{"x":8.0,"x2":5.0}]', + '--outdir', output_dir + ]) + with self.assertRaisesRegexp(ValueError, 'feature value must be a list'): + saved_model_cli.run(args) + + def testRunCommandInputExamplesFeatureBadType(self): + self.parser = saved_model_cli.create_parser() + base_path = test.test_src_dir_path(SAVED_MODEL_PATH) + output_dir = os.path.join(test.get_temp_dir(), 'new_dir') + args = self.parser.parse_args([ + 'run', '--dir', base_path, '--tag_set', 'serve', '--signature_def', + 'regress_x_to_y', '--input_examples', 'inputs=[{"x":[[1],[2]]}]', + '--outdir', output_dir + ]) + with self.assertRaisesRegexp(ValueError, 'is not supported'): + saved_model_cli.run(args) + def testRunCommandOutputFileExistError(self): self.parser = saved_model_cli.create_parser() base_path = test.test_src_dir_path(SAVED_MODEL_PATH) diff --git a/tensorflow/python/training/adam.py b/tensorflow/python/training/adam.py index 266f5563e0c738fe73e3a771a46e9b28c266cd73..0c69f8bf3997452f0eeb71c93f4fcf98eb27d8f9 100644 --- a/tensorflow/python/training/adam.py +++ b/tensorflow/python/training/adam.py @@ -24,7 +24,6 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import state_ops -from tensorflow.python.ops import variable_scope from tensorflow.python.training import optimizer from tensorflow.python.training import training_ops @@ -101,19 +100,16 @@ class AdamOptimizer(optimizer.Optimizer): self._beta2_t = None self._epsilon_t = None - # Variables to accumulate the powers of the beta parameters. - # Created in _create_slots when we know the variables to optimize. - self._beta1_power = None - self._beta2_power = None - # Created in SparseApply if needed. self._updated_lr = None def _get_beta_accumulators(self): - return self._beta1_power, self._beta2_power - - def _non_slot_variables(self): - return self._get_beta_accumulators() + if context.in_graph_mode(): + graph = ops.get_default_graph() + else: + graph = None + return (self._get_non_slot_variable("beta1_power", graph=graph), + self._get_non_slot_variable("beta2_power", graph=graph)) def _create_slots(self, var_list): # Create the beta1 and beta2 accumulators on the same device as the first @@ -121,19 +117,13 @@ class AdamOptimizer(optimizer.Optimizer): # workers (these need to go on the same PS, otherwise some updates are # silently ignored). first_var = min(var_list, key=lambda x: x.name) + self._create_non_slot_variable(initial_value=self._beta1, + name="beta1_power", + colocate_with=first_var) + self._create_non_slot_variable(initial_value=self._beta2, + name="beta2_power", + colocate_with=first_var) - create_new = self._beta1_power is None - if not create_new and context.in_graph_mode(): - create_new = (self._beta1_power.graph is not first_var.graph) - - if create_new: - with ops.colocate_with(first_var): - self._beta1_power = variable_scope.variable(self._beta1, - name="beta1_power", - trainable=False) - self._beta2_power = variable_scope.variable(self._beta2, - name="beta2_power", - trainable=False) # Create slots for the first and second moments. for v in var_list: self._zeros_slot(v, "m", self._name) @@ -148,10 +138,11 @@ class AdamOptimizer(optimizer.Optimizer): def _apply_dense(self, grad, var): m = self.get_slot(var, "m") v = self.get_slot(var, "v") + beta1_power, beta2_power = self._get_beta_accumulators() return training_ops.apply_adam( var, m, v, - math_ops.cast(self._beta1_power, var.dtype.base_dtype), - math_ops.cast(self._beta2_power, var.dtype.base_dtype), + math_ops.cast(beta1_power, var.dtype.base_dtype), + math_ops.cast(beta2_power, var.dtype.base_dtype), math_ops.cast(self._lr_t, var.dtype.base_dtype), math_ops.cast(self._beta1_t, var.dtype.base_dtype), math_ops.cast(self._beta2_t, var.dtype.base_dtype), @@ -161,10 +152,11 @@ class AdamOptimizer(optimizer.Optimizer): def _resource_apply_dense(self, grad, var): m = self.get_slot(var, "m") v = self.get_slot(var, "v") + beta1_power, beta2_power = self._get_beta_accumulators() return training_ops.resource_apply_adam( var.handle, m.handle, v.handle, - math_ops.cast(self._beta1_power, grad.dtype.base_dtype), - math_ops.cast(self._beta2_power, grad.dtype.base_dtype), + math_ops.cast(beta1_power, grad.dtype.base_dtype), + math_ops.cast(beta2_power, grad.dtype.base_dtype), math_ops.cast(self._lr_t, grad.dtype.base_dtype), math_ops.cast(self._beta1_t, grad.dtype.base_dtype), math_ops.cast(self._beta2_t, grad.dtype.base_dtype), @@ -172,8 +164,9 @@ class AdamOptimizer(optimizer.Optimizer): grad, use_locking=self._use_locking) def _apply_sparse_shared(self, grad, var, indices, scatter_add): - beta1_power = math_ops.cast(self._beta1_power, var.dtype.base_dtype) - beta2_power = math_ops.cast(self._beta2_power, var.dtype.base_dtype) + beta1_power, beta2_power = self._get_beta_accumulators() + beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype) + beta2_power = math_ops.cast(beta2_power, var.dtype.base_dtype) lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype) beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype) beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype) @@ -217,12 +210,11 @@ class AdamOptimizer(optimizer.Optimizer): def _finish(self, update_ops, name_scope): # Update the power accumulators. with ops.control_dependencies(update_ops): - with ops.colocate_with(self._beta1_power): - update_beta1 = self._beta1_power.assign( - self._beta1_power * self._beta1_t, - use_locking=self._use_locking) - update_beta2 = self._beta2_power.assign( - self._beta2_power * self._beta2_t, - use_locking=self._use_locking) + beta1_power, beta2_power = self._get_beta_accumulators() + with ops.colocate_with(beta1_power): + update_beta1 = beta1_power.assign( + beta1_power * self._beta1_t, use_locking=self._use_locking) + update_beta2 = beta2_power.assign( + beta2_power * self._beta2_t, use_locking=self._use_locking) return control_flow_ops.group(*update_ops + [update_beta1, update_beta2], name=name_scope) diff --git a/tensorflow/python/training/adam_test.py b/tensorflow/python/training/adam_test.py index ffb66abc4c1a38353d602a711cab86b0d63b9e96..a521f1299e035424d1c3897a469655db732b0dcd 100644 --- a/tensorflow/python/training/adam_test.py +++ b/tensorflow/python/training/adam_test.py @@ -174,8 +174,11 @@ class AdamOptimizerTest(test.TestCase): opt = adam.AdamOptimizer() update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) opt_variables = opt.variables() - self.assertIn(opt._beta1_power, opt_variables) - self.assertIn(opt._beta2_power, opt_variables) + beta1_power, beta2_power = opt._get_beta_accumulators() + self.assertTrue(beta1_power is not None) + self.assertTrue(beta2_power is not None) + self.assertIn(beta1_power, opt_variables) + self.assertIn(beta2_power, opt_variables) with ops.Graph().as_default(): # Shouldn't return non-slot variables from other graphs. diff --git a/tensorflow/python/training/optimizer.py b/tensorflow/python/training/optimizer.py index 56cf4d42ee194885057d8bf45d9b3c1c407c4a11..038469b1bac9d2fabce788340278ea165f2f9249 100644 --- a/tensorflow/python/training/optimizer.py +++ b/tensorflow/python/training/optimizer.py @@ -32,6 +32,7 @@ from tensorflow.python.ops import gradients from tensorflow.python.ops import math_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import state_ops +from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.training import slot_creator from tensorflow.python.util import nest @@ -299,6 +300,7 @@ class Optimizer(object): # Dictionary of slots. # {slot_name : { variable_to_train: slot_for_the_variable, ...}, ... } self._slots = {} + self._non_slot_dict = {} def get_name(self): return self._name @@ -603,17 +605,32 @@ class Optimizer(object): # Sort variables by name so that the return is deterministic. return sorted(optimizer_variables, key=lambda v: v.name) + def _create_non_slot_variable(self, initial_value, name, colocate_with): + """Add an extra variable, not associated with a slot.""" + if context.in_graph_mode(): + graph = colocate_with.graph + else: + graph = None + + key = (name, graph) + v = self._non_slot_dict.get(key, None) + if v is None: + with ops.colocate_with(colocate_with): + v = variable_scope.variable(initial_value, name=name, trainable=False) + self._non_slot_dict[key] = v + + return v + + def _get_non_slot_variable(self, name, graph=None): + return self._non_slot_dict.get((name, graph), None) + def _non_slot_variables(self): """Additional variables created by the `Optimizer`. - This method should be overridden by child classes which create extra - variables, so that `variables()` includes the `Optimizer`'s non-slot - variables. - Returns: A list or tuple of variables. """ - return [] + return self._non_slot_dict.values() def _assert_valid_dtypes(self, tensors): """Asserts tensors are all valid types (see `_valid_dtypes`). diff --git a/tensorflow/python/training/saver.py b/tensorflow/python/training/saver.py index 2c59b82ebe2264e56da1a3b977b27eba2ed6f494..4f3773c0fc71e1f1abd8197dea94ce2a63881389 100644 --- a/tensorflow/python/training/saver.py +++ b/tensorflow/python/training/saver.py @@ -1592,9 +1592,9 @@ class Saver(object): [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes). Returns: - A string: path prefix used for the checkpoint files. If the saver is - sharded, this string ends with: '-?????-of-nnnnn' where 'nnnnn' - is the number of shards created. + A string: path prefix used for the checkpoint files. If checkpoint + format is V1 and the saver is sharded, this string ends with: + '-?????-of-nnnnn' where 'nnnnn' is the number of shards created. If the saver is empty, returns None. Raises: @@ -1744,6 +1744,11 @@ class Saver(object): return if save_path is None: raise ValueError("Can't load save_path when it is None.") + if (os.path.isfile(save_path) and + self._write_version != saver_pb2.SaverDef.V1): + raise ValueError("The specified path: %s is a file." + " Please specify only the path prefix" + " to the checkpoint files." % save_path) logging.info("Restoring parameters from %s", save_path) if context.in_graph_mode(): sess.run(self.saver_def.restore_op_name, diff --git a/tensorflow/python/training/sync_replicas_optimizer_test.py b/tensorflow/python/training/sync_replicas_optimizer_test.py index 297284f80c2997e21304138c5a090da76425917b..fff17402e23cb7b054d3e433650666b0554ed8ba 100644 --- a/tensorflow/python/training/sync_replicas_optimizer_test.py +++ b/tensorflow/python/training/sync_replicas_optimizer_test.py @@ -286,8 +286,9 @@ class SyncReplicasOptimizerHookTest(test.TestCase): global_step = variables.Variable(0, name="global_step", trainable=False) opt.minimize(v, global_step=global_step) opt_variables = opt.variables() - self.assertIn(opt._opt._beta1_power, opt_variables) - self.assertIn(opt._opt._beta2_power, opt_variables) + beta1_power, beta2_power = opt._opt._get_beta_accumulators() + self.assertIn(beta1_power, opt_variables) + self.assertIn(beta2_power, opt_variables) if __name__ == "__main__": diff --git a/tensorflow/python/util/compat.py b/tensorflow/python/util/compat.py index 07382d93dfe5ebe3f063b86bc5afa288970330f6..270d96a3c7c831d8c06dd86199cf2dc5dfc43421 100644 --- a/tensorflow/python/util/compat.py +++ b/tensorflow/python/util/compat.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - """Functions for Python 2 vs. 3 compatibility. ## Conversion routines @@ -21,6 +20,7 @@ In addition to the functions below, `as_str` converts an object to a `str`. @@as_bytes @@as_text @@as_str_any +@@path_to_str ## Types The compatibility module also provides the following types: @@ -108,17 +108,29 @@ def as_str_any(value): return str(value) +def path_to_str(path): + """Returns the file system path representation of a `PathLike` object, else as it is. + + Args: + path: An object that can be converted to path representation. + + Returns: + A `str` object. + """ + if hasattr(path, '__fspath__'): + path = as_str_any(path.__fspath__()) + return path + + # Numpy 1.8 scalars don't inherit from numbers.Integral in Python 3, so we # need to check them specifically. The same goes from Real and Complex. integral_types = (_numbers.Integral, _np.integer) real_types = (_numbers.Real, _np.integer, _np.floating) complex_types = (_numbers.Complex, _np.number) - # Either bytes or text. bytes_or_text_types = (bytes, _six.text_type) - _allowed_symbols = [ 'as_str', 'bytes_or_text_types', diff --git a/tensorflow/python/util/compat_internal.py b/tensorflow/python/util/compat_internal.py new file mode 100644 index 0000000000000000000000000000000000000000..a299b2fc3c302705d9493904e8ac0f81e4b8d371 --- /dev/null +++ b/tensorflow/python/util/compat_internal.py @@ -0,0 +1,34 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Functions for Python 2 vs. 3 compatibility that are private to TensorFlow.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +def path_to_str(path): + """Returns the file system path representation of a `PathLike` object, else as it is. + + Args: + path: An object that can be converted to path representation. + + Returns: + A `str` object. + """ + if hasattr(path, "__fspath__"): + path = as_str_any(path.__fspath__()) + return path diff --git a/tensorflow/python/util/kernel_registry.h b/tensorflow/python/util/kernel_registry.h index c00b60d91b3737966536d02281ed7a31a238b82f..1ba76f020bf3916704fb3a2d76895650fe093cfa 100644 --- a/tensorflow/python/util/kernel_registry.h +++ b/tensorflow/python/util/kernel_registry.h @@ -14,8 +14,8 @@ limitations under the License. ==============================================================================*/ // Functions for getting information about kernels registered in the binary. -#ifndef THIRD_PARTY_TENSORFLOW_PYTHON_UTIL_KERNEL_REGISTRY_H_ -#define THIRD_PARTY_TENSORFLOW_PYTHON_UTIL_KERNEL_REGISTRY_H_ +#ifndef TENSORFLOW_PYTHON_UTIL_KERNEL_REGISTRY_H_ +#define TENSORFLOW_PYTHON_UTIL_KERNEL_REGISTRY_H_ #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/platform/types.h" @@ -31,4 +31,4 @@ string TryFindKernelClass(const string& serialized_node_def); } // namespace swig } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_PYTHON_UTIL_KERNEL_REGISTRY_H_ +#endif // TENSORFLOW_PYTHON_UTIL_KERNEL_REGISTRY_H_ diff --git a/tensorflow/python/util/nest.py b/tensorflow/python/util/nest.py index 4ce871de72fb43420e25bfa7cd13002b09f83f18..874df3d1087e157f8bfcec12ba3495e341c14b7b 100644 --- a/tensorflow/python/util/nest.py +++ b/tensorflow/python/util/nest.py @@ -47,10 +47,25 @@ def _sorted(dict_): raise TypeError("nest only supports dicts with sortable keys.") -def _is_namedtuple(instance): - """Returns True iff `instance` is a `namedtuple`.""" +def _is_namedtuple(instance, strict=False): + """Returns True iff `instance` is a `namedtuple`. + + Args: + instance: An instance of a Python object. + strict: If True, `instance` is considered to be a `namedtuple` only if + it is a "plain" namedtuple. For instance, a class inheriting + from a `namedtuple` will be considered to be a `namedtuple` + iff `strict=False`. + + Returns: + True if `instance` is a `namedtuple`. + """ + # Attemp to limit the test to plain namedtuple (not stuff inheriting from it). + if not isinstance(instance, tuple): + return False + if strict and instance.__class__.__base__ != tuple: + return False return ( - isinstance(instance, tuple) and hasattr(instance, "_fields") and isinstance(instance._fields, _collections.Sequence) and all(isinstance(f, _six.string_types) for f in instance._fields)) @@ -140,8 +155,37 @@ def flatten(nest): return _pywrap_tensorflow.Flatten(nest) +def _same_namedtuples(nest1, nest2): + """Returns True if the two namedtuples have the same name and fields.""" + if nest1._fields != nest2._fields: + return False + if nest1.__class__.__name__ != nest2.__class__.__name__: + return False + return True + + def _recursive_assert_same_structure(nest1, nest2, check_types): - """Helper function for `assert_same_structure`.""" + """Helper function for `assert_same_structure`. + + See `assert_same_structure` for further information about namedtuples. + + Args: + nest1: An arbitrarily nested structure. + nest2: An arbitrarily nested structure. + check_types: If `True` (default) types of sequences are checked as + well, including the keys of dictionaries. If set to `False`, for example + a list and a tuple of objects will look the same if they have the same + size. Note that namedtuples with identical name and fields are always + considered to have the same shallow structure. + + Returns: + True if `nest1` and `nest2` have the same structure. + + Raises: + ValueError: If the two structure don't have the same nested structre. + TypeError: If the two structure don't have the same sequence type. + ValueError: If the two dictionaries don't have the same set of keys. + """ is_sequence_nest1 = is_sequence(nest1) if is_sequence_nest1 != is_sequence(nest2): raise ValueError( @@ -154,11 +198,21 @@ def _recursive_assert_same_structure(nest1, nest2, check_types): if check_types: type_nest1 = type(nest1) type_nest2 = type(nest2) - if type_nest1 != type_nest2: - raise TypeError( - "The two structures don't have the same sequence type. First " - "structure has type %s, while second structure has type %s." - % (type_nest1, type_nest2)) + + # Duck-typing means that nest should be fine with two different namedtuples + # with identical name and fields. + if _is_namedtuple(nest1, True) and _is_namedtuple(nest2, True): + if not _same_namedtuples(nest1, nest2): + raise TypeError( + "The two namedtuples don't have the same sequence type. First " + "structure has type %s, while second structure has type %s." + % (type_nest1, type_nest2)) + else: + if type_nest1 != type_nest2: + raise TypeError( + "The two structures don't have the same sequence type. First " + "structure has type %s, while second structure has type %s." + % (type_nest1, type_nest2)) if isinstance(nest1, dict): keys1 = set(_six.iterkeys(nest1)) @@ -178,13 +232,24 @@ def _recursive_assert_same_structure(nest1, nest2, check_types): def assert_same_structure(nest1, nest2, check_types=True): """Asserts that two structures are nested in the same way. + Note that namedtuples with identical name and fields are always considered + to have the same shallow structure (even with `check_types=True`). + For intance, this code will print `True`: + + ```python + def nt(a, b): + return collections.namedtuple('foo', 'a b')(a, b) + print(assert_same_structure(nt(0, 1), nt(2, 3))) + ``` + Args: nest1: an arbitrarily nested structure. nest2: an arbitrarily nested structure. check_types: if `True` (default) types of sequences are checked as well, including the keys of dictionaries. If set to `False`, for example a list and a tuple of objects will look the same if they have the same - size. + size. Note that namedtuples with identical name and fields are always + considered to have the same shallow structure. Raises: ValueError: If the two structures do not have the same number of elements or @@ -354,6 +419,8 @@ def map_structure(func, *structure, **check_types_dict): `True` (default) the types of iterables within the structures have to be same (e.g. `map_structure(func, [1], (1,))` raises a `TypeError` exception). To allow this set this argument to `False`. + Note that namedtuples with identical name and fields are always + considered to have the same shallow structure. Returns: A new structure with the same arity as `structure`, whose values correspond diff --git a/tensorflow/python/util/nest_test.py b/tensorflow/python/util/nest_test.py index 4906649f013da38f6b18f1645958aa4b244a9d05..6bec397db577c5be5847a701ccc92367dc008fc9 100644 --- a/tensorflow/python/util/nest_test.py +++ b/tensorflow/python/util/nest_test.py @@ -258,6 +258,36 @@ class NestTest(test.TestCase): "don't have the same set of keys"): nest.assert_same_structure({"a": 1}, {"b": 1}) + same_name_type_0 = collections.namedtuple("same_name", ("a", "b")) + same_name_type_1 = collections.namedtuple("same_name", ("a", "b")) + nest.assert_same_structure(same_name_type_0(0, 1), same_name_type_1(2, 3)) + + # This assertion is expected to pass: two namedtuples with the same + # name and field names are considered to be identical. + same_name_type_2 = collections.namedtuple("same_name_1", ("x", "y")) + same_name_type_3 = collections.namedtuple("same_name_1", ("x", "y")) + nest.assert_same_structure( + same_name_type_0(same_name_type_2(0, 1), 2), + same_name_type_1(same_name_type_3(2, 3), 4)) + + expected_message = "The two structures don't have the same.*" + with self.assertRaisesRegexp(ValueError, expected_message): + nest.assert_same_structure(same_name_type_0(0, same_name_type_1(1, 2)), + same_name_type_1(same_name_type_0(0, 1), 2)) + + same_name_type_1 = collections.namedtuple("not_same_name", ("a", "b")) + self.assertRaises(TypeError, nest.assert_same_structure, + same_name_type_0(0, 1), same_name_type_1(2, 3)) + + same_name_type_1 = collections.namedtuple("same_name", ("x", "y")) + self.assertRaises(TypeError, nest.assert_same_structure, + same_name_type_0(0, 1), same_name_type_1(2, 3)) + + class SameNamedType1(collections.namedtuple("same_name", ("a", "b"))): + pass + self.assertRaises(TypeError, nest.assert_same_structure, + same_name_type_0(0, 1), SameNamedType1(2, 3)) + def testMapStructure(self): structure1 = (((1, 2), 3), 4, (5, 6)) structure2 = (((7, 8), 9), 10, (11, 12)) diff --git a/tensorflow/python/util/util.h b/tensorflow/python/util/util.h index 493d26b497d714b318a345c96462d2d01de789c9..2af71dc753760e7efaf28cc500d5296a31957a04 100644 --- a/tensorflow/python/util/util.h +++ b/tensorflow/python/util/util.h @@ -14,8 +14,8 @@ limitations under the License. ==============================================================================*/ // Functions for getting information about kernels registered in the binary. -#ifndef THIRD_PARTY_TENSORFLOW_PYTHON_UTIL_UTIL_H_ -#define THIRD_PARTY_TENSORFLOW_PYTHON_UTIL_UTIL_H_ +#ifndef TENSORFLOW_PYTHON_UTIL_UTIL_H_ +#define TENSORFLOW_PYTHON_UTIL_UTIL_H_ #include @@ -71,4 +71,4 @@ void RegisterSequenceClass(PyObject* sequence_class); } // namespace swig } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_PYTHON_UTIL_UTIL_H_ +#endif // TENSORFLOW_PYTHON_UTIL_UTIL_H_ diff --git a/tensorflow/stream_executor/cuda/cuda_diagnostics.cc b/tensorflow/stream_executor/cuda/cuda_diagnostics.cc index f35542e18fdba2b92f12b950e432937d0a1ef577..933c103f524ef37f840c9e13b9e4024289e274c1 100644 --- a/tensorflow/stream_executor/cuda/cuda_diagnostics.cc +++ b/tensorflow/stream_executor/cuda/cuda_diagnostics.cc @@ -232,7 +232,7 @@ port::StatusOr Diagnostician::FindDsoVersion() { result = StringToDriverVersion(version); } #else -#if !defined(PLATFORM_WINDOWS) && !defined(NVIDIA_TEGRA) +#if !defined(PLATFORM_WINDOWS) && !defined(ANDROID_TEGRA) // Callback used when iterating through DSOs. Looks for the driver-interfacing // DSO and yields its version number into the callback data, when found. auto iterate_phdr = diff --git a/tensorflow/stream_executor/dso_loader.cc b/tensorflow/stream_executor/dso_loader.cc index 5210a81092b3023563baa7edbb657b630dfc819a..d71938634d6e6fe092d9a1e0861215bb101e824f 100644 --- a/tensorflow/stream_executor/dso_loader.cc +++ b/tensorflow/stream_executor/dso_loader.cc @@ -96,10 +96,18 @@ string GetCudnnVersion() { return TF_CUDNN_VERSION; } } /* static */ port::Status DsoLoader::GetLibcuptiDsoHandle(void** dso_handle) { +#if defined(ANDROID_TEGRA) + // On Android devices the CUDA version number is not added to the library name. + return GetDsoHandle(FindDsoPath(port::Env::Default()->FormatLibraryFileName( + "cupti", ""), + GetCudaCuptiLibraryPath()), + dso_handle); +#else return GetDsoHandle(FindDsoPath(port::Env::Default()->FormatLibraryFileName( "cupti", GetCudaVersion()), GetCudaCuptiLibraryPath()), dso_handle); +#endif } static mutex& GetRpathMutex() { diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl index 383c97344a068e0174037f986baca21671f376e7..f32d4561550c0ff60511047c87821dffe736c935 100644 --- a/tensorflow/tensorflow.bzl +++ b/tensorflow/tensorflow.bzl @@ -258,6 +258,8 @@ def _rpath_linkopts(name): clean_dep("//tensorflow:darwin"): [ "-Wl,%s" % (_make_search_paths("@loader_path", levels_to_root),), ], + clean_dep("//tensorflow:windows"): [], + clean_dep("//tensorflow:windows_msvc"): [], "//conditions:default": [ "-Wl,%s" % (_make_search_paths("$$ORIGIN", levels_to_root),), ], @@ -289,6 +291,7 @@ def tf_cc_shared_object( "-Wl,-install_name,@rpath/" + name.split("/")[-1], ], "//conditions:default": [ + "-Wl,-soname," + name.split("/")[-1], ], }), **kwargs) @@ -600,6 +603,8 @@ def tf_cc_test(name, "//tensorflow:android": [ "-pie", ], + clean_dep("//tensorflow:windows"): [], + clean_dep("//tensorflow:windows_msvc"): [], "//conditions:default": [ "-lpthread", "-lm" @@ -1246,6 +1251,8 @@ def tf_custom_op_library(name, srcs=[], gpu_srcs=[], deps=[], linkopts=[]): "//conditions:default": [ "-lm", ], + clean_dep("//tensorflow:windows"): [], + clean_dep("//tensorflow:windows_msvc"): [], clean_dep("//tensorflow:darwin"): [], }),) diff --git a/tensorflow/tools/api/generator/BUILD b/tensorflow/tools/api/generator/BUILD index fa0f9b59aa938168cb3d318797c797eeabc9c7d9..d11031639592aa1d3e6ce1c7f09c2f0679b29854 100644 --- a/tensorflow/tools/api/generator/BUILD +++ b/tensorflow/tools/api/generator/BUILD @@ -46,11 +46,37 @@ genrule( "api/bitwise/__init__.py", "api/contrib/__init__.py", "api/contrib/stat_summarizer/__init__.py", + "api/distributions/__init__.py", + "api/distributions/bijectors/__init__.py", + "api/errors/__init__.py", "api/image/__init__.py", "api/linalg/__init__.py", "api/nn/__init__.py", "api/spectral/__init__.py", "api/train/__init__.py", + "api/app/__init__.py", + "api/gfile/__init__.py", + "api/graph_util/__init__.py", + "api/keras/__init__.py", + "api/keras/backend/__init__.py", + "api/keras/datasets/__init__.py", + "api/keras/datasets/boston_housing/__init__.py", + "api/keras/datasets/cifar10/__init__.py", + "api/keras/datasets/cifar100/__init__.py", + "api/keras/datasets/imdb/__init__.py", + "api/keras/datasets/mnist/__init__.py", + "api/keras/datasets/reuters/__init__.py", + "api/keras/utils/__init__.py", + "api/logging/__init__.py", + "api/resource_loader/__init__.py", + "api/sysconfig/__init__.py", + "api/test/__init__.py", + "api/initializers/__init__.py", + "api/keras/initializers/__init__.py", + "api/metrics/__init__.py", + "api/nn/rnn_cell/__init__.py", + "api/sets/__init__.py", + "api/summary/__init__.py", ], cmd = "$(location create_python_api) $(OUTS)", tools = ["create_python_api"], @@ -60,4 +86,7 @@ py_library( name = "python_api", srcs = [":python_api_gen"], srcs_version = "PY2AND3", + deps = [ + "//tensorflow/contrib:contrib_py", # keep + ], ) diff --git a/tensorflow/tools/api/generator/create_python_api.py b/tensorflow/tools/api/generator/create_python_api.py index aab856b723cf2686e8fc9feb156b9be28470fc98..1557314939bd85c0467426216f90aa3891ca0ac0 100644 --- a/tensorflow/tools/api/generator/create_python_api.py +++ b/tensorflow/tools/api/generator/create_python_api.py @@ -31,6 +31,7 @@ from tensorflow.python.util import tf_decorator _API_CONSTANTS_ATTR = '_tf_api_constants' _API_NAMES_ATTR = '_tf_api_names' _API_DIR = '/api/' +_CONTRIB_IMPORT = 'from tensorflow import contrib' _GENERATED_FILE_HEADER = """\"\"\"Imports for Python API. This file is MACHINE GENERATED! Do not edit. @@ -50,11 +51,17 @@ def format_import(source_module_name, source_name, dest_name): Returns: An import statement string. """ - if source_name == dest_name: - return 'from %s import %s' % (source_module_name, source_name) + if source_module_name: + if source_name == dest_name: + return 'from %s import %s' % (source_module_name, source_name) + else: + return 'from %s import %s as %s' % ( + source_module_name, source_name, dest_name) else: - return 'from %s import %s as %s' % ( - source_module_name, source_name, dest_name) + if source_name == dest_name: + return 'import %s' % source_name + else: + return 'import %s as %s' % (source_name, dest_name) def get_api_imports(): @@ -74,6 +81,9 @@ def get_api_imports(): # Only look at tensorflow modules. if not module or 'tensorflow.' not in module.__name__: continue + # Do not generate __init__.py files for contrib modules for now. + if '.contrib.' in module.__name__ or module.__name__.endswith('.contrib'): + continue for module_contents_name in dir(module): attr = getattr(module, module_contents_name) @@ -151,21 +161,28 @@ def create_api_files(output_files): os.makedirs(os.path.dirname(file_path)) open(file_path, 'a').close() - # Add imports to output files. module_imports = get_api_imports() + module_imports['tf'].append(_CONTRIB_IMPORT) # Include all of contrib. + + # Add imports to output files. missing_output_files = [] for module, exports in module_imports.items(): # Make sure genrule output file list is in sync with API exports. if module not in module_name_to_file_path: - missing_output_files.append(module) + module_without_tf = module[len('tf.'):] + module_file_path = '"api/%s/__init__.py"' % ( + module_without_tf.replace('.', '/')) + missing_output_files.append(module_file_path) continue with open(module_name_to_file_path[module], 'w') as fp: fp.write(_GENERATED_FILE_HEADER + '\n'.join(exports)) if missing_output_files: raise ValueError( - 'Missing outputs for python_api_gen genrule:\n%s' % - ',\n'.join(missing_output_files)) + 'Missing outputs for python_api_gen genrule:\n%s.' + 'Make sure all required outputs are in the ' + 'tensorflow/tools/api/generator/BUILD file.' % + ',\n'.join(sorted(missing_output_files))) def main(output_files): diff --git a/tensorflow/tools/api/golden/tensorflow.compat.pbtxt b/tensorflow/tools/api/golden/tensorflow.compat.pbtxt index ccc60314001f261a2b4a5560bea83ffa017fd914..bab480ff9b105546790aadb72f3eb88a795ebbff 100644 --- a/tensorflow/tools/api/golden/tensorflow.compat.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.compat.pbtxt @@ -32,4 +32,8 @@ tf_module { name: "as_text" argspec: "args=[\'bytes_or_text\', \'encoding\'], varargs=None, keywords=None, defaults=[\'utf-8\'], " } + member_method { + name: "path_to_str" + argspec: "args=[\'path\'], varargs=None, keywords=None, defaults=None" + } } diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-classifier.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-classifier.pbtxt index 46d59570577d0e31f61687e445f24770c561764d..efc441ae2f2a00f663c11f84c1155bece0c8e08a 100644 --- a/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-classifier.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-classifier.pbtxt @@ -21,7 +21,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'hidden_units\', \'feature_columns\', \'model_dir\', \'n_classes\', \'weight_column\', \'label_vocabulary\', \'optimizer\', \'activation_fn\', \'dropout\', \'input_layer_partitioner\', \'config\', \'warm_start_from\'], varargs=None, keywords=None, defaults=[\'None\', \'2\', \'None\', \'None\', \'Adagrad\', \'\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'hidden_units\', \'feature_columns\', \'model_dir\', \'n_classes\', \'weight_column\', \'label_vocabulary\', \'optimizer\', \'activation_fn\', \'dropout\', \'input_layer_partitioner\', \'config\', \'warm_start_from\', \'loss_reduction\'], varargs=None, keywords=None, defaults=[\'None\', \'2\', \'None\', \'None\', \'Adagrad\', \'\', \'None\', \'None\', \'None\', \'None\', \'weighted_sum\'], " } member_method { name: "evaluate" diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-linear-combined-classifier.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-linear-combined-classifier.pbtxt index 439e87375ba09de0b42fb483588bb51bf80b0476..20ce87987060d9013bd071d6fc9f1f4f33467121 100644 --- a/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-linear-combined-classifier.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-linear-combined-classifier.pbtxt @@ -21,7 +21,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'model_dir\', \'linear_feature_columns\', \'linear_optimizer\', \'dnn_feature_columns\', \'dnn_optimizer\', \'dnn_hidden_units\', \'dnn_activation_fn\', \'dnn_dropout\', \'n_classes\', \'weight_column\', \'label_vocabulary\', \'input_layer_partitioner\', \'config\', \'warm_start_from\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'Ftrl\', \'None\', \'Adagrad\', \'None\', \'\', \'None\', \'2\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'model_dir\', \'linear_feature_columns\', \'linear_optimizer\', \'dnn_feature_columns\', \'dnn_optimizer\', \'dnn_hidden_units\', \'dnn_activation_fn\', \'dnn_dropout\', \'n_classes\', \'weight_column\', \'label_vocabulary\', \'input_layer_partitioner\', \'config\', \'warm_start_from\', \'loss_reduction\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'Ftrl\', \'None\', \'Adagrad\', \'None\', \'\', \'None\', \'2\', \'None\', \'None\', \'None\', \'None\', \'None\', \'weighted_sum\'], " } member_method { name: "evaluate" diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-linear-combined-regressor.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-linear-combined-regressor.pbtxt index f79a8be3f69be7b19b0708023d440922e4cafdeb..73211aaf8ba5f925982afe3d17c4b8f009250cb8 100644 --- a/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-linear-combined-regressor.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-linear-combined-regressor.pbtxt @@ -21,7 +21,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'model_dir\', \'linear_feature_columns\', \'linear_optimizer\', \'dnn_feature_columns\', \'dnn_optimizer\', \'dnn_hidden_units\', \'dnn_activation_fn\', \'dnn_dropout\', \'label_dimension\', \'weight_column\', \'input_layer_partitioner\', \'config\', \'warm_start_from\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'Ftrl\', \'None\', \'Adagrad\', \'None\', \'\', \'None\', \'1\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'model_dir\', \'linear_feature_columns\', \'linear_optimizer\', \'dnn_feature_columns\', \'dnn_optimizer\', \'dnn_hidden_units\', \'dnn_activation_fn\', \'dnn_dropout\', \'label_dimension\', \'weight_column\', \'input_layer_partitioner\', \'config\', \'warm_start_from\', \'loss_reduction\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'Ftrl\', \'None\', \'Adagrad\', \'None\', \'\', \'None\', \'1\', \'None\', \'None\', \'None\', \'None\', \'weighted_sum\'], " } member_method { name: "evaluate" diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-regressor.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-regressor.pbtxt index c466dcb4c23eb36e8313df23c68da8ee39104c7b..27a159639d2098aace2e69718d9ac4e38a29fdc3 100644 --- a/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-regressor.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-regressor.pbtxt @@ -21,7 +21,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'hidden_units\', \'feature_columns\', \'model_dir\', \'label_dimension\', \'weight_column\', \'optimizer\', \'activation_fn\', \'dropout\', \'input_layer_partitioner\', \'config\', \'warm_start_from\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'None\', \'Adagrad\', \'\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'hidden_units\', \'feature_columns\', \'model_dir\', \'label_dimension\', \'weight_column\', \'optimizer\', \'activation_fn\', \'dropout\', \'input_layer_partitioner\', \'config\', \'warm_start_from\', \'loss_reduction\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'None\', \'Adagrad\', \'\', \'None\', \'None\', \'None\', \'None\', \'weighted_sum\'], " } member_method { name: "evaluate" diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-estimator.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-estimator.pbtxt index d0bf043754b60240c507fe34b21b0599b94b69e2..76f527f796e95f342eb144ae3de87ff234338021 100644 --- a/tensorflow/tools/api/golden/tensorflow.estimator.-estimator.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.estimator.-estimator.pbtxt @@ -20,7 +20,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'model_fn\', \'model_dir\', \'config\', \'params\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'model_fn\', \'model_dir\', \'config\', \'params\', \'warm_start_from\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " } member_method { name: "evaluate" diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-linear-classifier.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-linear-classifier.pbtxt index cb9e95588dbec1b3ee367be9b61f6f3bc1f77725..c45318b98a034255d32c326179813de14cf1d4c8 100644 --- a/tensorflow/tools/api/golden/tensorflow.estimator.-linear-classifier.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.estimator.-linear-classifier.pbtxt @@ -21,7 +21,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'feature_columns\', \'model_dir\', \'n_classes\', \'weight_column\', \'label_vocabulary\', \'optimizer\', \'config\', \'partitioner\', \'warm_start_from\'], varargs=None, keywords=None, defaults=[\'None\', \'2\', \'None\', \'None\', \'Ftrl\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'feature_columns\', \'model_dir\', \'n_classes\', \'weight_column\', \'label_vocabulary\', \'optimizer\', \'config\', \'partitioner\', \'warm_start_from\', \'loss_reduction\'], varargs=None, keywords=None, defaults=[\'None\', \'2\', \'None\', \'None\', \'Ftrl\', \'None\', \'None\', \'None\', \'weighted_sum\'], " } member_method { name: "evaluate" diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-linear-regressor.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-linear-regressor.pbtxt index 637f19ba2614265e69147093f2f21f1f9393d244..04a2aa080d0704a8b7ec98f8eafda4bd1944e567 100644 --- a/tensorflow/tools/api/golden/tensorflow.estimator.-linear-regressor.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.estimator.-linear-regressor.pbtxt @@ -21,7 +21,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'feature_columns\', \'model_dir\', \'label_dimension\', \'weight_column\', \'optimizer\', \'config\', \'partitioner\', \'warm_start_from\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'None\', \'Ftrl\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'feature_columns\', \'model_dir\', \'label_dimension\', \'weight_column\', \'optimizer\', \'config\', \'partitioner\', \'warm_start_from\', \'loss_reduction\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'None\', \'Ftrl\', \'None\', \'None\', \'None\', \'weighted_sum\'], " } member_method { name: "evaluate" diff --git a/tensorflow/tools/api/golden/tensorflow.feature_column.pbtxt b/tensorflow/tools/api/golden/tensorflow.feature_column.pbtxt index 018e8c909a23a9e7093c1bb411643d7db629b21c..24a58fb118bf52e650e1df71e9374099745ade52 100644 --- a/tensorflow/tools/api/golden/tensorflow.feature_column.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.feature_column.pbtxt @@ -48,6 +48,10 @@ tf_module { name: "numeric_column" argspec: "args=[\'key\', \'shape\', \'default_value\', \'dtype\', \'normalizer_fn\'], varargs=None, keywords=None, defaults=[\'(1,)\', \'None\', \"\", \'None\'], " } + member_method { + name: "shared_embedding_columns" + argspec: "args=[\'categorical_columns\', \'dimension\', \'combiner\', \'initializer\', \'shared_embedding_collection_name\', \'ckpt_to_load_from\', \'tensor_name_in_ckpt\', \'max_norm\', \'trainable\'], varargs=None, keywords=None, defaults=[\'mean\', \'None\', \'None\', \'None\', \'None\', \'None\', \'True\'], " + } member_method { name: "weighted_categorical_column" argspec: "args=[\'categorical_column\', \'weight_feature_key\', \'dtype\'], varargs=None, keywords=None, defaults=[\"\"], " diff --git a/tensorflow/tools/api/golden/tensorflow.image.pbtxt b/tensorflow/tools/api/golden/tensorflow.image.pbtxt index f32353c9570cb6c0f6536f5e9093a690c2522db5..baedf596e8fbce921ed7e0570542b8a11655dba4 100644 --- a/tensorflow/tools/api/golden/tensorflow.image.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.image.pbtxt @@ -168,13 +168,21 @@ tf_module { name: "rgb_to_hsv" argspec: "args=[\'images\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "rgb_to_yiq" + argspec: "args=[\'images\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "rgb_to_yuv" + argspec: "args=[\'images\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "rot90" argspec: "args=[\'image\', \'k\', \'name\'], varargs=None, keywords=None, defaults=[\'1\', \'None\'], " } member_method { name: "sample_distorted_bounding_box" - argspec: "args=[\'image_size\', \'bounding_boxes\', \'seed\', \'seed2\', \'min_object_covered\', \'aspect_ratio_range\', \'area_range\', \'max_attempts\', \'use_image_if_no_bounding_boxes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'image_size\', \'bounding_boxes\', \'seed\', \'seed2\', \'min_object_covered\', \'aspect_ratio_range\', \'area_range\', \'max_attempts\', \'use_image_if_no_bounding_boxes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'0.1\', \'None\', \'None\', \'None\', \'None\', \'None\'], " } member_method { name: "total_variation" @@ -184,4 +192,12 @@ tf_module { name: "transpose_image" argspec: "args=[\'image\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "yiq_to_rgb" + argspec: "args=[\'images\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "yuv_to_rgb" + argspec: "args=[\'images\'], varargs=None, keywords=None, defaults=None" + } } diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv3-d-transpose.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv3-d-transpose.pbtxt index d898c546278188ca84a94660d9dc0c7be03e0b24..11e05f884d781166616a9c9a61dacbc8fdae6ae3 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv3-d-transpose.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv3-d-transpose.pbtxt @@ -1,6 +1,7 @@ path: "tensorflow.keras.layers.Conv3DTranspose" tf_class { is_instance: "" + is_instance: "" is_instance: "" is_instance: "" is_instance: "" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution3-d-transpose.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution3-d-transpose.pbtxt index a7001bbe34f899bdba6c49f7d2d1c7d9becc1313..58724a1e1661609ef3c000c7ca1dfe9b3235acff 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution3-d-transpose.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution3-d-transpose.pbtxt @@ -1,6 +1,7 @@ path: "tensorflow.keras.layers.Convolution3DTranspose" tf_class { is_instance: "" + is_instance: "" is_instance: "" is_instance: "" is_instance: "" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-conv2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-conv2-d.pbtxt index a6d9b57c8813acc85436cc08041159c17c252806..5d898fb2bd86b39cb8fab755382bb96cce231fa6 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-conv2-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-conv2-d.pbtxt @@ -2,7 +2,7 @@ path: "tensorflow.keras.layers.SeparableConv2D" tf_class { is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" is_instance: "" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-convolution2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-convolution2-d.pbtxt index 551d6953796fcc63b6d9b58fec5a45ef03f6dc2a..c758d87993b3acba88a13c7bc9eaeee929a22652 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-convolution2-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-convolution2-d.pbtxt @@ -2,7 +2,7 @@ path: "tensorflow.keras.layers.SeparableConvolution2D" tf_class { is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" is_instance: "" diff --git a/tensorflow/tools/api/golden/tensorflow.layers.-separable-conv1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.layers.-separable-conv1-d.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..05799ecfc9fdb9ff44620a67dcdbdc4426fddced --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.layers.-separable-conv1-d.pbtxt @@ -0,0 +1,144 @@ +path: "tensorflow.layers.SeparableConv1D" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "graph" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "scope_name" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'dilation_rate\', \'depth_multiplier\', \'activation\', \'use_bias\', \'depthwise_initializer\', \'pointwise_initializer\', \'bias_initializer\', \'depthwise_regularizer\', \'pointwise_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'depthwise_constraint\', \'pointwise_constraint\', \'bias_constraint\', \'trainable\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'1\', \'valid\', \'channels_last\', \'1\', \'1\', \'None\', \'True\', \'None\', \'None\', \'\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'True\', \'None\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.layers.-separable-conv2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.layers.-separable-conv2-d.pbtxt index 4d91ab1d8c9b5d2c8f7db5fd645b3c126eb609c2..c2aeb35c4648bcce22ca73c838a85803a6b9cedf 100644 --- a/tensorflow/tools/api/golden/tensorflow.layers.-separable-conv2-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.layers.-separable-conv2-d.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.layers.SeparableConv2D" tf_class { is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" is_instance: "" diff --git a/tensorflow/tools/api/golden/tensorflow.layers.pbtxt b/tensorflow/tools/api/golden/tensorflow.layers.pbtxt index c45d6e6c05054f1c0c61caeaf5e9a3fd7d00983f..59134f84891ad5518dcb5331ce04475482c8b59e 100644 --- a/tensorflow/tools/api/golden/tensorflow.layers.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.layers.pbtxt @@ -68,6 +68,10 @@ tf_module { name: "MaxPooling3D" mtype: "" } + member { + name: "SeparableConv1D" + mtype: "" + } member { name: "SeparableConv2D" mtype: "" @@ -136,6 +140,10 @@ tf_module { name: "max_pooling3d" argspec: "args=[\'inputs\', \'pool_size\', \'strides\', \'padding\', \'data_format\', \'name\'], varargs=None, keywords=None, defaults=[\'valid\', \'channels_last\', \'None\'], " } + member_method { + name: "separable_conv1d" + argspec: "args=[\'inputs\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'dilation_rate\', \'depth_multiplier\', \'activation\', \'use_bias\', \'depthwise_initializer\', \'pointwise_initializer\', \'bias_initializer\', \'depthwise_regularizer\', \'pointwise_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'depthwise_constraint\', \'pointwise_constraint\', \'bias_constraint\', \'trainable\', \'name\', \'reuse\'], varargs=None, keywords=None, defaults=[\'1\', \'valid\', \'channels_last\', \'1\', \'1\', \'None\', \'True\', \'None\', \'None\', \'\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'True\', \'None\', \'None\'], " + } member_method { name: "separable_conv2d" argspec: "args=[\'inputs\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'dilation_rate\', \'depth_multiplier\', \'activation\', \'use_bias\', \'depthwise_initializer\', \'pointwise_initializer\', \'bias_initializer\', \'depthwise_regularizer\', \'pointwise_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'depthwise_constraint\', \'pointwise_constraint\', \'bias_constraint\', \'trainable\', \'name\', \'reuse\'], varargs=None, keywords=None, defaults=[\'(1, 1)\', \'valid\', \'channels_last\', \'(1, 1)\', \'1\', \'None\', \'True\', \'None\', \'None\', \'\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'True\', \'None\', \'None\'], " diff --git a/tensorflow/tools/api/golden/tensorflow.nn.pbtxt b/tensorflow/tools/api/golden/tensorflow.nn.pbtxt index 8ce022e4549712bb13dedcd66481a3ad2a2db0e5..455590d866a4c1ebea65ccff51e34f2e0b0479d7 100644 --- a/tensorflow/tools/api/golden/tensorflow.nn.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.nn.pbtxt @@ -262,7 +262,7 @@ tf_module { } member_method { name: "sampled_softmax_loss" - argspec: "args=[\'weights\', \'biases\', \'labels\', \'inputs\', \'num_sampled\', \'num_classes\', \'num_true\', \'sampled_values\', \'remove_accidental_hits\', \'partition_strategy\', \'name\'], varargs=None, keywords=None, defaults=[\'1\', \'None\', \'True\', \'mod\', \'sampled_softmax_loss\'], " + argspec: "args=[\'weights\', \'biases\', \'labels\', \'inputs\', \'num_sampled\', \'num_classes\', \'num_true\', \'sampled_values\', \'remove_accidental_hits\', \'partition_strategy\', \'name\', \'seed\'], varargs=None, keywords=None, defaults=[\'1\', \'None\', \'True\', \'mod\', \'sampled_softmax_loss\', \'None\'], " } member_method { name: "selu" diff --git a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-dropout-wrapper.pbtxt b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-dropout-wrapper.pbtxt index f61a5a28e3cd249a2cc2c84fc401cecc49a7945c..97edf245f6fbed393a6fb8dbf1e83649e9ac4b4e 100644 --- a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-dropout-wrapper.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-dropout-wrapper.pbtxt @@ -88,6 +88,10 @@ tf_class { name: "weights" mtype: "" } + member { + name: "wrapped_cell" + mtype: "" + } member_method { name: "__init__" argspec: "args=[\'self\', \'cell\', \'input_keep_prob\', \'output_keep_prob\', \'state_keep_prob\', \'variational_recurrent\', \'input_size\', \'dtype\', \'seed\', \'dropout_state_filter_visitor\'], varargs=None, keywords=None, defaults=[\'1.0\', \'1.0\', \'1.0\', \'False\', \'None\', \'None\', \'None\', \'None\'], " diff --git a/tensorflow/tools/api/golden/tensorflow.summary.pbtxt b/tensorflow/tools/api/golden/tensorflow.summary.pbtxt index 326e077d396bc5e3463bba3818f4757127ee0370..871ebb5247f62e9300566da063e4dadeb5087091 100644 --- a/tensorflow/tools/api/golden/tensorflow.summary.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.summary.pbtxt @@ -50,7 +50,7 @@ tf_module { } member_method { name: "merge_all" - argspec: "args=[\'key\'], varargs=None, keywords=None, defaults=[\'summaries\'], " + argspec: "args=[\'key\', \'scope\'], varargs=None, keywords=None, defaults=[\'summaries\', \'None\'], " } member_method { name: "scalar" diff --git a/tensorflow/tools/benchmark/BUILD b/tensorflow/tools/benchmark/BUILD index caa6629c491477ffcd108c52d7ce20f1ab95a0a9..6ed2594e6abe169577066678e1bf4b9e2df4c4d3 100644 --- a/tensorflow/tools/benchmark/BUILD +++ b/tensorflow/tools/benchmark/BUILD @@ -61,10 +61,11 @@ tf_cc_test( # This binary may be built for either desktop or Android. # A typical Android build command will look like the following: -# bazel build -c opt tensorflow/core:android_tensorflow_lib \ +# bazel build tensorflow/core:android_tensorflow_lib \ # --crosstool_top=//external:android/crosstool \ # --cpu=armeabi-v7a \ # --host_crosstool_top=@bazel_tools//tools/cpp:toolchain +# --config monolithic tf_cc_binary( name = "benchmark_model", testonly = 1, diff --git a/tensorflow/tools/benchmark/README.md b/tensorflow/tools/benchmark/README.md index ca0da2d41b91a385cd57dbe1ebaf4fe83ed380c9..e64af2bfe1a77e6883f0c2c7dd9303e6d4eb4ee6 100644 --- a/tensorflow/tools/benchmark/README.md +++ b/tensorflow/tools/benchmark/README.md @@ -17,6 +17,7 @@ bazel build -c opt \ --crosstool_top=//external:android/crosstool \ --cpu=armeabi-v7a \ --host_crosstool_top=@bazel_tools//tools/cpp:toolchain \ + --config monolithic \ tensorflow/tools/benchmark:benchmark_model ``` diff --git a/tensorflow/tools/ci_build/builds/libtensorflow.sh b/tensorflow/tools/ci_build/builds/libtensorflow.sh index 26713dded88ce7b93df0200ac84e11f8efb9baf3..9b3ff0cba7dcacc0f68a417299c31f7a0f413430 100755 --- a/tensorflow/tools/ci_build/builds/libtensorflow.sh +++ b/tensorflow/tools/ci_build/builds/libtensorflow.sh @@ -51,8 +51,8 @@ function build_libtensorflow_tarball() { rm -rf ${DIR} TARBALL_SUFFIX="${1}" - BAZEL="bazel --bazelrc ./tensorflow/tools/ci_build/install/.bazelrc" - BAZEL_OPTS="-c opt" + BAZEL_OPTS="-c opt --cxxopt=-D_GLIBCXX_USE_CXX11_ABI=0" + export CC_OPT_FLAGS='-mavx' if [ "${TF_NEED_CUDA}" == "1" ]; then BAZEL_OPTS="${BAZEL_OPTS} --config=cuda" fi diff --git a/tensorflow/tools/ci_build/ci_sanity.sh b/tensorflow/tools/ci_build/ci_sanity.sh index b728c878da0f729c74b20e66cfc97868c3e953f3..aa341b144cb8ef3c9a13635c62a7ae1be90b0994 100755 --- a/tensorflow/tools/ci_build/ci_sanity.sh +++ b/tensorflow/tools/ci_build/ci_sanity.sh @@ -26,6 +26,8 @@ SCRIPT_DIR=$( cd ${0%/*} && pwd -P ) source "${SCRIPT_DIR}/builds/builds_common.sh" +ROOT_DIR=$( cd "$SCRIPT_DIR/../../.." && pwd -P ) + # Helper functions die() { echo $@ @@ -418,15 +420,8 @@ do_bazel_nobuild() { } do_pip_smoke_test() { - BUILD_CMD="bazel build ${BAZEL_FLAGS} //tensorflow/tools/pip_package:pip_smoke_test" - ${BUILD_CMD} - cmd_status \ - "Pip smoke test has failed. Please make sure any new TensorFlow are added to the tensorflow/tools/pip_package:build_pip_package dependencies." - - RUN_CMD="bazel-bin/tensorflow/tools/pip_package/pip_smoke_test" - ${RUN_CMD} - cmd_status \ - "The pip smoke test failed." + cd "$ROOT_DIR/tensorflow/tools/pip_package" + python pip_smoke_test.py } do_code_link_check() { @@ -500,20 +495,23 @@ do_clang_format_check() { } do_check_load_py_test() { - BUILD_CMD="bazel build ${BAZEL_FLAGS} //tensorflow/tools/pip_package:check_load_py_test" - ${BUILD_CMD} - cmd_status \ - "check_load_py_test failed to build." + cd "$ROOT_DIR/tensorflow/tools/pip_package" + python check_load_py_test.py +} - BUILD_CMD="bazel-bin/tensorflow/tools/pip_package/check_load_py_test" - ${BUILD_CMD} - cmd_status \ - "check_load_py_test failed." +do_cmake_python_sanity() { + cd "$ROOT_DIR/tensorflow/contrib/cmake" + python -m unittest -v python_sanity_test +} + +do_check_futures_test() { + cd "$ROOT_DIR/tensorflow/tools/test" + python check_futures_test.py } # Supply all sanity step commands and descriptions -SANITY_STEPS=("do_pylint PYTHON2" "do_pylint PYTHON3" "do_buildifier" "do_bazel_nobuild" "do_pip_package_licenses_check" "do_lib_package_licenses_check" "do_java_package_licenses_check" "do_pip_smoke_test" "do_check_load_py_test" "do_code_link_check") -SANITY_STEPS_DESC=("Python 2 pylint" "Python 3 pylint" "buildifier check" "bazel nobuild" "pip: license check for external dependencies" "C library: license check for external dependencies" "Java Native Library: license check for external dependencies" "Pip Smoke Test: Checking py_test dependencies exist in pip package" "Check load py_test: Check that BUILD files with py_test target properly load py_test" "Code Link Check: Check there are no broken links") +SANITY_STEPS=("do_pylint PYTHON2" "do_pylint PYTHON3" "do_check_futures_test" "do_buildifier" "do_bazel_nobuild" "do_pip_package_licenses_check" "do_lib_package_licenses_check" "do_java_package_licenses_check" "do_pip_smoke_test" "do_check_load_py_test" "do_code_link_check" "do_cmake_python_sanity") +SANITY_STEPS_DESC=("Python 2 pylint" "Python 3 pylint" "Check that python files have certain __future__ imports" "buildifier check" "bazel nobuild" "pip: license check for external dependencies" "C library: license check for external dependencies" "Java Native Library: license check for external dependencies" "Pip Smoke Test: Checking py_test dependencies exist in pip package" "Check load py_test: Check that BUILD files with py_test target properly load py_test" "Code Link Check: Check there are no broken links" "Test entries in /tensorflow/contrib/cmake/python_{modules|protos|protos_cc}.txt for validity and consistency") INCREMENTAL_FLAG="" DEFAULT_BAZEL_CONFIGS="--config=hdfs --config=gcp" @@ -548,7 +546,10 @@ while [[ ${COUNTER} -lt "${#SANITY_STEPS[@]}" ]]; do "${SANITY_STEPS[COUNTER]} (${SANITY_STEPS_DESC[COUNTER]}) ===" echo "" + # subshell: don't leak variables or changes of working directory + ( ${SANITY_STEPS[COUNTER]} ${INCREMENTAL_FLAG} + ) RESULT=$? if [[ ${RESULT} != "0" ]]; then diff --git a/tensorflow/tools/ci_build/linux/cpu/run_cc_core.sh b/tensorflow/tools/ci_build/linux/cpu/run_cc_core.sh index e3e6b2f3166e18dc29ae24671889230d3a4a71c7..51e10f81f82da7920e9d219eaec3e1eb2973b998 100755 --- a/tensorflow/tools/ci_build/linux/cpu/run_cc_core.sh +++ b/tensorflow/tools/ci_build/linux/cpu/run_cc_core.sh @@ -26,12 +26,13 @@ echo "" # Run configure. export TF_NEED_CUDA=0 +export CC_OPT_FLAGS='-mavx' # Only running cc tests, python version does not matter. export PYTHON_BIN_PATH=`which python` yes "" | $PYTHON_BIN_PATH configure.py # Run bazel test command. Double test timeouts to avoid flakes. bazel test --test_tag_filters=-no_oss,-gpu,-benchmark-test --test_lang_filters=cc,java -k \ - --jobs=${N_JOBS} --test_timeout 300,450,1200,3600 \ + --jobs=${N_JOBS} --test_timeout 300,450,1200,3600 --config=opt \ --test_output=errors -- \ //tensorflow/... -//tensorflow/compiler/... -//tensorflow/contrib/... diff --git a/tensorflow/tools/ci_build/linux/cpu/run_py2_core.sh b/tensorflow/tools/ci_build/linux/cpu/run_py2_core.sh index 5110d52f31c257a043177ede686817e6206fa2eb..ea14848b1ae74ef0c42d14678fde225d465512bf 100755 --- a/tensorflow/tools/ci_build/linux/cpu/run_py2_core.sh +++ b/tensorflow/tools/ci_build/linux/cpu/run_py2_core.sh @@ -26,11 +26,12 @@ echo "" # Run configure. export TF_NEED_CUDA=0 +export CC_OPT_FLAGS='-mavx' export PYTHON_BIN_PATH=`which python2` yes "" | $PYTHON_BIN_PATH configure.py # Run bazel test command. Double test timeouts to avoid flakes. bazel test --test_tag_filters=-no_oss,-oss_serial,-gpu,-benchmark-test --test_lang_filters=py -k \ - --jobs=${N_JOBS} --test_timeout 300,450,1200,3600 --build_tests_only \ + --jobs=${N_JOBS} --test_timeout 300,450,1200,3600 --build_tests_only --config=opt \ --test_output=errors -- \ //tensorflow/... -//tensorflow/compiler/... -//tensorflow/contrib/... diff --git a/tensorflow/tools/ci_build/linux/cpu/run_py3_contrib.sh b/tensorflow/tools/ci_build/linux/cpu/run_py3_contrib.sh index df6016504cec19e02af988e87733fc409cef6826..6d017c8a1f0232deab82278b26797a73b3a8ea9c 100755 --- a/tensorflow/tools/ci_build/linux/cpu/run_py3_contrib.sh +++ b/tensorflow/tools/ci_build/linux/cpu/run_py3_contrib.sh @@ -26,12 +26,13 @@ echo "" # Run configure. export TF_NEED_CUDA=0 +export CC_OPT_FLAGS='-mavx' export PYTHON_BIN_PATH=`which python3` yes "" | $PYTHON_BIN_PATH configure.py # Run bazel test command. Double test timeouts to avoid flakes. bazel test --test_tag_filters=-no_oss,-oss_serial,-gpu,-benchmark-test -k \ - --jobs=${N_JOBS} --test_timeout 300,450,1200,3600 \ + --jobs=${N_JOBS} --test_timeout 300,450,1200,3600 --config=opt \ --test_output=errors -- \ //tensorflow/contrib/... \ -//tensorflow/contrib/lite/... \ diff --git a/tensorflow/tools/ci_build/linux/cpu/run_py3_core.sh b/tensorflow/tools/ci_build/linux/cpu/run_py3_core.sh index ea9e102936bc56288acea051af3d3414766d38fb..a9accb9dd5b2d23e028a34ac3d99976d5f2f59db 100755 --- a/tensorflow/tools/ci_build/linux/cpu/run_py3_core.sh +++ b/tensorflow/tools/ci_build/linux/cpu/run_py3_core.sh @@ -26,11 +26,12 @@ echo "" # Run configure. export TF_NEED_CUDA=0 +export CC_OPT_FLAGS='-mavx' export PYTHON_BIN_PATH=`which python3` yes "" | $PYTHON_BIN_PATH configure.py # Run bazel test command. Double test timeouts to avoid flakes. bazel test --test_tag_filters=-no_oss,-oss_serial,-gpu,-benchmark-test --test_lang_filters=py -k \ - --jobs=${N_JOBS} --test_timeout 300,450,1200,3600 --build_tests_only \ + --jobs=${N_JOBS} --test_timeout 300,450,1200,3600 --build_tests_only --config=opt \ --test_output=errors -- \ //tensorflow/... -//tensorflow/compiler/... -//tensorflow/contrib/... diff --git a/tensorflow/tools/ci_build/linux/gpu/run_cc_core.sh b/tensorflow/tools/ci_build/linux/gpu/run_cc_core.sh index df196f829cd920b538fd0032950a9282c3043617..02224d8e9d9efd92b5c1658118bd0c45bdf4f1db 100755 --- a/tensorflow/tools/ci_build/linux/gpu/run_cc_core.sh +++ b/tensorflow/tools/ci_build/linux/gpu/run_cc_core.sh @@ -26,6 +26,7 @@ echo "" # Run configure. export PYTHON_BIN_PATH=`which python3` +export CC_OPT_FLAGS='-mavx' export TF_NEED_CUDA=1 export TF_CUDA_COMPUTE_CAPABILITIES=3.7 @@ -35,6 +36,6 @@ yes "" | $PYTHON_BIN_PATH configure.py # Run bazel test command. Double test timeouts to avoid flakes. bazel test --config=cuda --test_tag_filters=-no_oss,-oss_serial,-no_gpu,-benchmark-test -k \ --test_lang_filters=cc --jobs=${N_JOBS} --test_timeout 300,450,1200,3600 \ - --build_tests_only --test_output=errors --local_test_jobs=8 \ + --build_tests_only --test_output=errors --local_test_jobs=8 --config=opt \ --run_under=//tensorflow/tools/ci_build/gpu_build:parallel_gpu_execute -- \ //tensorflow/... -//tensorflow/compiler/... -//tensorflow/contrib/... diff --git a/tensorflow/tools/ci_build/linux/gpu/run_py3_core.sh b/tensorflow/tools/ci_build/linux/gpu/run_py3_core.sh index abd256a895ea751f84ec946a85a4331fe5b23440..0367a53d1459e7207a76c83e0c1e5c83580722a7 100755 --- a/tensorflow/tools/ci_build/linux/gpu/run_py3_core.sh +++ b/tensorflow/tools/ci_build/linux/gpu/run_py3_core.sh @@ -26,6 +26,7 @@ echo "" # Run configure. export PYTHON_BIN_PATH=`which python3` +export CC_OPT_FLAGS='-mavx' export TF_NEED_CUDA=1 export TF_CUDA_COMPUTE_CAPABILITIES=3.7 @@ -35,6 +36,6 @@ yes "" | $PYTHON_BIN_PATH configure.py # Run bazel test command. Double test timeouts to avoid flakes. bazel test --config=cuda --test_tag_filters=-no_oss,-oss_serial,-no_gpu,-benchmark-test -k \ --test_lang_filters=py --jobs=${N_JOBS} --test_timeout 300,450,1200,3600 \ - --build_tests_only --test_output=errors --local_test_jobs=8 \ + --build_tests_only --test_output=errors --local_test_jobs=8 --config=opt \ --run_under=//tensorflow/tools/ci_build/gpu_build:parallel_gpu_execute -- \ //tensorflow/... -//tensorflow/compiler/... -//tensorflow/contrib/... diff --git a/tensorflow/tools/ci_build/osx/cpu/run_contrib.sh b/tensorflow/tools/ci_build/osx/cpu/run_contrib.sh index ddaaddc9179ab640ce5b09b4d8732944b8177f8a..509ee38ec4fd584037f8e43726c01391430c1817 100755 --- a/tensorflow/tools/ci_build/osx/cpu/run_contrib.sh +++ b/tensorflow/tools/ci_build/osx/cpu/run_contrib.sh @@ -27,11 +27,12 @@ echo "" # Run configure. export TF_NEED_CUDA=0 +export CC_OPT_FLAGS='-mavx' export PYTHON_BIN_PATH=$(which python2) yes "" | $PYTHON_BIN_PATH configure.py which bazel bazel test --test_tag_filters=-no_oss,-gpu,-benchmark-test,-nomac \ --test_timeout 300,450,1200,3600 \ - --test_size_filters=small,medium \ + --test_size_filters=small,medium --config=opt \ --jobs=${N_JOBS} --build_tests_only --test_output=errors -k -- \ //tensorflow/contrib/... -//tensorflow/contrib/lite/... diff --git a/tensorflow/tools/ci_build/osx/cpu/run_py2_cc_core.sh b/tensorflow/tools/ci_build/osx/cpu/run_py2_cc_core.sh index e026dcd08f1e4ba88cd231fa33cb29ce3e916652..05547136704394ed9262f566a2bfb4160b73c7fd 100755 --- a/tensorflow/tools/ci_build/osx/cpu/run_py2_cc_core.sh +++ b/tensorflow/tools/ci_build/osx/cpu/run_py2_cc_core.sh @@ -27,11 +27,12 @@ echo "" # Run configure. export TF_NEED_CUDA=0 +export CC_OPT_FLAGS='-mavx' export PYTHON_BIN_PATH=$(which python2) yes "" | $PYTHON_BIN_PATH configure.py which bazel bazel test --test_tag_filters=-no_oss,-gpu,-benchmark-test,-nomac \ - --test_timeout 300,450,1200,3600 \ + --test_timeout 300,450,1200,3600 --config=opt \ --test_size_filters=small,medium \ --jobs=${N_JOBS} --build_tests_only --test_output=errors -k -- \ //tensorflow/... -//tensorflow/compiler/... -//tensorflow/contrib/... diff --git a/tensorflow/tools/ci_build/update_version.py b/tensorflow/tools/ci_build/update_version.py index d2a63e5d66a34f61d17e8327d4b25320371c4fa3..347d0769a92cc767f2e263fce0e21d7d0bc8e586 100755 --- a/tensorflow/tools/ci_build/update_version.py +++ b/tensorflow/tools/ci_build/update_version.py @@ -25,19 +25,19 @@ # pylint: disable=superfluous-parens import argparse -import fileinput import os import re import subprocess import time -# File parameters +# File parameters. TF_SRC_DIR = "tensorflow" VERSION_H = "%s/core/public/version.h" % TF_SRC_DIR SETUP_PY = "%s/tools/pip_package/setup.py" % TF_SRC_DIR README_MD = "./README.md" DEVEL_DOCKERFILE = "%s/tools/docker/Dockerfile.devel" % TF_SRC_DIR GPU_DEVEL_DOCKERFILE = "%s/tools/docker/Dockerfile.devel-gpu" % TF_SRC_DIR +CPU_MKL_DEVEL_DOCKERFILE = "%s/tools/docker/Dockerfile.devel-cpu-mkl" % TF_SRC_DIR RELEVANT_FILES = [TF_SRC_DIR, VERSION_H, SETUP_PY, @@ -45,17 +45,11 @@ RELEVANT_FILES = [TF_SRC_DIR, DEVEL_DOCKERFILE, GPU_DEVEL_DOCKERFILE] -# Version type parameters +# Version type parameters. NIGHTLY_VERSION = 1 REGULAR_VERSION = 0 -def replace_line(old_line, new_line, filename): - """Replace a line in a file.""" - for line in fileinput.input(filename, inplace=True): - print(line.rstrip().replace(old_line, new_line)) - - def check_existence(filename): """Check the existence of file or dir.""" if not os.path.exists(filename): @@ -69,9 +63,12 @@ def check_all_files(): check_existence(file_name) -def replace_with_sed(query, filename): +def replace_string_in_line(search, replace, filename): """Replace with sed when regex is required.""" - subprocess.check_call(['sed', '-i', '-r', '-e', query, filename]) + with open(filename, "r") as source: + content = source.read() + with open(filename, "w") as source: + source.write(re.sub(search, replace, content)) class Version(object): @@ -125,13 +122,13 @@ class Version(object): Raises: RuntimeError: If the version string is not valid. """ - # Check validity of new version string + # Check validity of new version string. if not re.search(r"[0-9]+\.[0-9]+\.[a-zA-Z0-9]+", string): raise RuntimeError("Invalid version string: %s" % string) major, minor, extension = string.split(".", 2) - # Isolate patch and identifier string if identifier string exists + # Isolate patch and identifier string if identifier string exists. extension_split = extension.split("-", 1) patch = extension_split[0] if len(extension_split) == 2: @@ -154,7 +151,7 @@ def get_current_semver_version(): core/public/version.h """ - # Get current version information + # Get current version information. version_file = open(VERSION_H, "r") for line in version_file: major_match = re.search("^#define TF_MAJOR_VERSION ([0-9]+)", line) @@ -185,32 +182,33 @@ def get_current_semver_version(): def update_version_h(old_version, new_version): """Update tensorflow/core/public/version.h.""" - replace_line("#define TF_MAJOR_VERSION %s" % old_version.major, - "#define TF_MAJOR_VERSION %s" % new_version.major, VERSION_H) - replace_line("#define TF_MINOR_VERSION %s" % old_version.minor, - "#define TF_MINOR_VERSION %s" % new_version.minor, VERSION_H) - replace_line("#define TF_PATCH_VERSION %s" % old_version.patch, - "#define TF_PATCH_VERSION %s" % new_version.patch, VERSION_H) - replace_line("#define TF_VERSION_SUFFIX \"%s\"" % - old_version.identifier_string, - "#define TF_VERSION_SUFFIX \"%s\"" - % new_version.identifier_string, - VERSION_H) + replace_string_in_line("#define TF_MAJOR_VERSION %s" % old_version.major, + "#define TF_MAJOR_VERSION %s" % new_version.major, + VERSION_H) + replace_string_in_line("#define TF_MINOR_VERSION %s" % old_version.minor, + "#define TF_MINOR_VERSION %s" % new_version.minor, + VERSION_H) + replace_string_in_line("#define TF_PATCH_VERSION %s" % old_version.patch, + "#define TF_PATCH_VERSION %s" % new_version.patch, + VERSION_H) + replace_string_in_line( + "#define TF_VERSION_SUFFIX \"%s\"" % old_version.identifier_string, + "#define TF_VERSION_SUFFIX \"%s\"" % new_version.identifier_string, + VERSION_H) def update_setup_dot_py(old_version, new_version): """Update setup.py.""" - replace_line("_VERSION = '%s'" % old_version.string, - "_VERSION = '%s'" % new_version.string, SETUP_PY) + replace_string_in_line("_VERSION = '%s'" % old_version.string, + "_VERSION = '%s'" % new_version.string, SETUP_PY) def update_readme(old_version, new_version): """Update README.""" pep_440_str = new_version.pep_440_str - replace_with_sed(r"s/%s\.%s\.([[:alnum:]]+)-/%s-/g" % (old_version.major, - old_version.minor, - pep_440_str), - README_MD) + replace_string_in_line(r"%s\.%s\.([[:alnum:]]+)-" % (old_version.major, + old_version.minor), + "%s-" % pep_440_str, README_MD) def update_md_files(old_version, new_version): @@ -226,22 +224,29 @@ def update_md_files(old_version, new_version): for filename in ["linux", "mac", "windows", "sources"]: filepath = "%s/docs_src/install/install_%s.md" % (TF_SRC_DIR, filename) - replace_with_sed("s/tensorflow-%s/tensorflow-%s/g" - % (old_pep_version, new_pep_version), filepath) - replace_with_sed("s/tensorflow_gpu-%s/tensorflow_gpu-%s/g" - % (old_pep_version, new_pep_version), filepath) - replace_with_sed("s/TensorFlow %s/TensorFlow %s/g" - % (old_pep_version, new_pep_version), filepath) + + if filename == "sources" and "rc0" in new_pep_version: + replace_string_in_line("(?)tensorflow-%s" % old_pep_version, + "tensorflow-%s" % new_pep_version, filepath) + replace_string_in_line("(?)tensorflow_gpu-%s" % old_pep_version, + "tensorflow_gpu-%s" % new_pep_version, filepath) + else: + replace_string_in_line("tensorflow-%s" % old_pep_version, + "tensorflow-%s" % new_pep_version, filepath) + replace_string_in_line("tensorflow_gpu-%s" % old_pep_version, + "tensorflow_gpu-%s" % new_pep_version, filepath) + replace_string_in_line("TensorFlow %s" % old_pep_version, + "TensorFlow %s" % new_pep_version, filepath) for filename in ["java", "go", "c"]: filepath = "%s/docs_src/install/install_%s.md" % (TF_SRC_DIR, filename) - replace_with_sed(r"s/x86_64-%s/x86_64-%s/g" - % (old_version, new_version), filepath) - replace_with_sed(r"s/libtensorflow-%s.jar/libtensorflow-%s.jar/g" - % (old_version, new_version), filepath) - replace_with_sed(r"s/%s<\/version>/%s<\/version>/g" - % (old_version, new_version), filepath) + replace_string_in_line(r"x86_64-%s" % old_version, + "x86_64-%s" % new_version, filepath) + replace_string_in_line(r"libtensorflow-%s.jar" % old_version, + "libtensorflow-%s.jar" % new_version, filepath) + replace_string_in_line(r"%s<\/version>" % old_version, + "%s" % new_version, filepath) def major_minor_change(old_version, new_version): @@ -266,10 +271,11 @@ def update_dockerfiles(old_version, new_version): % (old_r_major_minor_string, r_major_minor_string)) # Update dockerfiles - replace_with_sed("s/%s/%s/g" - % (old_r_major_minor, r_major_minor), DEVEL_DOCKERFILE) - replace_with_sed("s/%s/%s/g" - % (old_r_major_minor, r_major_minor), GPU_DEVEL_DOCKERFILE) + replace_string_in_line(old_r_major_minor, r_major_minor, DEVEL_DOCKERFILE) + replace_string_in_line(old_r_major_minor, r_major_minor, + GPU_DEVEL_DOCKERFILE) + replace_string_in_line(old_r_major_minor, r_major_minor, + CPU_MKL_DEVEL_DOCKERFILE) def check_for_lingering_string(lingering_string): @@ -333,7 +339,7 @@ def main(): old_version = get_current_semver_version() if args.nightly: - # dev minor version is one ahead of official + # Dev minor version is one ahead of official. nightly_minor_ver = int(old_version.minor) + 1 new_version = Version(old_version.major, str(nightly_minor_ver), @@ -349,12 +355,18 @@ def main(): update_md_files(old_version, new_version) update_dockerfiles(old_version, new_version) - # Print transition details + # Print transition details. print("Major: %s -> %s" % (old_version.major, new_version.major)) print("Minor: %s -> %s" % (old_version.minor, new_version.minor)) print("Patch: %s -> %s\n" % (old_version.patch, new_version.patch)) check_for_old_version(old_version, new_version) + if "rc0" in str(new_version): + print("\n\n\033[93mNOTE: Please update the tensorflow/docs_src/install/" + "install_sources.md and add a line for tensorflow-%s and " + "tensorflow_gpu-%s in the tested source configurations " + "table.\033[0m\n" % (new_version.pep_440_str, + new_version.pep_440_str)) if __name__ == "__main__": diff --git a/tensorflow/tools/dist_test/README.md b/tensorflow/tools/dist_test/README.md index 39c040e051ec48ae6d1a1f6eb343a143930ba4f3..c1b1f79bbd4b657768b9bbcab93efa3354774915 100644 --- a/tensorflow/tools/dist_test/README.md +++ b/tensorflow/tools/dist_test/README.md @@ -17,7 +17,7 @@ cesnsu model: ./local_test.sh --model_name CENSUS_WIDENDEEP -**2) Launch a remote k8s cluster on Google Container Engine (GKE) and run the +**2) Launch a remote k8s cluster on Google Kubernetes Engine (GKE) and run the test suite on it** For example: diff --git a/tensorflow/tools/docker/parameterized_docker_build.sh b/tensorflow/tools/docker/parameterized_docker_build.sh index fa867b65db50f4a197a35bc3aba98f9f6ecf4724..b4fba5b8f5e19c2fbb8c7261d8cf293757df503c 100755 --- a/tensorflow/tools/docker/parameterized_docker_build.sh +++ b/tensorflow/tools/docker/parameterized_docker_build.sh @@ -34,6 +34,11 @@ # If set to a non-empty string, will use it as the URL from which the # pip wheel file will be downloaded (instead of building the pip locally). # +# TF_DOCKER_BUILD_CENTRAL_PIP_IS_LOCAL +# (Optional) +# If set to a non-empty string, we will treat TF_DOCKER_BUILD_CENTRAL_PIP +# as a path rather than a url. +# # TF_DOCKER_BUILD_IMAGE_NAME: # (Optional) # If set to any non-empty value, will use it as the image of the @@ -234,6 +239,32 @@ if [[ "${TF_DOCKER_BUILD_IS_DEVEL}" == "no" ]]; then "COPY ${PIP_WHL} /\n"\ "RUN pip --no-cache-dir install /${PIP_WHL}" "${ORIG_DOCKERFILE}" \ > "${DOCKERFILE}" + + # Build from a local whl file path rather than an URL + elif [[ ! -z "${TF_DOCKER_BUILD_CENTRAL_PIP_IS_LOCAL}" ]]; then + PIP_WHL="${TF_DOCKER_BUILD_CENTRAL_PIP}" + if [[ -z "${PIP_WHL}" ]]; then + die "ERROR: Cannot locate the specified pip whl file" + fi + echo "Specified PIP whl file is at: ${PIP_WHL}" + + # Copy the pip file to tmp directory + cp "${PIP_WHL}" "${TMP_DIR}/" || \ + die "ERROR: Failed to copy wheel file: ${PIP_WHL}" + + # Use string replacement to put the correct file name into the Dockerfile + PIP_WHL=$(basename "${PIP_WHL}") + + # Modify the non-devel Dockerfile to point to the correct pip whl file + # location + sed -e "/# --- DO NOT EDIT OR DELETE BETWEEN THE LINES --- #/,"\ +"/# --- ~ DO NOT EDIT OR DELETE BETWEEN THE LINES --- #/c"\ +"COPY ${PIP_WHL} /\n"\ +"RUN pip --no-cache-dir install /${PIP_WHL}" "${ORIG_DOCKERFILE}" \ + > "${DOCKERFILE}" + echo "Using local pip wheel from: ${TF_DOCKER_BUILD_CENTRAL_PIP}" + echo + else echo "Downloading pip wheel from: ${TF_DOCKER_BUILD_CENTRAL_PIP}" echo diff --git a/tensorflow/tools/docs/pretty_docs.py b/tensorflow/tools/docs/pretty_docs.py index c033c16ae98c4bcaa4c0338e539324b3a2ae5552..b5df633800ae5a3ce67cf03910d472b9908d6249 100644 --- a/tensorflow/tools/docs/pretty_docs.py +++ b/tensorflow/tools/docs/pretty_docs.py @@ -323,7 +323,7 @@ class _Metadata(object): """ def __init__(self, name): - """Creata a Metadata builder. + """Create a Metadata builder. Args: name: The name of the page being described by the Metadata block. diff --git a/tensorflow/tools/graph_transforms/file_utils.h b/tensorflow/tools/graph_transforms/file_utils.h index 4737e95abcec3694d426e0c3c3a7112c2c5b6bd1..a3723f5cd383341ec206221e7591eca40aabd885 100644 --- a/tensorflow/tools/graph_transforms/file_utils.h +++ b/tensorflow/tools/graph_transforms/file_utils.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_TOOLS_GRAPH_TRANSFORMS_FILE_UTILS_H_ -#define THIRD_PARTY_TENSORFLOW_TOOLS_GRAPH_TRANSFORMS_FILE_UTILS_H_ +#ifndef TENSORFLOW_TOOLS_GRAPH_TRANSFORMS_FILE_UTILS_H_ +#define TENSORFLOW_TOOLS_GRAPH_TRANSFORMS_FILE_UTILS_H_ #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/lib/core/status.h" @@ -29,4 +29,4 @@ Status LoadTextOrBinaryGraphFile(const string& file_name, GraphDef* graph_def); } // namespace graph_transforms } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_TOOLS_GRAPH_TRANSFORMS_FILE_UTILS_H_ +#endif // TENSORFLOW_TOOLS_GRAPH_TRANSFORMS_FILE_UTILS_H_ diff --git a/tensorflow/tools/pip_package/BUILD b/tensorflow/tools/pip_package/BUILD index ff5dd6a0b09bcc0296d7add42d51fdd83b821c64..598080ed2753b862056ebcc76c4c572ae45b46e6 100644 --- a/tensorflow/tools/pip_package/BUILD +++ b/tensorflow/tools/pip_package/BUILD @@ -47,27 +47,6 @@ py_binary( deps = ["//tensorflow:tensorflow_py"], ) -py_test( - name = "pip_smoke_test", - srcs = ["pip_smoke_test.py"], - data = [ - "//tensorflow:all_opensource_files", - ], - tags = [ - "manual", - "notap", - ], -) - -py_binary( - name = "check_load_py_test", - srcs = ["check_load_py_test.py"], - data = [ - "//tensorflow:all_opensource_files", - ], - srcs_version = "PY2AND3", -) - # On Windows, python binary is a zip file of runfiles tree. # Add everything to its data dependency for generating a runfiles tree # for building the pip package on Windows. @@ -173,7 +152,8 @@ sh_binary( "//tensorflow/contrib/nn:nn_py", "//tensorflow/contrib/predictor:predictor_pip", "//tensorflow/contrib/py2tf:py2tf_internal", - "//tensorflow/contrib/py2tf/convert:convert", + "//tensorflow/contrib/py2tf/converters:converters", + "//tensorflow/contrib/py2tf/converters:test_lib", "//tensorflow/contrib/py2tf/pyct:pyct", "//tensorflow/contrib/py2tf/pyct/static_analysis:static_analysis", "//tensorflow/contrib/receptive_field:receptive_field_pip", diff --git a/tensorflow/tools/pip_package/build_pip_package.sh b/tensorflow/tools/pip_package/build_pip_package.sh index ca8c272a0894d1c8ab665d58bcf02bba4c300708..dc31e4c5f703b29f464519d5f1fd54f9b5e11690 100755 --- a/tensorflow/tools/pip_package/build_pip_package.sh +++ b/tensorflow/tools/pip_package/build_pip_package.sh @@ -137,8 +137,8 @@ function main() { fi fi fi - # Install toco as a binary in aux-bin. mkdir "${TMPDIR}/tensorflow/aux-bin" + # Install toco as a binary in aux-bin. cp bazel-bin/tensorflow/contrib/lite/toco/toco ${TMPDIR}/tensorflow/aux-bin/ fi diff --git a/tensorflow/tools/pip_package/check_load_py_test.py b/tensorflow/tools/pip_package/check_load_py_test.py index 79d11b08ce33d4509492927111309e647abe683b..e2fe1121d7fa3178ec60886c6dcb56fe374d38a5 100644 --- a/tensorflow/tools/pip_package/check_load_py_test.py +++ b/tensorflow/tools/pip_package/check_load_py_test.py @@ -22,6 +22,9 @@ import os import subprocess +os.chdir(os.path.abspath(os.path.join(os.path.dirname(__file__), '../../..'))) + + def check_output_despite_error(args): """Get output of args from command line, even if there are errors. diff --git a/tensorflow/tools/pip_package/pip_smoke_test.py b/tensorflow/tools/pip_package/pip_smoke_test.py index cddf9c8f44e3949d2e17dfd00b1a7a1dc4238d7e..38a900738786e2413f5b1dd914caaebeafc92e21 100644 --- a/tensorflow/tools/pip_package/pip_smoke_test.py +++ b/tensorflow/tools/pip_package/pip_smoke_test.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - """This pip smoke test verifies dependency files exist in the pip package. This script runs bazel queries to see what python files are required by the @@ -23,11 +22,15 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os import subprocess -PIP_PACKAGE_QUERY_EXPRESSION = \ - 'deps(//tensorflow/tools/pip_package:build_pip_package)' +os.chdir(os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))) + +PIP_PACKAGE_QUERY_EXPRESSION = ( + "deps(//tensorflow/tools/pip_package:build_pip_package)") +# pylint: disable=g-backslash-continuation PY_TEST_QUERY_EXPRESSION = 'deps(\ filter("^((?!benchmark).)*$",\ kind(py_test,\ @@ -35,6 +38,7 @@ PY_TEST_QUERY_EXPRESSION = 'deps(\ + //tensorflow/contrib/... \ - //tensorflow/contrib/tensorboard/... \ - attr(tags, "manual|no_pip", //tensorflow/...))), 1)' +# pylint: enable=g-backslash-continuation # Hard-coded blacklist of files if not included in pip package # TODO(amitpatankar): Clean up blacklist. @@ -85,15 +89,15 @@ def main(): """ # pip_package_dependencies_list is the list of included files in pip packages - pip_package_dependencies = subprocess.check_output([ - 'bazel', 'query', PIP_PACKAGE_QUERY_EXPRESSION]) + pip_package_dependencies = subprocess.check_output( + ["bazel", "query", PIP_PACKAGE_QUERY_EXPRESSION]) pip_package_dependencies_list = pip_package_dependencies.strip().split("\n") print("Pip package superset size: %d" % len(pip_package_dependencies_list)) # tf_py_test_dependencies is the list of dependencies for all python # tests in tensorflow - tf_py_test_dependencies = subprocess.check_output([ - 'bazel', 'query', PY_TEST_QUERY_EXPRESSION]) + tf_py_test_dependencies = subprocess.check_output( + ["bazel", "query", PY_TEST_QUERY_EXPRESSION]) tf_py_test_dependencies_list = tf_py_test_dependencies.strip().split("\n") print("Pytest dependency subset size: %d" % len(tf_py_test_dependencies_list)) @@ -114,8 +118,7 @@ def main(): # Check if the dependency is in the pip package, the blacklist, or # should be ignored because of its file extension - if not (ignore or - dependency in pip_package_dependencies_list or + if not (ignore or dependency in pip_package_dependencies_list or dependency in BLACKLIST): missing_dependencies.append(dependency) @@ -126,19 +129,20 @@ def main(): for missing_dependency in missing_dependencies: print("\nMissing dependency: %s " % missing_dependency) print("Affected Tests:") - rdep_query = 'rdeps(kind(py_test, \ - //tensorflow/python/...), %s)' % missing_dependency - affected_tests = subprocess.check_output(['bazel', 'query', rdep_query]) + rdep_query = ("rdeps(kind(py_test, //tensorflow/python/...), %s)" % + missing_dependency) + affected_tests = subprocess.check_output(["bazel", "query", rdep_query]) affected_tests_list = affected_tests.split("\n")[:-2] print("\n".join(affected_tests_list)) raise RuntimeError("""One or more dependencies are not in the pip package. Please either blacklist the dependencies in -tensorflow/tensorflow/tensorflow/tools/pip_package/pip_smoke_test.py -or add them to tensorflow/tensorflow/tensorflow/tools/pip_package/BUILD.""") +//tensorflow/tools/pip_package/pip_smoke_test.py +or add them to //tensorflow/tools/pip_package/BUILD.""") else: print("TEST PASSED") + if __name__ == "__main__": main() diff --git a/tensorflow/tools/proto_text/gen_proto_text_functions_lib.h b/tensorflow/tools/proto_text/gen_proto_text_functions_lib.h index 44387bbd4d8cbedf3178ca799d75c758c054a10e..e18d749cff8864d5f900f07028b4bf7f5cb07b7a 100644 --- a/tensorflow/tools/proto_text/gen_proto_text_functions_lib.h +++ b/tensorflow/tools/proto_text/gen_proto_text_functions_lib.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_UTIL_CREATE_PROTO_DEBUG_STRING_LIB_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_UTIL_CREATE_PROTO_DEBUG_STRING_LIB_H_ +#ifndef TENSORFLOW_CORE_UTIL_CREATE_PROTO_DEBUG_STRING_LIB_H_ +#define TENSORFLOW_CORE_UTIL_CREATE_PROTO_DEBUG_STRING_LIB_H_ #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/types.h" @@ -50,4 +50,4 @@ ProtoTextFunctionCode GetProtoTextFunctionCode( } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_UTIL_CREATE_PROTO_DEBUG_STRING_LIB_H_ +#endif // TENSORFLOW_CORE_UTIL_CREATE_PROTO_DEBUG_STRING_LIB_H_ diff --git a/tensorflow/tools/test/BUILD b/tensorflow/tools/test/BUILD index 28d651e9106b29058824c06b160df2b9b5781757..159a8c1cfbdb793d05eda850afb54e860bf2614e 100644 --- a/tensorflow/tools/test/BUILD +++ b/tensorflow/tools/test/BUILD @@ -104,12 +104,3 @@ filegroup( ), visibility = ["//tensorflow:__subpackages__"], ) - -py_test( - name = "check_futures_test", - size = "small", - srcs = ["check_futures_test.py"], - data = ["//tensorflow:all_opensource_files"], - srcs_version = "PY2AND3", - deps = ["@six_archive//:six"], -) diff --git a/tensorflow/tools/test/check_futures_test.py b/tensorflow/tools/test/check_futures_test.py index 1c07511888d6f641fd2d59a9e9161174e1ef1b5c..9181c9bd4a4497dbf22a1f0935795c65533f08d8 100644 --- a/tensorflow/tools/test/check_futures_test.py +++ b/tensorflow/tools/test/check_futures_test.py @@ -33,7 +33,7 @@ import re import six -BASE_DIR = os.path.normpath(os.path.join(__file__, '../../..')) +BASE_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')) FUTURES_PATTERN = re.compile(r'^from __future__ import (\w+)\s*$') FUTURES_PATTERN_2 = re.compile( r'^from __future__ import (\w+), (\w+), (\w+)\s*$') diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl index b27b1f21fbe0607e6f97050c530f6e0b6e3580f9..9145d9e58a3df6c074d5ac44a665a33339c45cc6 100644 --- a/tensorflow/workspace.bzl +++ b/tensorflow/workspace.bzl @@ -5,6 +5,7 @@ load("//third_party/mkl:build_defs.bzl", "mkl_repository") load("//third_party/git:git_configure.bzl", "git_configure") load("//third_party/py:python_configure.bzl", "python_configure") load("//third_party/sycl:sycl_configure.bzl", "sycl_configure") +load("//third_party/toolchains/clang6:repo.bzl", "clang6_configure") load("//third_party/toolchains/cpus/arm:arm_compiler_configure.bzl", "arm_compiler_configure") load("//third_party:repo.bzl", "tf_http_archive") load("@io_bazel_rules_closure//closure/private:java_import_external.bzl", "java_import_external") @@ -65,6 +66,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): # files, in case the parsing of those build files depends on the bazel # version we require here. check_bazel_version_at_least("0.5.4") + clang6_configure(name="local_config_clang6") cuda_configure(name="local_config_cuda") git_configure(name="local_config_git") sycl_configure(name="local_config_sycl") @@ -473,11 +475,11 @@ def tf_workspace(path_prefix="", tf_repo_name=""): tf_http_archive( name = "llvm", urls = [ - "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/bfe367d1e2a3c75b8694967a83c7f05885e8f184.tar.gz", - "https://github.com/llvm-mirror/llvm/archive/bfe367d1e2a3c75b8694967a83c7f05885e8f184.tar.gz", + "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/11a2ca6eea8a7fe240a14c0c35fd2017341279be.tar.gz", + "https://github.com/llvm-mirror/llvm/archive/11a2ca6eea8a7fe240a14c0c35fd2017341279be.tar.gz", ], - sha256 = "916c82948687f6be82dbb7764f707abc319e6e4ebaef868f745bd5f44b0f281c", - strip_prefix = "llvm-bfe367d1e2a3c75b8694967a83c7f05885e8f184", + sha256 = "b5429ccf8d57273cb8489714f728c997cd720ec66fc2c0292422ab8f0e729ce0", + strip_prefix = "llvm-11a2ca6eea8a7fe240a14c0c35fd2017341279be", build_file = str(Label("//third_party/llvm:llvm.BUILD")), )