diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index de4fded6ae6e66995aa9f1687a9d598017416f7a..16d7051343aa92f26905ae81c0ae5c1ce63210e3 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -68,8 +68,8 @@ Include a license at the top of new files. * [Java license example](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/java/src/main/java/org/tensorflow/Graph.java#L1) * [Go license example](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/go/operation.go#L1) * [Bash license example](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/ci_build/ci_sanity.sh#L2) -* [HTML license example](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tensorboard/dist/index.html#L2) -* [JavaScript/TypeScript license example](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tensorboard/components/tf_backend/backend.ts#L1) +* [HTML license example](https://github.com/tensorflow/tensorboard/blob/master/tensorboard/components/tf_backend/tf-backend.html#L2) +* [JavaScript/TypeScript license example](https://github.com/tensorflow/tensorboard/blob/master/tensorboard/components/tf_backend/backend.ts#L1) Bazel BUILD files also need to include a license section, e.g., [BUILD example](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/BUILD#L61). diff --git a/ISSUE_TEMPLATE.md b/ISSUE_TEMPLATE.md index 1a401997c649518766acb2ebb0dea1c128bd0ba4..2f3df7cda9cec29ed0c2266629022f0a22b37df9 100644 --- a/ISSUE_TEMPLATE.md +++ b/ISSUE_TEMPLATE.md @@ -4,7 +4,7 @@ https://stackoverflow.com/questions/tagged/tensorflow If you open a GitHub issue, here is our policy: -1. It must be a bug or a feature request. +1. It must be a bug, a feature request, or a significant problem with documentation (for small docs fixes please send a PR instead). 2. The form below must be filled out. 3. It shouldn't be a TensorBoard issue. Those go [here](https://github.com/tensorflow/tensorboard/issues). diff --git a/README.md b/README.md index 0c93813e584d4e41fe80d50e047069b2dad8311a..c754c3f0db088be8b638c8bc508e1dd765d960d8 100644 --- a/README.md +++ b/README.md @@ -6,7 +6,7 @@ | **`Linux CPU`** | **`Linux GPU`** | **`Mac OS CPU`** | **`Windows CPU`** | **`Android`** | |-----------------|---------------------|------------------|-------------------|---------------| -| [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-cpu)](https://ci.tensorflow.org/job/tensorflow-master-cpu) | [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-linux-gpu)](https://ci.tensorflow.org/job/tensorflow-master-linux-gpu) | [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-mac)](https://ci.tensorflow.org/job/tensorflow-master-mac) | [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-win-cmake-py)](https://ci.tensorflow.org/job/tensorflow-master-win-cmake-py) | [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-android)](https://ci.tensorflow.org/job/tensorflow-master-android) | +| [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-cpu)](https://ci.tensorflow.org/job/tensorflow-master-cpu) | [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-linux-gpu)](https://ci.tensorflow.org/job/tensorflow-master-linux-gpu) | [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-mac)](https://ci.tensorflow.org/job/tensorflow-master-mac) | [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-win-cmake-py)](https://ci.tensorflow.org/job/tensorflow-master-win-cmake-py) | [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-android)](https://ci.tensorflow.org/job/tensorflow-master-android) [ ![Download](https://api.bintray.com/packages/google/tensorflow/tensorflow/images/download.svg) ](https://bintray.com/google/tensorflow/tensorflow/_latestVersion) | **TensorFlow** is an open source software library for numerical computation using data flow graphs. The graph nodes represent mathematical operations, while diff --git a/RELEASE.md b/RELEASE.md index fdf10407fda21444f1d0ee6cf20650d2659b146f..0fad3b5d41cffd3e99e2b8900894bb02d94eb567 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -1,9 +1,98 @@ +# Release 1.6.0 + +## Breaking Changes +* Prebuilt binaries are now built against CUDA 9.0 and cuDNN 7. +* Prebuilt binaries will use AVX instructions. This may break TF on older CPUs. + +## Major Features And Improvements +* New Optimizer internal API for non-slot variables. Descendants of AdamOptimizer that access _beta[12]_power will need to be updated. +* `tf.estimator.{FinalExporter,LatestExporter}` now export stripped SavedModels. This improves forward compatibility of the SavedModel. +* FFT support added to XLA CPU/GPU. + +## Bug Fixes and Other Changes +* Documentation updates: + * Added a second version of Getting Started, which is aimed at ML +newcomers. + * Clarified documentation on `resize_images.align_corners` parameter. + * Additional documentation for TPUs. +* Google Cloud Storage (GCS): + * Add client-side throttle. + * Add a `FlushCaches()` method to the FileSystem interface, with an implementation for GcsFileSystem. +* Other: + * Add `tf.contrib.distributions.Kumaraswamy`. + * `RetryingFileSystem::FlushCaches()` calls the base FileSystem's `FlushCaches()`. + * Add auto_correlation to distributions. + * Add `tf.contrib.distributions.Autoregressive`. + * Add SeparableConv1D layer. + * Add convolutional Flipout layers. + * When both inputs of `tf.matmul` are bfloat16, it returns bfloat16, instead of float32. + * Added `tf.contrib.image.connected_components`. + * Add `tf.contrib.framework.CriticalSection` that allows atomic variable access. + * Output variance over trees predictions for classifications tasks. + * For `pt` and `eval` commands, allow writing tensor values to filesystem as numpy files. + * gRPC: Propagate truncated errors (instead of returning gRPC internal error). + * Augment parallel_interleave to support 2 kinds of prefetching. + * Improved XLA support for C64-related ops log, pow, atan2, tanh. + * Add probabilistic convolutional layers. + +## API Changes +* Introducing prepare_variance boolean with default setting to False for backward compatibility. +* Move `layers_dense_variational_impl.py` to `layers_dense_variational.py`. + +## Known Bugs +* Using XLA:GPU with CUDA 9 and CUDA 9.1 results in garbage results and/or + `CUDA_ILLEGAL_ADDRESS` failures. + + Google discovered in mid-December 2017 that the PTX-to-SASS compiler in CUDA 9 + and CUDA 9.1 sometimes does not properly compute the carry bit when + decomposing 64-bit address calculations with large offsets (e.g. `load [x + + large_constant]`) into 32-bit arithmetic in SASS. + + As a result, these versions of `ptxas` miscompile most XLA programs which use + more than 4GB of temp memory. This results in garbage results and/or + `CUDA_ERROR_ILLEGAL_ADDRESS` failures. + + A fix in CUDA 9.1.121 is expected in late February 2018. We do not expect a + fix for CUDA 9.0.x. Until the fix is available, the only workaround is to + [downgrade](https://developer.nvidia.com/cuda-toolkit-archive) to CUDA 8.0.x + or disable XLA:GPU. + + TensorFlow will print a warning if you use XLA:GPU with a known-bad version of + CUDA; see e00ba24c4038e7644da417ddc639169b6ea59122. + +## Thanks to our Contributors + +This release contains contributions from many people at Google, as well as: + +4d55397500, Ag Ramesh, Aiden Scandella, Akimasa Kimura, Alex Rothberg, Allen Goodman, +amilioto, Andrei Costinescu, Andrei Nigmatulin, Anjum Sayed, Anthony Platanios, +Anush Elangovan, Armando Fandango, Ashish Kumar Ram, Ashwini Shukla, Ben, Bhavani Subramanian, +Brett Koonce, Carl Thomé, cclauss, Cesc, Changming Sun, Christoph Boeddeker, Clayne Robison, +Clemens Schulz, Clint (Woonhyuk Baek), codrut3, Cole Gerdemann, Colin Raffel, Daniel Trebbien, +Daniel Ylitalo, Daniel Zhang, Daniyar, Darjan Salaj, Dave Maclachlan, David Norman, Dong--Jian, +dongsamb, dssgsra, Edward H, eladweiss, elilienstein, Eric Lilienstein, error.d, Eunji Jeong, fanlu, +Florian Courtial, fo40225, Fred, Gregg Helt, Guozhong Zhuang, Hanchen Li, hsm207, hyunyoung2, +ImSheridan, Ishant Mrinal Haloi, Jacky Ko, Jay Young, Jean Flaherty, Jerome, JerrikEph, Jesse +Kinkead, jfaath, Jian Lin, jinghuangintel, Jiongyan Zhang, Joel Hestness, Joel Shor, Johnny Chan, +Julian Niedermeier, Julian Wolff, JxKing, K-W-W, Karl Lessard, Kasper Marstal, Keiji Ariyama, +Koan-Sin Tan, Loki Der Quaeler, Loo Rong Jie, Luke Schaefer, Lynn Jackson, ManHyuk, Matt Basta, +Matt Smith, Matthew Schulkind, Michael, michaelkhan3, Miguel Piedrafita, Mikalai Drabovich, +Mike Knapp, mjwen, mktozk, Mohamed Aly, Mohammad Ashraf Bhuiyan, Myungjoo Ham, Naman Bhalla, +Namrata-Ibm, Nathan Luehr, nathansilberman, Netzeband, Niranjan Hasabnis, Omar Aflak, Ozge +Yalcinkaya, Parth P Panchal, patrickzzy, Patryk Chrabaszcz, Paul Van Eck, Paweł Kapica, Peng Yu, +Philip Yang, Pierre Blondeau, Po-Hsien Chu, powderluv, Puyu Wang, Rajendra Arora, Rasmus, Renat +Idrisov, resec, Robin Richtsfeld, Ronald Eddy Jr, Sahil Singh, Sam Matzek, Sami Kama, sandipmgiri, +Santiago Castro, Sayed Hadi Hashemi, Scott Tseng, Sergii Khomenko, Shahid, Shengpeng Liu, Shreyash +Sharma, Shrinidhi Kl, Simone Cirillo, simsicon, Stanislav Levental, starsblinking, Stephen Lumenta, +Steven Hickson, Su Tang, Taehoon Lee, Takuya Wakisaka, Ted Chang, Ted Ying, Tijmen Verhulsdonck, +Timofey Kondrashov, vade, vaibhav, Valentin Khrulkov, vchigrin, Victor Costan, Viraj Navkal, +Vivek Rane, wagonhelm, Yan Facai (颜发才), Yanbo Liang, Yaroslav Bulatov, yegord, Yong Tang, +Yoni Tsafir, yordun, Yuan (Terry) Tang, Yuxin Wu, zhengdi, Zhengsheng Wei, 田传武 + # Release 1.5.0 ## Breaking Changes -* Prebuilt binaries are now built against CUDA 9 and cuDNN 7. -* Our Linux binaries are built using ubuntu 16 containers, potentially - introducing glibc incompatibility issues with ubuntu 14. +* Prebuilt binaries are now built against CUDA 9.0 and cuDNN 7. * Starting from 1.6 release, our prebuilt binaries will use AVX instructions. This may break TF on older CPUs. @@ -12,7 +101,7 @@ preview version is now available. * [TensorFlow Lite](https://github.com/tensorflow/tensorflow/tree/r1.5/tensorflow/contrib/lite) dev preview is now available. -* CUDA 9 and cuDNN 7 support. +* CUDA 9.0 and cuDNN 7 support. * Accelerated Linear Algebra (XLA): * Add `complex64` support to XLA compiler. * `bfloat` support is now added to XLA infrastructure. @@ -125,6 +214,27 @@ * Minor refactor: move stats files from `stochastic` to `common` and remove `stochastic`. +## Known Bugs +* Using XLA:GPU with CUDA 9 and CUDA 9.1 results in garbage results and/or + `CUDA_ILLEGAL_ADDRESS` failures. + + Google discovered in mid-December 2017 that the PTX-to-SASS compiler in CUDA 9 + and CUDA 9.1 sometimes does not properly compute the carry bit when + decomposing 64-bit address calculations with large offsets (e.g. `load [x + + large_constant]`) into 32-bit arithmetic in SASS. + + As a result, these versions of `ptxas` miscompile most XLA programs which use + more than 4GB of temp memory. This results in garbage results and/or + `CUDA_ERROR_ILLEGAL_ADDRESS` failures. + + A fix in CUDA 9.1.121 is expected in late February 2018. We do not expect a + fix for CUDA 9.0.x. Until the fix is available, the only workaround is to + [downgrade](https://developer.nvidia.com/cuda-toolkit-archive) to CUDA 8.0.x + or disable XLA:GPU. + + TensorFlow will print a warning if you use XLA:GPU with a known-bad version of + CUDA; see e00ba24c4038e7644da417ddc639169b6ea59122. + ## Thanks to our Contributors This release contains contributions from many people at Google, as well as: diff --git a/tensorflow/BUILD b/tensorflow/BUILD index 9e69613c79ebd1d63ff052295cdb7acaaea5ff92..c225cc1a74ca34b818a9e9bed878c9a0d5b22cc0 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -550,8 +550,10 @@ filegroup( "//tensorflow/contrib/predictor:all_files", "//tensorflow/contrib/py2tf:all_files", "//tensorflow/contrib/py2tf/converters:all_files", + "//tensorflow/contrib/py2tf/impl:all_files", "//tensorflow/contrib/py2tf/pyct:all_files", "//tensorflow/contrib/py2tf/pyct/static_analysis:all_files", + "//tensorflow/contrib/py2tf/utils:all_files", "//tensorflow/contrib/quantize:all_files", "//tensorflow/contrib/receptive_field:all_files", "//tensorflow/contrib/reduce_slice_ops:all_files", diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index 3c7f041b39f01d9b8b187079b00e0c5ad99a38cc..e7fb1dec53b56544679173e79729a9d3b950fc89 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -195,10 +195,10 @@ TF_Tensor* TF_NewTensor(TF_DataType dtype, const int64_t* dims, int num_dims, reinterpret_cast(data) % EIGEN_MAX_ALIGN_BYTES != 0) { // TF_STRING and TF_RESOURCE tensors have a different representation in // TF_Tensor than they do in tensorflow::Tensor. So a copy here is a waste - // (any alignement requirements will be taken care of by TF_TensorToTensor + // (any alignment requirements will be taken care of by TF_TensorToTensor // and TF_TensorFromTensor). // - // Other types have the same represntation, so copy only if it is safe to do + // Other types have the same representation, so copy only if it is safe to do // so. buf->data_ = allocate_tensor("TF_NewTensor", len); std::memcpy(buf->data_, data, len); @@ -2144,7 +2144,7 @@ Status CopyGraph(Graph* src_graph, Graph* dst_graph, opts.return_tensors.push_back(ToTensorId(nodes_to_return[i])); } - // TOOD(skyewm): change to OutputTensor + // TODO(skyewm): change to OutputTensor tensorflow::ImportGraphDefResults results; TF_RETURN_IF_ERROR( ImportGraphDef(opts, gdef, dst_graph, dst_refiner, &results)); diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index 74190cb135ac6c17bfcc9d8bd2f7c75ac5e8c076..e62310d811462f88af93505393b622d9a87c72d3 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -46,6 +46,7 @@ tf_cuda_library( "//tensorflow/c:c_api", "//tensorflow/c:c_api_internal", "//tensorflow/core:core_cpu_lib", + "//tensorflow/core:framework", "//tensorflow/core:framework_internal", "//tensorflow/core:framework_lite", "//tensorflow/core:lib_internal", diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index a76c8f5ec05fc3199addc67857d7bb2ea0e263c2..d65b592895950cea3b528478e5bd6257ac688cc6 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -85,15 +85,7 @@ TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) { return nullptr; } - TFE_Context* ret = new TFE_Context(session); - ret->policy = opts->policy; - ret->pflr.reset(new tensorflow::ProcessFunctionLibraryRuntime( - ret->session->device_mgr, opts->session_options.options.env, - TF_GRAPH_DEF_VERSION, &ret->func_lib_def, {})); - ret->rendezvous = - new tensorflow::IntraProcessRendezvous(ret->session->device_mgr); - - return ret; + return new TFE_Context(*opts, session); } void TFE_DeleteContext(TFE_Context* ctx, TF_Status* status) { @@ -261,15 +253,6 @@ TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name, void TFE_DeleteOp(TFE_Op* op) { delete op; } -static void TFE_OpSetDeviceHelper(TFE_Op* op, tensorflow::Device* device, - TF_Status* status) { - // Questionable heuristic: Place the op on the same device as the first input - // placed outside of host memory? - if (IsCPU(op->device) && !IsCPU(device)) { - op->device = device; - } -} - void TFE_OpSetDevice(TFE_Op* op, const char* device_name, TF_Status* status) { tensorflow::Device* d = nullptr; if (device_name != nullptr && strlen(device_name) > 0) { @@ -277,11 +260,24 @@ void TFE_OpSetDevice(TFE_Op* op, const char* device_name, TF_Status* status) { op->ctx->session->device_mgr->LookupDevice(device_name, &d); if (!status->status.ok()) return; } - TFE_OpSetDeviceHelper(op, d, status); + op->device = d; +} + +const char* TFE_OpGetDevice(TFE_Op* op, TF_Status* status) { + tensorflow::Device* device = + (op->device == nullptr) ? op->ctx->devices()[0] : op->device; + return device->name().c_str(); } void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* h, TF_Status* status) { - TFE_OpSetDeviceHelper(op, h->d, status); + // Questionable heuristic ... + // + // Motivation: After an 'op' is placed on GPU because some of its earlier + // inputs are on GPU, we want to keep the 'op' there, even if some later + // inputs of it are not on GPU. + if (IsCPU(op->device) && !IsCPU(h->d)) { + op->device = h->d; + } if (!status->status.ok()) return; op->inputs.push_back(h->t); op->input_devices.push_back(h->d); @@ -298,7 +294,7 @@ TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name, return TF_ATTR_INT; // The compiler requires that we return something. } status->status = - tensorflow::AttrTypeByName(op->attr_types, attr_name, &ret, is_list); + tensorflow::AttrTypeByName(*op->attr_types, attr_name, &ret, is_list); return ret; } diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h index 387de078948e5076d0b069d6380dfc04ea6254df..6a2aff1591d551d4859ef9686604680923802101 100644 --- a/tensorflow/c/eager/c_api.h +++ b/tensorflow/c/eager/c_api.h @@ -154,6 +154,9 @@ TF_CAPI_EXPORT extern void TFE_DeleteOp(TFE_Op* op); TF_CAPI_EXPORT extern void TFE_OpSetDevice(TFE_Op* op, const char* device_name, TF_Status* status); +// The returned string remains valid throughout the lifetime of 'op'. +TF_CAPI_EXPORT extern const char* TFE_OpGetDevice(TFE_Op* op, + TF_Status* status); TF_CAPI_EXPORT extern void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* h, TF_Status* status); diff --git a/tensorflow/c/eager/c_api_internal.h b/tensorflow/c/eager/c_api_internal.h index a6f76c732f2a4c2402a27cd69c101d028dbb8fcc..f2abffb7bc042ffb6d0c5f360fe96588c61ff176 100644 --- a/tensorflow/c/eager/c_api_internal.h +++ b/tensorflow/c/eager/c_api_internal.h @@ -35,6 +35,7 @@ limitations under the License. #include "tensorflow/core/lib/gtl/stl_util.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/public/version.h" struct TFE_ContextOptions { TF_SessionOptions session_options; @@ -43,9 +44,15 @@ struct TFE_ContextOptions { }; struct TFE_Context { - explicit TFE_Context(TF_Session* s) : session(s) {} + explicit TFE_Context(const TFE_ContextOptions& opts, TF_Session* s) + : policy(opts.policy), + session(s), + rendezvous(new tensorflow::IntraProcessRendezvous(s->device_mgr)), + pflr(new tensorflow::ProcessFunctionLibraryRuntime( + session->device_mgr, opts.session_options.options.env, + TF_GRAPH_DEF_VERSION, &func_lib_def, {})) {} - TFE_ContextDevicePlacementPolicy policy; + const TFE_ContextDevicePlacementPolicy policy; // Note: we cannot use C++11 thread_local here as there is no concept of a // thread-local-object-local variable in C++11. @@ -54,8 +61,8 @@ struct TFE_Context { thread_local_policies GUARDED_BY(policy_map_mu); // TFE_Context is an extension of TF_Session. And TF_Session needs a TF_Graph. - TF_Session* session; - tensorflow::Rendezvous* rendezvous; + TF_Session* const session; + tensorflow::Rendezvous* const rendezvous; tensorflow::mutex functions_mu; tensorflow::FunctionLibraryDefinition func_lib_def GUARDED_BY(functions_mu){ @@ -64,14 +71,14 @@ struct TFE_Context { // One FunctionLibraryRuntime per device. // func_libs[i] is the FunctionLibraryRuntime corresponding to // session->devices[i]. - std::unique_ptr pflr; + const std::unique_ptr pflr; tensorflow::mutex cache_mu; std::unordered_map kernel_cache GUARDED_BY(cache_mu); - tensorflow::FunctionLibraryRuntime* func_lib(tensorflow::Device* d) { + tensorflow::FunctionLibraryRuntime* func_lib(tensorflow::Device* d) const { return pflr->GetFLR(d->name()); } @@ -100,6 +107,8 @@ struct TFE_TensorHandle { }; struct TFE_Op { + // t is NULL iff the TFE_Op corresponds to a TensorFlow function instead of a + // primitive operation. TFE_Op(TFE_Context* ctx, const char* op, const tensorflow::AttrTypeMap* t) : ctx(ctx), name(op), attrs(op), attr_types(t), device(nullptr) {} diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc index 18e7a64435e6c7e51998a744abd615edc7ad4318..b0409af87c25b7eb106a69260f9f0b4c30317490 100644 --- a/tensorflow/c/eager/c_api_test.cc +++ b/tensorflow/c/eager/c_api_test.cc @@ -60,6 +60,31 @@ TFE_Op* MatMulOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b) { return op; } +// If there is a GPU device, returns true and sets 'gpu_device_name' +// accordingly. +bool GetGPUDeviceName(TFE_Context* ctx, string* gpu_device_name) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + TF_DeviceList* devices = TFE_ContextListDevices(ctx, status.get()); + CHECK_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + + const int num_devices = TF_DeviceListCount(devices); + for (int i = 0; i < num_devices; ++i) { + const string device_type(TF_DeviceListType(devices, i, status.get())); + CHECK_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get()); + const string device_name(TF_DeviceListName(devices, i, status.get())); + CHECK_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get()); + if (device_type == "GPU") { + *gpu_device_name = device_name; + LOG(INFO) << "Found GPU device " << device_name; + TF_DeleteDeviceList(devices); + return true; + } + } + TF_DeleteDeviceList(devices); + return false; +} + void BM_InitOp(int iters) { tensorflow::testing::StopTiming(); TF_Status* status = TF_NewStatus(); @@ -288,22 +313,15 @@ TEST(CAPI, TensorHandleSilentCopy) { TF_Tensor* t = TFE_TensorHandleResolve(hcpu, status.get()); ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); - TF_DeviceList* devices = TFE_ContextListDevices(ctx, status.get()); - ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); - const int num_devices = TF_DeviceListCount(devices); - // Disable the test if no GPU is present. - if (num_devices > 1) { - const int device_to_use = 1; - const string name(TF_DeviceListName(devices, device_to_use, status.get())); - ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); - - TFE_TensorHandle* hgpu = - TFE_TensorHandleCopyToDevice(hcpu, ctx, name.c_str(), status.get()); + string gpu_device_name; + if (GetGPUDeviceName(ctx, &gpu_device_name)) { + TFE_TensorHandle* hgpu = TFE_TensorHandleCopyToDevice( + hcpu, ctx, gpu_device_name.c_str(), status.get()); ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); TFE_Op* matmul = MatMulOp(ctx, hcpu, hgpu); - TFE_OpSetDevice(matmul, name.c_str(), status.get()); + TFE_OpSetDevice(matmul, gpu_device_name.c_str(), status.get()); ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); TFE_TensorHandle* retvals[1]; int num_retvals = 1; @@ -314,7 +332,6 @@ TEST(CAPI, TensorHandleSilentCopy) { TFE_DeleteTensorHandle(hgpu); } - TF_DeleteDeviceList(devices); TF_DeleteTensor(t); TFE_DeleteTensorHandle(hcpu); TFE_DeleteContext(ctx, status.get()); @@ -337,22 +354,15 @@ TEST(CAPI, TensorHandleSilentCopyLocal) { TF_Tensor* t = TFE_TensorHandleResolve(hcpu, status.get()); ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); - TF_DeviceList* devices = TFE_ContextListDevices(ctx, status.get()); - ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); - const int num_devices = TF_DeviceListCount(devices); - // Disable the test if no GPU is present. - if (num_devices > 1) { - const int device_to_use = 1; - const string name(TF_DeviceListName(devices, device_to_use, status.get())); - ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); - - TFE_TensorHandle* hgpu = - TFE_TensorHandleCopyToDevice(hcpu, ctx, name.c_str(), status.get()); + string gpu_device_name; + if (GetGPUDeviceName(ctx, &gpu_device_name)) { + TFE_TensorHandle* hgpu = TFE_TensorHandleCopyToDevice( + hcpu, ctx, gpu_device_name.c_str(), status.get()); ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); TFE_Op* matmul = MatMulOp(ctx, hcpu, hgpu); - TFE_OpSetDevice(matmul, name.c_str(), status.get()); + TFE_OpSetDevice(matmul, gpu_device_name.c_str(), status.get()); ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); TFE_TensorHandle* retvals[1]; int num_retvals = 1; @@ -363,13 +373,43 @@ TEST(CAPI, TensorHandleSilentCopyLocal) { TFE_DeleteTensorHandle(hgpu); } - TF_DeleteDeviceList(devices); TF_DeleteTensor(t); TFE_DeleteTensorHandle(hcpu); TFE_DeleteContext(ctx, status.get()); EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); } +TEST(CAPI, SetAndGetOpDevices) { + TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_Context* ctx = TFE_NewContext(opts, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContextOptions(opts); + + TFE_TensorHandle* m = TestMatrixTensorHandle(); + TFE_Op* matmul = MatMulOp(ctx, m, m); + + // Disable the test if no GPU is present. + string gpu_device_name; + if (GetGPUDeviceName(ctx, &gpu_device_name)) { + TFE_OpSetDevice(matmul, "GPU:0", status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + const char* device_name = TFE_OpGetDevice(matmul, status); + ASSERT_TRUE(strstr(device_name, "GPU:0") != nullptr); + + TFE_OpSetDevice(matmul, "CPU:0", status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + device_name = TFE_OpGetDevice(matmul, status); + ASSERT_TRUE(strstr(device_name, "CPU:0") != nullptr); + } + + TFE_DeleteOp(matmul); + TFE_DeleteTensorHandle(m); + TFE_DeleteContext(ctx, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TF_DeleteStatus(status); +} + TEST(CAPI, Execute) { TF_Status* status = TF_NewStatus(); TFE_ContextOptions* opts = TFE_NewContextOptions(); diff --git a/tensorflow/c/eager/runtime.cc b/tensorflow/c/eager/runtime.cc index 3a9951e14de3a70e0b9e47fa62e6342e063c4bed..12abfcba2f02782172ab8b978487461770f44167 100644 --- a/tensorflow/c/eager/runtime.cc +++ b/tensorflow/c/eager/runtime.cc @@ -86,10 +86,9 @@ Status AttrTypeMapForOp(const char* op_name, const AttrTypeMap** out) { return Status::OK(); } -Status AttrTypeByName(const AttrTypeMap* m, const string& attr_name, +Status AttrTypeByName(const AttrTypeMap& m, const string& attr_name, TF_AttrType* out, unsigned char* is_list) { - CHECK(m); - auto* t = gtl::FindOrNull(*m, attr_name); + auto* t = gtl::FindOrNull(m, attr_name); if (t == nullptr) { return errors::InvalidArgument("Attribute '", attr_name, "' does not exist for this operation"); @@ -173,14 +172,14 @@ void CombineUnordered(const tensorflow::Fprint128& a, b->high64 += a.high64; } -inline tensorflow::Fprint128 CacheKeyHelper(const StringPiece& s, +inline tensorflow::Fprint128 CacheKeyHelper(StringPiece s, const tensorflow::Fprint128& b) { // TODO(agarwal): avoid ToString(). tensorflow::Fprint128 a = tensorflow::Fingerprint128(s.ToString()); return FingerprintCat128(a, b); } -inline tensorflow::Fprint128 CacheKeyHelper(const StringPiece& s, uint64 b) { +inline tensorflow::Fprint128 CacheKeyHelper(StringPiece s, uint64 b) { return CacheKeyHelper(s, {b, b}); } diff --git a/tensorflow/c/eager/runtime.h b/tensorflow/c/eager/runtime.h index e28a416e67f8382dbd490648106a7eb6e5fcfd13..4d20b5244a46fcde2eed0a429dced2a77b86aedd 100644 --- a/tensorflow/c/eager/runtime.h +++ b/tensorflow/c/eager/runtime.h @@ -43,7 +43,7 @@ typedef std::unordered_map AttrTypeMap; Status AttrTypeMapForOp(const char* op_name, const AttrTypeMap** out); // Looks for 'attr_name' in 'm' and sets 'out' and 'is_list'. -Status AttrTypeByName(const AttrTypeMap* m, const string& attr_name, +Status AttrTypeByName(const AttrTypeMap& m, const string& attr_name, TF_AttrType* out, unsigned char* is_list); // KernelAndDevice::Init needs a NodeDef only to pass the attribute map through. diff --git a/tensorflow/c/eager/runtime_test.cc b/tensorflow/c/eager/runtime_test.cc index 2ccca66f672b96b3c782ddbfc828eeda270cebee..643153058ce3d6f0c88dd23a0dec4c6eff060319 100644 --- a/tensorflow/c/eager/runtime_test.cc +++ b/tensorflow/c/eager/runtime_test.cc @@ -63,17 +63,17 @@ TEST(AttrTypeMap, Lookup) { TF_AttrType t; unsigned char is_list = 1; - s = AttrTypeByName(m, "ThisAttribyteCannotPossiblyExist", &t, &is_list); + s = AttrTypeByName(*m, "ThisAttribyteCannotPossiblyExist", &t, &is_list); EXPECT_FALSE(s.ok()); EXPECT_NE(is_list, 0); - s = AttrTypeByName(m, "transpose_a", &t, &is_list); + s = AttrTypeByName(*m, "transpose_a", &t, &is_list); ASSERT_TRUE(s.ok()) << s; EXPECT_EQ(TF_ATTR_BOOL, t); EXPECT_EQ(is_list, 0); s = AttrTypeMapForOp("Squeeze", &m); ASSERT_TRUE(s.ok()) << s; - s = AttrTypeByName(m, "squeeze_dims", &t, &is_list); + s = AttrTypeByName(*m, "squeeze_dims", &t, &is_list); ASSERT_TRUE(s.ok()) << s; EXPECT_EQ(TF_ATTR_INT, t); EXPECT_NE(is_list, 0); diff --git a/tensorflow/cc/saved_model/loader.cc b/tensorflow/cc/saved_model/loader.cc index acef098c7d07f45d171679bff7c41e13ef0424f1..faa1e378d07ea94ad08ee084d18bf6a113f054af 100644 --- a/tensorflow/cc/saved_model/loader.cc +++ b/tensorflow/cc/saved_model/loader.cc @@ -96,7 +96,9 @@ Status FindMetaGraphDefToLoad(const SavedModel& saved_model_proto, Status LoadMetaGraphIntoSession(const MetaGraphDef& meta_graph_def, const SessionOptions& session_options, std::unique_ptr* session) { - session->reset(NewSession(session_options)); + Session* session_p = nullptr; + TF_RETURN_IF_ERROR(NewSession(session_options, &session_p)); + session->reset(session_p); return (*session)->Create(meta_graph_def.graph_def()); } diff --git a/tensorflow/cc/saved_model/loader_test.cc b/tensorflow/cc/saved_model/loader_test.cc index 0ad6b33bba5fcceaca68e2f179cef2232c689a80..4c64d2cfe3c10e6c7ed82a2d72460a0b34283bb2 100644 --- a/tensorflow/cc/saved_model/loader_test.cc +++ b/tensorflow/cc/saved_model/loader_test.cc @@ -155,6 +155,24 @@ TEST_F(LoaderTest, NoTagMatchMultiple) { << st.error_message(); } +TEST_F(LoaderTest, SessionCreationFailure) { + SavedModelBundle bundle; + // Use invalid SessionOptions to cause session creation to fail. Default + // options work, so provide an invalid value for the target field. + SessionOptions session_options; + constexpr char kInvalidTarget[] = "invalid target"; + session_options.target = kInvalidTarget; + RunOptions run_options; + + const string export_dir = + io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataSharded); + Status st = LoadSavedModel(session_options, run_options, export_dir, + {kSavedModelTagServe}, &bundle); + EXPECT_FALSE(st.ok()); + EXPECT_TRUE(StringPiece(st.error_message()).contains(kInvalidTarget)) + << st.error_message(); +} + TEST_F(LoaderTest, PbtxtFormat) { SavedModelBundle bundle; SessionOptions session_options; diff --git a/tensorflow/compiler/aot/tfcompile.bzl b/tensorflow/compiler/aot/tfcompile.bzl index 2b9c83ba149adf9e089786b91039e256216579c8..58572fea3db5599cc282944e15c866dcf5f25de0 100644 --- a/tensorflow/compiler/aot/tfcompile.bzl +++ b/tensorflow/compiler/aot/tfcompile.bzl @@ -4,7 +4,7 @@ To use from your BUILD file, add the following line to load the macro: -load("@org_tensorflow//tensorflow/compiler/aot:tfcompile.bzl", "tf_library") +load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library") Then call the macro like this: @@ -16,14 +16,15 @@ tf_library( ) """ -load("@org_tensorflow//tensorflow:tensorflow.bzl", "if_android", "tf_copts") +load("//tensorflow:tensorflow.bzl", + "if_android", "tf_cc_test", "tf_copts") def tf_library(name, graph, config, freeze_checkpoint=None, freeze_saver=None, cpp_class=None, gen_test=True, gen_benchmark=True, visibility=None, testonly=None, tfcompile_flags=None, - tfcompile_tool="@org_tensorflow//tensorflow/compiler/aot:tfcompile", + tfcompile_tool="//tensorflow/compiler/aot:tfcompile", include_standard_runtime_deps=True, deps=None, tags=None): """Runs tfcompile to compile a TensorFlow graph into executable code. @@ -119,9 +120,9 @@ def tf_library(name, graph, config, out_nodes_file, ] + freeze_saver_srcs, outs=[freeze_file], - cmd=("$(location @org_tensorflow//tensorflow/python/tools:freeze_graph)" + + cmd=("$(location //tensorflow/python/tools:freeze_graph)" + freeze_args), - tools=["@org_tensorflow//tensorflow/python/tools:freeze_graph"], + tools=["//tensorflow/python/tools:freeze_graph"], tags=tags, ) tfcompile_graph = freeze_file @@ -213,22 +214,22 @@ def tf_library(name, graph, config, # These deps are required by all tf_library targets even if # include_standard_runtime_deps is False. Without them, the # generated code will fail to compile. - "@org_tensorflow//tensorflow/compiler/tf2xla:xla_compiled_cpu_function", - "@org_tensorflow//tensorflow/core:framework_lite", + "//tensorflow/compiler/tf2xla:xla_compiled_cpu_function", + "//tensorflow/core:framework_lite", ] + (need_xla_data_proto and [ # If we're generating the program shape, we must depend on the proto. - "@org_tensorflow//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla:xla_data_proto", ] or []) + (include_standard_runtime_deps and [ # TODO(cwhipkey): only depend on kernel code that the model actually needed. - "@org_tensorflow//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_1d", - "@org_tensorflow//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_2d", - "@org_tensorflow//tensorflow/compiler/xla/service/cpu:cpu_runtime_avx", - "@org_tensorflow//tensorflow/compiler/xla/service/cpu:cpu_runtime_neon", - "@org_tensorflow//tensorflow/compiler/xla/service/cpu:cpu_runtime_sse4_1", - "@org_tensorflow//tensorflow/compiler/xla/service/cpu:runtime_conv2d", - "@org_tensorflow//tensorflow/compiler/xla/service/cpu:runtime_matmul", - "@org_tensorflow//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_conv2d", - "@org_tensorflow//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_matmul", + "//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_1d", + "//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_2d", + "//tensorflow/compiler/xla/service/cpu:cpu_runtime_avx", + "//tensorflow/compiler/xla/service/cpu:cpu_runtime_neon", + "//tensorflow/compiler/xla/service/cpu:cpu_runtime_sse4_1", + "//tensorflow/compiler/xla/service/cpu:runtime_conv2d", + "//tensorflow/compiler/xla/service/cpu:runtime_matmul", + "//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_conv2d", + "//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_matmul", "//third_party/eigen3", ] or []) + (deps or []), tags=tags, @@ -254,28 +255,32 @@ def tf_library(name, graph, config, name=("gen_" + test_name), testonly=1, srcs=[ - "@org_tensorflow//tensorflow/compiler/aot:test.cc", + "//tensorflow/compiler/aot:test.cc", header_file, ], outs=[test_file], cmd=("sed " + sed_replace + - " $(location @org_tensorflow//tensorflow/compiler/aot:test.cc) " + + " $(location //tensorflow/compiler/aot:test.cc) " + "> $(OUTS)"), tags=tags, ) - # The cc_test rule for the generated code. - native.cc_test( + # The cc_test rule for the generated code. To ensure that this works + # reliably across build configurations, we must use tf_cc_test instead of + # native.cc_test. This is related to how we build + # //tensorflow/core:lib -- see the note in tensorflow/core/BUILD + # for more details. + tf_cc_test( name=test_name, srcs=[test_file], deps=[ ":" + name, - "@org_tensorflow//tensorflow/compiler/aot:runtime", - "@org_tensorflow//tensorflow/compiler/aot:tf_library_test_main", - "@org_tensorflow//tensorflow/compiler/xla:executable_run_options", + "//tensorflow/compiler/aot:runtime", + "//tensorflow/compiler/aot:tf_library_test_main", + "//tensorflow/compiler/xla:executable_run_options", "//third_party/eigen3", - "@org_tensorflow//tensorflow/core:lib", - "@org_tensorflow//tensorflow/core:test", + "//tensorflow/core:lib", + "//tensorflow/core:test", ], tags=tags, ) @@ -283,7 +288,7 @@ def tf_library(name, graph, config, if gen_benchmark: benchmark_name = name + "_benchmark" benchmark_file = benchmark_name + ".cc" - benchmark_main = ("@org_tensorflow//tensorflow/compiler/aot:" + + benchmark_main = ("//tensorflow/compiler/aot:" + "benchmark_main.template") # Rule to rewrite benchmark.cc to produce the benchmark_file. @@ -301,7 +306,9 @@ def tf_library(name, graph, config, tags=tags, ) - # The cc_benchmark rule for the generated code. + # The cc_benchmark rule for the generated code. This does not need the + # tf_cc_binary since we (by deliberate design) do not depend on + # //tensorflow/core:lib. # # Note: to get smaller size on android for comparison, compile with: # --copt=-fvisibility=hidden @@ -315,12 +322,12 @@ def tf_library(name, graph, config, linkopts = if_android(["-pie", "-s"]), deps=[ ":" + name, - "@org_tensorflow//tensorflow/compiler/aot:benchmark", - "@org_tensorflow//tensorflow/compiler/aot:runtime", - "@org_tensorflow//tensorflow/compiler/xla:executable_run_options", + "//tensorflow/compiler/aot:benchmark", + "//tensorflow/compiler/aot:runtime", + "//tensorflow/compiler/xla:executable_run_options", "//third_party/eigen3", ] + if_android([ - "@org_tensorflow//tensorflow/compiler/aot:benchmark_extra_android", + "//tensorflow/compiler/aot:benchmark_extra_android", ]), tags=tags, ) @@ -330,11 +337,11 @@ def target_llvm_triple(): # TODO(toddw): Add target_triple for other targets. For details see: # http://llvm.org/docs/doxygen/html/Triple_8h_source.html return select({ - "@org_tensorflow//tensorflow:android_armeabi": "armv5-none-android", - "@org_tensorflow//tensorflow:android_arm": "armv7-none-android", - "@org_tensorflow//tensorflow:android_arm64": "aarch64-none-android", - "@org_tensorflow//tensorflow:android_x86": "i686-none-android", - "@org_tensorflow//tensorflow:linux_ppc64le": "ppc64le-ibm-linux-gnu", - "@org_tensorflow//tensorflow:darwin": "x86_64-none-darwin", + "//tensorflow:android_armeabi": "armv5-none-android", + "//tensorflow:android_arm": "armv7-none-android", + "//tensorflow:android_arm64": "aarch64-none-android", + "//tensorflow:android_x86": "i686-none-android", + "//tensorflow:linux_ppc64le": "ppc64le-ibm-linux-gnu", + "//tensorflow:darwin": "x86_64-none-darwin", "//conditions:default": "x86_64-pc-linux", }) diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc index 0de163d3a8f082eab4d8d802485da1bbc56e8180..9c372a012789fc25ca0a711349c09ca62edc6754 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc @@ -30,12 +30,14 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/optimization_registry.h" +#include "tensorflow/core/common_runtime/shape_refiner.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/graph_def_util.h" #include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/graph_def_builder.h" #include "tensorflow/core/graph/tensor_id.h" #include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/gtl/map_util.h" @@ -141,8 +143,7 @@ struct NodeSlot { // everything to use it. static const char* const kArgOp = "_Arg"; static const char* const kRetValOp = "_Retval"; -static const char* const kSendToHostOp = "_XlaSendToHost"; -static const char* const kRecvFromHostOp = "_XlaRecvFromHost"; +static const char* const kHostComputeOp = "_XlaHostCompute"; static const char* const kSendFromHostOp = "_XlaSendFromHost"; static const char* const kRecvAtHostOp = "_XlaRecvAtHost"; @@ -171,7 +172,8 @@ class Encapsulator { // Write a copy of the input graph to 'graph_out', where the subgraphs are // replaced with calls to the new functions. - Status BuildOutputGraph(bool parallel_checking, Graph* graph_out); + Status BuildOutputGraph(bool parallel_checking, Graph* graph_out, + FunctionLibraryDefinition* library); private: // A subgraph of the input, all marked with a common 'group_attribute' @@ -201,21 +203,29 @@ class Encapsulator { // .. . // RAH --> C --> SFH // - // The compiled cluster is as follows. STH is a SendToHost node which is the - // source of a channel to the RAH node above. RFH is a RecvFromHost node which - // is the destination of a channel from the SFH node above. There is a control - // edge that ensures RFH follows STH, which is used in shape inference to - // ensure that the shapes on the STH host channel are known before the RFH - // channel is compiled. + // The compiled cluster is as follows. HC is a HostCompute node which is the + // source of a channel to the RAH node above and the destination of a channel + // from the SFH node above. // - // Arg --> B --> STH ..> RFH --> D --> Retval + // Arg --> B --> HC --> D --> Retval // - // The channels STH/RAH and SFH/RFH each transmit a tuple, so there is at most - // one RAH and SFH in each compiled cluster. This design is preferred over - // adding separate Arg/Retval nodes for each transmitted value because it - // simplifies the host code that would like to limit communication between - // host and device and, e.g., raise only one interrupt per channel rather than - // one per transmitted value. + // The channels HC/RAH and SFH/HC each transmit multiple tensors, so there is + // at most one RAH and SFH in each outside_compilation cluster. This design is + // preferred over adding separate Arg/Retval nodes for each transmitted value + // because it allows optimizations to the host code that would like to limit + // communication between host and device and, e.g., raise only one interrupt + // per channel rather than one per transmitted value. + // + // The shapes of the outputs from the HC node in general cannot be determined + // until the shapes of its inputs are known at compile time, since e.g., + // above, the shape of C's outputs aren't known until the shape of its inputs + // are known. If the shapes of the HC's outputs can be determined during the + // rewrite, they are stored in the node's 'shapes' attr. Otherwise a minimal + // graph is stored in the shape_inference_graph attr. This graph can be used + // when compiling the HC Op to determined the shape of the SFH inputs given + // the shapes of any ancestor RAH outputs. If it can be determined that the + // shape of the SFH inputs will not be inferrable even once the shapes of the + // RAH outputs are known, an error is returned by the rewriter. class Subgraph { public: // Creates a graph to build the subgraph in, if it doesn't already exist, @@ -246,6 +256,10 @@ class Encapsulator { const std::unordered_map& node_images, Graph* graph_out); + // Returns the names of all the outside_compilation subgraphs in this + // Subgraph. + void GetOutsideCompilationSubgraphNames(std::vector* names) const; + // Returns the Node that inputs to the function should be wired up to. Node* GetCallNodeForInputs() const; @@ -305,15 +319,9 @@ class Encapsulator { void RecordOutsideCompilationOutputOrControl( const string& outside_compilation_id, const Edge* edge); - // Adds the SendToHost nodes for each outside_compilation subgraph once the - // edges have all been recorded via RecordOutsideCompilationInputOrControl. - Status AddSendsToOutsideCompilation( - const std::unordered_map& node_images); - - // Adds the RecvFromHost nodes for each outside_compilation subgraph once - // the edges have all been recorded via - // RecordOutsideCompilationOutputOrControl. - Status AddRecvsFromOutsideCompilation( + // Adds the HostCompute nodes for each outside_compilation subgraph. + Status AddHostComputes( + const string& subgraph_name, const std::unordered_map& node_images); // Creates the sequencer node if it doesn't exist, adding it to graph_out. @@ -323,10 +331,16 @@ class Encapsulator { // all the downstream nodes of call_node_outputs. void ConnectSequencerToOutputs(Graph* graph_out); + Status AddShapeInferenceInfo( + const string& outside_compilation_subgraph_name, + const std::vector& shapes, GraphDef* inference_graph); + + Status ReplaceFunctionDef(FunctionLibraryDefinition* library); + private: struct OutsideCompilationSubgraph { // Map from source (producer node/slot) tensors in the original graph to - // input index (slot number in the SendToHost/RecvAtHost nodes that will + // input index (slot number in the HostCompute/RecvAtHost nodes that will // be created) for the outside_compilation subgraph. std::unordered_map inputs; @@ -335,14 +349,14 @@ class Encapsulator { // outside_compilation subgraph. These are recorded by // RecordOutsideCompilationInputOrControl while walking all the subgraph // edges, and lifted control edges within the subgraph are added by - // AddSendsToOutsideCompilation once the _SendToHost node has been + // AddSendsToOutsideCompilation once the _HostCompute node has been // created. The matching control edge from _RecvAtHost to the // destination is added by CopyEdgeToOutputGraph. std::unordered_set control_inputs; // Maps from source (producer node/slot) and destination (consumer // node/slot) tensors in the original graph to output index (slot number - // in the SendFromHost/RecvFromHost nodes that will be created) for the + // in the SendFromHost/HostCompute nodes that will be created) for the // outside_compilation subgraph. std::unordered_map outputs_by_src; std::unordered_map outputs_by_dst; @@ -352,13 +366,13 @@ class Encapsulator { // containing compiled subgraph. These are recorded by // RecordOutsideCompilationOutputOrControl while walking all the subgraph // edges, and lifted control edges within the subgraph are added by - // AddRecvsFromToOutsideCompilation once the _RecvFromHost node has been + // AddRecvsFromToOutsideCompilation once the _HostCompute node has been // created. The matching control edge from the source to _SendFromHost to // the destination is added by CopyEdgeToOutputGraph. std::unordered_set control_outputs; - // _SendToHost node in the subgraph. Not owned. - Node* send_to_host = nullptr; + // Name of the _HostCompute node in the subgraph. + string host_compute_name; // _RecvAtHost node in the output graph. Not owned. Node* recv_at_host = nullptr; @@ -516,6 +530,59 @@ class Encapsulator { const std::unordered_map& node_images, bool parallel_checking, Graph* graph_out); + // Constructs a minimal shape inference graph that can be used to determine + // the shape of send_node at the time that the subgraph is compiled. + // recv_at_host_nodes contains the names of all the recv_at_host nodes that + // send_node might depend on. These recv_at_host nodes have shapes that are + // not known during the rewrite pass, but will be known at compile time. + // + // If the shapes of all the inputs to send_node can be determined during the + // rewrite pass, on exit graphdef_out is empty and the shapes are returned in + // static_shape_out. Otherwise graphdef_out contains a graph that can be used + // for shape inference at compile time, where all the source nodes of the + // graph are either constants with known shapes, or nodes named in + // recv_at_host_nodes. + // + // A non-OK status is returned if neither of the above conditions can be + // satisfied, e.g., because send_node depends on a node that doesn't have a + // registered shape inference function. + Status DoStaticShapeInferenceForOutsideCompilationSend( + const Graph& graph_in, const ShapeRefiner& shape_refiner, + const std::unordered_set& recv_at_host_nodes, Node* send_node, + FunctionLibraryDefinition* library, + std::vector* static_shape_out, + std::unique_ptr* graphdef_out); + + // Makes a copy of graph containing only nodes that are ancestors of at least + // one node in send_from_host_nodes and store it in pruned_graph. On exit + // nodes_images contains a mapping from nodes in graph to nodes in + // pruned_graph. All functions in the copied graph are inlined. + Status MakePrunedGraphCopyAndInline( + const Graph& graph, const std::vector& sink_nodes, + std::unique_ptr* pruned_graph, + std::unordered_map* node_images, + FunctionLibraryDefinition* library); + + // Makes a copy of graph containing only nodes that are ancestors of a + // send_from_host node in an outside_compilation subgraph, and store it in + // pruned_graph. Also perform shape inference on the pruned graph, using + // shape_refiner. On exit node_images contains a mapping from nodes in graph + // to nodes in pruned_graph. + Status MakeGraphForOutsideCompilationSends( + const Graph& graph, std::unique_ptr* pruned_graph, + ShapeRefiner* shape_refiner, + std::unordered_map* node_images, + FunctionLibraryDefinition* library); + + // Performs static shape inference, as far as possible, for the send_from_host + // nodes in each outside_compilation subgraph. Where it is not possible to + // determine the shape statically, stores a serialized GraphDef in the + // HostCompute 'shape_inference_graph' attr, to be used at compile time for + // final inference. If the shapes are known statically they are stored in the + // HostCompute 'shapes' attr. + Status GetShapeInfoForOutsideCompilationSends( + Graph* graph_out, FunctionLibraryDefinition* library); + const string group_attribute_; const string outside_compilation_attribute_; const Graph* graph_in_; @@ -682,16 +749,20 @@ void Encapsulator::Subgraph::RecordOutsideCompilationOutputOrControl( } } -Status Encapsulator::Subgraph::AddSendsToOutsideCompilation( +Status Encapsulator::Subgraph::AddHostComputes( + const string& subgraph_name, const std::unordered_map& node_images) { for (auto& oc_subgraph_iter : outside_compilation_subgraphs_) { const string& oc_subgraph_name = oc_subgraph_iter.first; OutsideCompilationSubgraph& oc_subgraph = oc_subgraph_iter.second; - if (!oc_subgraph.inputs.empty() || !oc_subgraph.control_inputs.empty()) { - // Build a _SendToHost node sending all the args of the appropriate - // types. - std::vector dtypes(oc_subgraph.inputs.size(), DT_INVALID); + if (!oc_subgraph.inputs.empty() || !oc_subgraph.control_inputs.empty() || + !oc_subgraph.outputs_by_src.empty() || + !oc_subgraph.control_outputs.empty()) { + // Build a _HostCompute node. std::vector inputs(oc_subgraph.inputs.size()); + std::vector input_dtypes(oc_subgraph.inputs.size(), DT_INVALID); + std::vector output_dtypes(oc_subgraph.outputs_by_src.size(), + DT_INVALID); for (const auto& input_src : oc_subgraph.inputs) { const Node* src_node = input_src.first.node; @@ -700,94 +771,64 @@ Status Encapsulator::Subgraph::AddSendsToOutsideCompilation( int input_index = input_src.second; DataType dtype = src_node->output_type(src_slot); - dtypes[input_index] = dtype; inputs[input_index].Reset(src_image->name(), src_slot, dtype); + input_dtypes[input_index] = dtype; } - NodeDef send_def; - NodeDefBuilder builder( - strings::StrCat("outside_compilation_", oc_subgraph_name, "_send"), - kSendToHostOp); - builder.Attr("dtypes", dtypes); + for (const auto& output : oc_subgraph.outputs_by_src) { + DataType dtype = output.first.dtype; + int output_index = output.second; + output_dtypes[output_index] = dtype; + } + + NodeDef host_compute_def; + NodeDefBuilder builder(strings::StrCat("outside_compilation_", + oc_subgraph_name, "_host_compute"), + kHostComputeOp); builder.Input(inputs); - Status s = builder.Finalize(&send_def); + builder.Attr("Tinputs", input_dtypes); + builder.Attr("Toutputs", output_dtypes); + builder.Attr("key", + strings::StrCat("host_compute_channel_", subgraph_name, "_", + oc_subgraph_name)); + Status s = builder.Finalize(&host_compute_def); if (!s.ok()) return s; - oc_subgraph.send_to_host = graph_->AddNode(send_def, &s); + Node* host_compute = graph_->AddNode(host_compute_def, &s); if (!s.ok()) return s; + oc_subgraph.host_compute_name = host_compute->name(); - // Connect the _SendToHost node to its producers in the subgraph. + // Connect the _HostCompute node to its producers in the subgraph. for (auto& input_src : oc_subgraph.inputs) { const Node* src_node = input_src.first.node; Node* src_image = node_images.at(src_node); int src_slot = input_src.first.slot; int input_index = input_src.second; - graph_->AddEdge(src_image, src_slot, oc_subgraph.send_to_host, - input_index); + graph_->AddEdge(src_image, src_slot, host_compute, input_index); } - // Connect the _SendToHost node to its control edge producers in the + // Connect the _HostCompute node to its control edge producers in the // subgraph. for (const auto& src_node : oc_subgraph.control_inputs) { Node* src_image = node_images.at(src_node); - graph_->AddControlEdge(src_image, oc_subgraph.send_to_host); - } - } - } - - return Status::OK(); -} - -Status Encapsulator::Subgraph::AddRecvsFromOutsideCompilation( - const std::unordered_map& node_images) { - for (auto& oc_subgraph_iter : outside_compilation_subgraphs_) { - const string& oc_subgraph_name = oc_subgraph_iter.first; - OutsideCompilationSubgraph& oc_subgraph = oc_subgraph_iter.second; - if (!oc_subgraph.outputs_by_src.empty() || - !oc_subgraph.control_outputs.empty()) { - // Build a _RecvFromHost node producing all the outputs of the appropriate - // types. - std::vector dtypes(oc_subgraph.outputs_by_src.size(), - DT_INVALID); - - for (const auto& output : oc_subgraph.outputs_by_src) { - DataType dtype = output.first.dtype; - int output_index = output.second; - dtypes[output_index] = dtype; + graph_->AddControlEdge(src_image, host_compute); } - NodeDef recv_def; - NodeDefBuilder builder( - strings::StrCat("outside_compilation_", oc_subgraph_name, "_recv"), - kRecvFromHostOp); - builder.Attr("dtypes", dtypes); - Status s = builder.Finalize(&recv_def); - if (!s.ok()) return s; - - Node* recv = graph_->AddNode(recv_def, &s); - if (!s.ok()) return s; - - // Connect the consumers in the subgraph to the _RecvFromHost node. + // Connect the consumers in the subgraph to the _HostCompute node. for (const auto& output : oc_subgraph.outputs_by_dst) { const Node* dst_node = output.first.node; Node* dst_image = node_images.at(dst_node); int dst_slot = output.first.slot; int output_index = output.second; - graph_->AddEdge(recv, output_index, dst_image, dst_slot); + graph_->AddEdge(host_compute, output_index, dst_image, dst_slot); } - // Connect the control edge consumers in the subgraph to the _RecvFromHost + // Connect the control edge consumers in the subgraph to the _HostCompute // node. for (const auto& dst_node : oc_subgraph.control_outputs) { Node* dst_image = node_images.at(dst_node); - graph_->AddControlEdge(recv, dst_image); - } - - // Add a control edge in the subgraph so that the _SendToHost node, if - // any, is compiled before the _RecvFromHost node. - if (oc_subgraph.send_to_host != nullptr) { - graph_->AddControlEdge(oc_subgraph.send_to_host, recv); + graph_->AddControlEdge(host_compute, dst_image); } } } @@ -882,6 +923,63 @@ Status Encapsulator::Subgraph::BuildFunctionDef( return Status::OK(); } +Status Encapsulator::Subgraph::AddShapeInferenceInfo( + const string& outside_compilation_subgraph_name, + const std::vector& shapes, GraphDef* inference_graph) { + OutsideCompilationSubgraph& oc_subgraph = + outside_compilation_subgraphs_.at(outside_compilation_subgraph_name); + + Node* host_compute = nullptr; + for (Node* n : graph_->nodes()) { + if (n->name() == oc_subgraph.host_compute_name) { + host_compute = n; + break; + } + } + if (host_compute == nullptr) { + return errors::InvalidArgument( + "After rewriting subgraph ", outside_compilation_subgraph_name, + " there is no HostCompute Op for outside compilation subgraph ", + oc_subgraph.host_compute_name); + } + + if (inference_graph == nullptr) { + host_compute->AddAttr("shape_inference_graph", ""); + host_compute->AddAttr("shapes", shapes); + } else { + string serialized_graph; + if (!inference_graph->SerializeToString(&serialized_graph)) { + return errors::Internal( + "Failed to serialize graph for outside compilation subgraph ", + oc_subgraph.host_compute_name); + } + host_compute->AddAttr("shape_inference_graph", serialized_graph); + host_compute->AddAttr("shapes", std::vector()); + } + return Status::OK(); +} + +Status Encapsulator::Subgraph::ReplaceFunctionDef( + FunctionLibraryDefinition* library) { + const string& name = call_node_def_.name(); + + FunctionDef fdef; + TF_RETURN_IF_ERROR(GraphToFunctionDef(*graph_, name, &fdef)); + + if (VLOG_IS_ON(1)) { + VLOG(2) << "Replace function def " << name; + dump_graph::DumpGraphToFile( + strings::StrCat("replace_encapsulate_fdef_graph_", name), *graph_, + library); + dump_graph::DumpFunctionDefToFile( + strings::StrCat("replace_encapsulate_fdef_", name), fdef); + } + + TF_RETURN_IF_ERROR(library->RemoveFunction(name)); + TF_RETURN_IF_ERROR(library->AddFunctionDef(fdef)); + return Status::OK(); +} + Status Encapsulator::Subgraph::BuildParallelCheckOp( const std::unordered_map& node_images, Graph* graph_out) { @@ -980,7 +1078,9 @@ Status Encapsulator::Subgraph::AddRecvAtHostNode( NodeDefBuilder builder(strings::StrCat("outside_compilation_", subgraph_name, "_", oc_subgraph_name, "_recv"), kRecvAtHostOp); - builder.Attr("dtypes", dtypes); + builder.Attr("Toutputs", dtypes); + builder.Attr("key", strings::StrCat("host_compute_channel_", subgraph_name, + "_", oc_subgraph_name)); Status s = builder.Finalize(&recv_def); if (!s.ok()) return s; @@ -1020,7 +1120,9 @@ Status Encapsulator::Subgraph::AddSendFromHostNode( NodeDefBuilder builder(strings::StrCat("outside_compilation_", subgraph_name, "_", oc_subgraph_name, "_send"), kSendFromHostOp); - builder.Attr("dtypes", dtypes); + builder.Attr("Tinputs", dtypes); + builder.Attr("key", strings::StrCat("host_compute_channel_", subgraph_name, + "_", oc_subgraph_name)); builder.Input(inputs); Status s = builder.Finalize(&send_def); if (!s.ok()) return s; @@ -1062,6 +1164,13 @@ Status Encapsulator::Subgraph::AddOutsideCompilationHostIONodes( return Status::OK(); } +void Encapsulator::Subgraph::GetOutsideCompilationSubgraphNames( + std::vector* names) const { + for (auto& entry : outside_compilation_subgraphs_) { + names->push_back(entry.first); + } +} + Status Encapsulator::GetFunctionNameAttr( Node const* node, string* attr, string* outside_compilation_attr) const { Status s = GetNodeAttr(node->attrs(), group_attribute_, attr); @@ -1220,8 +1329,7 @@ Status Encapsulator::SplitIntoSubgraphs() { // single input and output node for it. for (auto& entry : subgraphs_) { Subgraph& subgraph = entry.second; - TF_RETURN_IF_ERROR(subgraph.AddSendsToOutsideCompilation(node_images)); - TF_RETURN_IF_ERROR(subgraph.AddRecvsFromOutsideCompilation(node_images)); + TF_RETURN_IF_ERROR(subgraph.AddHostComputes(entry.first, node_images)); } MarkGuaranteedConstants(*graph_in_, src_arg_pairs); @@ -1509,8 +1617,346 @@ Status Encapsulator::AddEdgesToOutputGraph( return Status::OK(); } -Status Encapsulator::BuildOutputGraph(bool parallel_checking, - Graph* graph_out) { +namespace { + +// Adds a dummy Const node to graph_out. The "constant" has the type of +// data_type and the shape indicated in 'shape'. The dummy node is not a valid +// Const node because it does not have any value defined, but this doesn't +// matter because it will only be used subsequently for shape inference. (It +// would be possible to add a switch statement over data_type to create a value +// for the constant, but that would entail maintaining the logic as new types +// are added, and is not necessary.) +Node* AddDummyShapedNode(DataType data_type, const TensorShapeProto& shape, + Graph* graph_out) { + TensorProto dummy_proto; + dummy_proto.set_dtype(data_type); + *dummy_proto.mutable_tensor_shape() = shape; + // Don't set any value field in the proto, since it is only going to be used + // for shape inference. + + GraphDefBuilder::Options options(graph_out, /*status=*/nullptr); + NodeBuilder node_builder(options.GetNameForOp("KnownShape"), "Const", + options.op_registry()); + node_builder.Attr("dtype", data_type).Attr("value", dummy_proto); + return options.FinalizeBuilder(&node_builder); +} + +// Adds a copy of node_in to graph_out and adds the mapping to +// copied_node_images. +Status CopyShapeInferenceNodeToGraph( + Node* node_in, const Node* send_node, + const std::unordered_map& dummy_node_images, + FunctionLibraryDefinition* library, + std::unordered_map* copied_node_images, Graph* graph_out) { + // Once all the ancestor nodes have been added to graph_out, add this node + // and connect it to its ancestors. + Node* node_out = graph_out->CopyNode(node_in); + (*copied_node_images)[node_in] = node_out; + // Don't bother to build the shape inference graph if there's a node with no + // shape inference function, since it would just result in an error later at + // compile time. + const OpRegistrationData* op_reg_data; + TF_RETURN_IF_ERROR(library->LookUp(node_in->type_string(), &op_reg_data)); + if (op_reg_data->shape_inference_fn == nullptr) { + return errors::InvalidArgument( + "Shape inference is not possible for outside_compilation " + "SendFromHost node ", + send_node->name(), " because it depends on node ", node_in->name(), + " which does not have a shape inference function registered."); + } + // Add all the edges to the newly copied node. + for (const Edge* in_edge : node_in->in_edges()) { + if (!in_edge->IsControlEdge()) { + Node* src = in_edge->src(); + const auto iter = dummy_node_images.find(src); + if (iter == dummy_node_images.end()) { + // The src is a copied node so use the original output port. + graph_out->AddEdge((*copied_node_images)[in_edge->src()], + in_edge->src_output(), node_out, + in_edge->dst_input()); + } else { + // The src is a dummy node so use output port 0. + graph_out->AddEdge(iter->second, 0, node_out, in_edge->dst_input()); + } + } + } + return Status::OK(); +} + +} // namespace + +Status Encapsulator::DoStaticShapeInferenceForOutsideCompilationSend( + const Graph& graph_in, const ShapeRefiner& shape_refiner, + const std::unordered_set& recv_at_host_nodes, Node* send_node, + FunctionLibraryDefinition* library, + std::vector* static_shape_out, + std::unique_ptr* graphdef_out) { + // Maps from nodes in graph_in to nodes in graph_out. + // + // When an edge has fully defined shape the source node in graph_in is + // replaced in graph_out by a dummy constant node. The mapping from nodes + // in graph_in to dummy nodes is stored in dummy_node_images. + // + // When a node in graph_in has at least one ancestor that doesn't have fully + // defined shape, it is copied into graph_out. The mapping from nodes in + // graph_in to copied nodes is stored in copied_node_images. + // + // The two types of node are treated differently because, when adding edges to + // graph_out, an output from a dummy node always uses port 0, whereas an + // output from a copied node uses the same port that was used in graph_in. + std::unordered_map dummy_node_images; + std::unordered_map copied_node_images; + + std::unique_ptr graph_out(new Graph(graph_in.op_registry())); + graph_out->set_versions(graph_in.versions()); + static_shape_out->resize(send_node->num_inputs()); + + // We don't use the standard ReverseDFS because we want to cut off traversal + // whenever we find an output with fully defined shape. + // TODO(misard) make this work properly in the presence of control flow. + struct Work { + Node* node; + bool leave; // Are we entering or leaving node? + }; + std::vector stack({{send_node, false}}); + std::vector visited(graph_in.num_node_ids(), false); + while (!stack.empty()) { + Work w = stack.back(); + stack.pop_back(); + Node* n = w.node; + + if (w.leave) { + TF_RETURN_IF_ERROR(CopyShapeInferenceNodeToGraph( + n, send_node, dummy_node_images, library, &copied_node_images, + graph_out.get())); + } else { + if (visited[n->id()]) continue; + visited[n->id()] = true; + + // Arrange to revisit when all done with all inputs. + stack.push_back(Work{n, true}); + + bool has_parent_with_unknown_shape = false; + for (const Edge* in_edge : n->in_edges()) { + if (!in_edge->IsControlEdge()) { + Node* src_node = in_edge->src(); + int src_port = in_edge->src_output(); + shape_inference::InferenceContext* context = + shape_refiner.GetContext(src_node); + shape_inference::ShapeHandle shape = context->output(src_port); + if (context->FullyDefined(shape)) { + // This ancestor has known shape, so instead of adding it to the + // stack, add a dummy node with that shape to graph_out and + // continue. + TensorShapeProto proto; + context->ShapeHandleToProto(shape, &proto); + dummy_node_images[src_node] = AddDummyShapedNode( + src_node->output_type(src_port), proto, graph_out.get()); + if (n == send_node) { + (*static_shape_out)[in_edge->dst_input()] = proto; + } + } else { + if (!visited[src_node->id()]) { + has_parent_with_unknown_shape = true; + stack.push_back({src_node, false}); + } + } + } + } + if (!has_parent_with_unknown_shape) { + if (n == send_node) { + // The shapes of all the inputs to send_node are statically known. We + // won't have to do any inference at compile time so return now: the + // shapes were stored in static_shape_out above. + graphdef_out->reset(); + return Status::OK(); + } else { + // Any shape that is being processed is either the original send node + // or has at least one output with statically-unknown shape. If the + // latter and it doesn't have any inputs with statically-unknown + // shape, then check that it is of the recv nodes that we can fill in + // the shape of at run-time later. If it isn't one of those, then we + // won't have any additional knowledge at compile time, so we already + // know we won't be able to do shape inference and we can return an + // error now. + if (recv_at_host_nodes.find(n->name()) == recv_at_host_nodes.end()) { + return errors::InvalidArgument( + "Shape inference is not possible for outside_compilation " + "SendFromHost node ", + send_node->name(), " because shape of node ", n->name(), + " will not be known at compilation time."); + } + } + } + } + } + + graphdef_out->reset(new GraphDef()); + graph_out->ToGraphDef(graphdef_out->get()); + + return Status::OK(); +} + +Status Encapsulator::MakePrunedGraphCopyAndInline( + const Graph& graph, const std::vector& sink_nodes, + std::unique_ptr* pruned_graph, + std::unordered_map* node_images, + FunctionLibraryDefinition* library) { + // First copy all ancestor nodes of sink_nodes into a new graph. + pruned_graph->reset(new Graph(library)); + (*pruned_graph)->set_versions(graph.versions()); + ReverseDFSFrom(graph, sink_nodes, + /*enter=*/nullptr, + /*leave=*/[&](Node* n) { + if (!n->IsSource()) { + Node* copied = (*pruned_graph)->CopyNode(n); + node_images->emplace(n, copied); + } + }); + + // Add all the edges between copied nodes. + for (auto entry : *node_images) { + const Node* orig = entry.first; + Node* image = entry.second; + for (const Edge* out_edge : orig->out_edges()) { + auto iter = node_images->find(out_edge->dst()); + if (iter != node_images->end()) { + // The source and destination are both in the copied graph. + (*pruned_graph) + ->AddEdge(image, out_edge->src_output(), iter->second, + out_edge->dst_input()); + } + } + } + + // Find all the function call nodes, and inline them. + std::vector function_nodes; + for (auto node : (*pruned_graph)->nodes()) { + const OpRegistrationData* op_reg_data; + TF_RETURN_IF_ERROR(library->LookUp(node->type_string(), &op_reg_data)); + if (op_reg_data->is_function_op) { + function_nodes.push_back(node); + } + } + for (auto node : function_nodes) { + VLOG(2) << "Inlining function " << node->name(); + const FunctionDef* fdef = library->Find(node->type_string()); + if (fdef == nullptr) { + return errors::Internal("Failed to find function ", node->type_string(), + " in function library."); + } + FunctionBody* fbody = nullptr; + TF_RETURN_IF_ERROR( + FunctionDefToBodyHelper(*fdef, node->attrs(), library, + [library](const string& op, const OpDef** sig) { + return library->LookUpOpDef(op, sig); + }, + &fbody)); + InlineFunctionBody(*library, pruned_graph->get(), node, fbody); + delete fbody; + } + + return Status::OK(); +} + +Status Encapsulator::MakeGraphForOutsideCompilationSends( + const Graph& graph, std::unique_ptr* pruned_graph, + ShapeRefiner* shape_refiner, + std::unordered_map* node_images, + FunctionLibraryDefinition* library) { + // Find all the send_from_host nodes in all subgraphs, to use as roots for the + // pruning. + std::vector send_from_host_nodes; + for (auto& subgraph_entry : subgraphs_) { + Subgraph& subgraph = subgraph_entry.second; + std::vector outside_compilation_names; + subgraph.GetOutsideCompilationSubgraphNames(&outside_compilation_names); + for (const auto& name : outside_compilation_names) { + Node* send_node = subgraph.GetSendFromHostNode(name); + if (send_node != nullptr) { + send_from_host_nodes.push_back(send_node); + } + } + } + + // Make a copy of all the graph nodes needed to evaluate the send_from_host + // nodes, inlining any functions as needed. + TF_RETURN_IF_ERROR(MakePrunedGraphCopyAndInline( + graph, send_from_host_nodes, pruned_graph, node_images, library)); + + // Perform shape inference on the pruned graph. + shape_refiner->set_require_shape_inference_fns(false); + FixupSourceAndSinkEdges(pruned_graph->get()); + std::vector post_order; + GetReversePostOrder(*(*pruned_graph), &post_order); + for (auto node : post_order) { + // Ignore the status returned by the shape_refiner. At this point we want + // the best effort shapes, even if no shape function is registered for a + // node. + Status status = shape_refiner->AddNode(node); + if (!status.ok()) { + VLOG(1) << "Shape inference failed for node: " << status; + } + } + + return Status::OK(); +} + +Status Encapsulator::GetShapeInfoForOutsideCompilationSends( + Graph* graph_out, FunctionLibraryDefinition* library) { + std::unique_ptr pruned_graph; + ShapeRefiner shape_refiner(graph_out->versions(), graph_out->op_registry()); + std::unordered_map node_images; + TF_RETURN_IF_ERROR(MakeGraphForOutsideCompilationSends( + *graph_out, &pruned_graph, &shape_refiner, &node_images, library)); + + for (auto& subgraph_entry : subgraphs_) { + Subgraph& subgraph = subgraph_entry.second; + // Find all the recv_at_host nodes in this subgraph. + std::vector outside_compilation_names; + subgraph.GetOutsideCompilationSubgraphNames(&outside_compilation_names); + std::unordered_set recv_at_host_names; + for (const auto& name : outside_compilation_names) { + Node* recv_node = subgraph.GetRecvAtHostNode(name); + if (recv_node != nullptr) { + recv_at_host_names.insert(recv_node->name()); + } + } + // For each send_from_host node, do as much shape inference as possible + // without knowing the shape of the recv_at_host nodes, and store the + // result, along with enough information to complete the job at compile time + // once the recv_at_host shapes are known. + for (const auto& name : outside_compilation_names) { + Node* send_node = subgraph.GetSendFromHostNode(name); + std::vector static_shape; + std::unique_ptr graphdef; + if (send_node != nullptr) { + TF_RETURN_IF_ERROR(DoStaticShapeInferenceForOutsideCompilationSend( + *pruned_graph, shape_refiner, recv_at_host_names, + node_images[send_node], library, &static_shape, &graphdef)); + if (graphdef == nullptr) { + VLOG(2) << "Send node " << send_node->name() << " shapes"; + for (int i = 0; i < static_shape.size(); ++i) { + VLOG(2) << static_shape[i].DebugString(); + } + } else { + VLOG(2) << "Send node " << send_node->name() << " graph\n" + << graphdef->DebugString(); + } + } + TF_RETURN_IF_ERROR( + subgraph.AddShapeInferenceInfo(name, static_shape, graphdef.get())); + } + if (!outside_compilation_names.empty()) { + TF_RETURN_IF_ERROR(subgraph.ReplaceFunctionDef(library)); + } + } + + return Status::OK(); +} + +Status Encapsulator::BuildOutputGraph(bool parallel_checking, Graph* graph_out, + FunctionLibraryDefinition* library) { // Map from nodes in the input graph to nodes in the output graph. std::unordered_map node_images; @@ -1522,6 +1968,9 @@ Status Encapsulator::BuildOutputGraph(bool parallel_checking, TF_RETURN_IF_ERROR( AddEdgesToOutputGraph(node_images, parallel_checking, graph_out)); + TF_RETURN_IF_ERROR( + GetShapeInfoForOutsideCompilationSends(graph_out, library)); + return Status::OK(); } @@ -1545,7 +1994,7 @@ Status EncapsulateSubgraphsInFunctions( std::unique_ptr out(new Graph(library)); out->set_versions(graph_in.versions()); TF_RETURN_IF_ERROR( - encapsulator.BuildOutputGraph(parallel_checking, out.get())); + encapsulator.BuildOutputGraph(parallel_checking, out.get(), library)); *graph_out = std::move(out); return Status::OK(); diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc index b100861d5e9c04a8f9d32d486e0ee7252b79c62b..aed9cae0f1799c4524da8ee309344849798755d5 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc @@ -29,17 +29,181 @@ limitations under the License. namespace tensorflow { namespace { +template +bool EqualProtoMap(const ::tensorflow::protobuf::Map& a, + const ::tensorflow::protobuf::Map& b, + const std::function& key_to_string, + const std::function& value_to_string, + const std::function& compare, + const string& map_name, string* diff) { + for (const auto& elt_a : a) { + const auto iter = b.find(elt_a.first); + if (iter == b.end()) { + if (diff) { + *diff = strings::StrCat( + map_name, " expected: contains element with key '", + key_to_string(elt_a.first), "' got: map has no such element"); + } + return false; + } + if (!compare(elt_a.first, elt_a.second, iter->second)) { + if (diff) { + *diff = strings::StrCat(map_name, " expected: element with key '", + key_to_string(elt_a.first), " has value '", + value_to_string(elt_a.second), "' got: '", + value_to_string(iter->second), "'"); + } + return false; + } + } + for (const auto& elt_b : b) { + const auto iter = a.find(elt_b.first); + if (iter == a.end()) { + if (diff) { + *diff = strings::StrCat(map_name, " got: contains element with key '", + key_to_string(elt_b.first), + "' expected: map has no such element"); + } + return false; + } + } + return true; +} + +bool EqualFunctionNodeDef(const NodeDef& a, const NodeDef& b, + const string& diff_preamble, string* diff) { + if (a.op() != b.op()) { + if (diff) { + *diff = strings::StrCat(diff_preamble, " mismatch for node ", a.name(), + ", expected op '", a.op(), "' got '", b.op()); + } + return false; + } + if (a.device() != b.device()) { + if (diff) { + *diff = strings::StrCat(diff_preamble, " mismatch for node ", a.name(), + ", expected device '", a.device(), "' got '", + b.device()); + } + return false; + } + if (a.input_size() != b.input_size()) { + if (diff) { + *diff = strings::StrCat(diff_preamble, " mismatch for node ", a.name(), + ", expected ", a.input_size(), " inputs got ", + b.input_size(), " expected:\n", a.DebugString(), + "\ngot:\n", b.DebugString()); + } + return false; + } + for (int i = 0; i < a.input_size(); ++i) { + if (a.input(i) != b.input(i)) { + if (diff) { + *diff = strings::StrCat(diff_preamble, " mismatch for node ", a.name(), + " input ", i, ", expected ", a.input(i), + " got ", b.input(i), " expected:\n", + a.DebugString(), "\ngot:\n", b.DebugString()); + } + return false; + } + } + return EqualProtoMap( + a.attr(), b.attr(), [](const string& s) { return s; }, + [](const AttrValue& v) { return v.DebugString(); }, + [](const string& key, const AttrValue& av, const AttrValue& bv) { + if (key == "shape_inference_graph") { + // Default serialization of GraphDef is unstable because maps don't + // serialize deterministically. Rather than go through the hoops to + // turn on deterministic serialization of this attr just for this + // test, add logic here to compare determinstically. + GraphDef ga; + if (!ga.ParseFromString(av.s())) { + return false; + } + GraphDef gb; + if (!gb.ParseFromString(bv.s())) { + return false; + } + return EqualGraphDef(ga, gb, nullptr); + } else { + return av.DebugString() == bv.DebugString(); + } + }, + strings::StrCat(diff_preamble, " attr mismatch for node ", a.name()), + diff); +} + bool EqualFunctionDef(const FunctionDef& a, const FunctionDef& b, string* diff) { - // TODO(phawkins) use a more sophisticated equality test. - if (a.DebugString() != b.DebugString()) { + if (a.signature().DebugString() != b.signature().DebugString()) { if (diff) { - *diff = strings::StrCat("Definition mismatch for function ", + *diff = strings::StrCat("Signature mismatch for function ", a.signature().name(), ", expected:\n", - a.DebugString(), "\ngot:\n", b.DebugString()); + a.signature().DebugString(), "\ngot:\n", + b.signature().DebugString()); } return false; } + if (!EqualProtoMap( + a.attr(), b.attr(), [](const string& s) { return s; }, + [](const AttrValue& v) { return v.DebugString(); }, + [](const string& key, const AttrValue& av, const AttrValue& bv) { + return av.DebugString() == bv.DebugString(); + }, + strings::StrCat("attr mismatch for function ", a.signature().name()), + diff)) { + return false; + } + if (!EqualProtoMap( + a.ret(), b.ret(), [](const string& s) { return s; }, + [](const string& s) { return s; }, + [](const string& key, const string& av, const string& bv) { + return av == bv; + }, + strings::StrCat("ret mismatch for function ", a.signature().name()), + diff)) { + return false; + } + for (int i = 0; i < a.node_def_size(); ++i) { + bool found = false; + for (int j = 0; j < b.node_def_size(); ++j) { + if (a.node_def(i).name() == b.node_def(j).name()) { + if (!EqualFunctionNodeDef( + a.node_def(i), b.node_def(j), + strings::StrCat("Function ", a.signature().name()), diff)) { + return false; + } + found = true; + break; + } + } + if (!found) { + if (diff) { + *diff = strings::StrCat("Function ", a.signature().name(), + ", expected: has node '", a.node_def(i).name(), + "' got: no node of that name"); + } + return false; + } + } + for (int i = 0; i < b.node_def_size(); ++i) { + bool found = false; + for (int j = 0; j < a.node_def_size(); ++j) { + if (b.node_def(i).name() == a.node_def(j).name()) { + found = true; + break; + } + } + if (!found) { + if (diff) { + *diff = strings::StrCat("Function ", a.signature().name(), + ", got: has node '", b.node_def(i).name(), + "' expected: no node of that name"); + } + return false; + } + } return true; } @@ -84,29 +248,64 @@ bool EqualFunctionDefLibrary(const FunctionDefLibrary& expected, // TODO(misard): remove these fake registrations once there are real Ops to be // compiled. -REGISTER_OP("_XlaSendToHost") - .Input("input: dtypes") - .Attr("dtypes: list(type) >= 0"); - -REGISTER_OP("_XlaRecvFromHost") - .Output("output: dtypes") - .Attr("dtypes: list(type) >= 0"); +REGISTER_OP("_XlaHostCompute") + .Input("inputs: Tinputs") + .Output("outputs: Toutputs") + .Attr("Tinputs: list(type) >= 0") + .Attr("Toutputs: list(type) >= 0") + .Attr("key: string") + .SetShapeFn(::tensorflow::shape_inference::UnknownShape); REGISTER_OP("_XlaSendFromHost") - .Input("input: dtypes") - .Attr("dtypes: list(type) >= 0"); + .Input("input: Tinputs") + .Attr("Tinputs: list(type) >= 0") + .Attr("key: string") + .SetShapeFn(::tensorflow::shape_inference::UnknownShape); REGISTER_OP("_XlaRecvAtHost") - .Output("output: dtypes") - .Attr("dtypes: list(type) >= 0"); - -REGISTER_OP("InputTest").Output("o: float"); - -REGISTER_OP("UnaryTest").Input("a: float").Output("o: float"); + .Output("output: Toutputs") + .Attr("Toutputs: list(type) >= 0") + .Attr("key: string") + .SetShapeFn(::tensorflow::shape_inference::UnknownShape); + +REGISTER_OP("InputTest") + .Output("o: float") + .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { + c->set_output(0, c->UnknownShape()); + return Status::OK(); + }); + +REGISTER_OP("InputTestShaped") + .Output("o: float") + .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { + c->set_output(0, c->Vector(2)); + return Status::OK(); + }); + +REGISTER_OP("UnaryTest") + .Input("a: float") + .Output("o: float") + .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { + ::tensorflow::shape_inference::ShapeHandle o; + TF_RETURN_IF_ERROR(c->Merge(c->UnknownShape(), c->input(0), &o)); + c->set_output(0, o); + return Status::OK(); + }); REGISTER_OP("BinaryTest") .Input("a: float") .Input("b: float") - .Output("o: float"); + .Output("o: float") + .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { + ::tensorflow::shape_inference::ShapeHandle o; + TF_RETURN_IF_ERROR(c->Merge(c->UnknownShape(), c->input(0), &o)); + c->set_output(0, o); + return Status::OK(); + }); +REGISTER_OP("BinaryTest2") + .Input("a: float") + .Input("b: float") + .Output("o: float") + .SetShapeFn(::tensorflow::shape_inference::UnknownShape); REGISTER_OP("AddNLikeTest") .Input("inputs: N * T") @@ -124,22 +323,48 @@ Node* Input(const GraphDefBuilder::Options& opts) { return ops::SourceOp("InputTest", opts); } -Node* RecvAtHost(const gtl::ArraySlice& dtypes, +Node* InputShaped(const GraphDefBuilder::Options& opts) { + return ops::SourceOp("InputTestShaped", opts); +} + +Node* KnownShape(const gtl::ArraySlice& shape, + const GraphDefBuilder::Options& opts) { + if (opts.HaveError()) return nullptr; + NodeBuilder node_builder(opts.GetNameForOp("Const"), "Const", + opts.op_registry()); + TensorProto value; + value.set_dtype(DT_FLOAT); + for (int dim : shape) { + value.mutable_tensor_shape()->add_dim()->set_size(dim); + } + return opts.WithAttr("value", value) + .WithAttr("dtype", DT_FLOAT) + .FinalizeBuilder(&node_builder); +} + +Node* RecvAtHost(const string& key, const gtl::ArraySlice& dtypes, const GraphDefBuilder::Options& opts) { if (opts.HaveError()) return nullptr; NodeBuilder node_builder(opts.GetNameForOp("_XlaRecvAtHost"), "_XlaRecvAtHost", opts.op_registry()); - return opts.WithAttr("dtypes", dtypes).FinalizeBuilder(&node_builder); + return opts.WithAttr("Toutputs", dtypes) + .WithAttr("key", key) + .FinalizeBuilder(&node_builder); } -Node* SendFromHost(const std::vector& inputs, - const gtl::ArraySlice& dtypes, +Node* SendFromHost(const string& key, const std::vector& inputs, const GraphDefBuilder::Options& opts) { if (opts.HaveError()) return nullptr; NodeBuilder node_builder(opts.GetNameForOp("_XlaSendFromHost"), "_XlaSendFromHost", opts.op_registry()); node_builder.Input(inputs); - return opts.WithAttr("dtypes", dtypes).FinalizeBuilder(&node_builder); + std::vector dtypes; + for (const auto& node : inputs) { + dtypes.push_back(node.dt); + } + return opts.WithAttr("key", key) + .WithAttr("Tinputs", dtypes) + .FinalizeBuilder(&node_builder); } Node* Unary(ops::NodeOut a, const GraphDefBuilder::Options& opts) { @@ -151,6 +376,11 @@ Node* Binary(ops::NodeOut a, ops::NodeOut b, return ops::BinaryOp("BinaryTest", std::move(a), std::move(b), opts); } +Node* BinaryUnknownShape(ops::NodeOut a, ops::NodeOut b, + const GraphDefBuilder::Options& opts) { + return ops::BinaryOp("BinaryTest2", std::move(a), std::move(b), opts); +} + Node* AddNLike(const std::vector& inputs, const GraphDefBuilder::Options& opts) { if (opts.HaveError()) return nullptr; @@ -576,6 +806,21 @@ TEST(EncapsulateSubgraphsTest, OneFunctionOneOutside) { FunctionDefLibrary library_expected; GraphDef graphdef_expected; + string shape_string_expected; + { + GraphDefBuilder shape(GraphDefBuilder::kFailImmediately); + Node* recv = + RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT, DT_FLOAT}, + shape.opts().WithName("outside_compilation_F1_O1_recv")); + Node* e = Binary(ops::NodeOut(recv, 0), ops::NodeOut(recv, 1), + shape.opts().WithName("E")); + SendFromHost("host_compute_channel_F1_O1", {e}, + shape.opts().WithName("outside_compilation_F1_O1_send")); + GraphDef shape_graph; + TF_EXPECT_OK(shape.ToGraphDef(&shape_graph)); + EXPECT_TRUE(shape_graph.SerializeToString(&shape_string_expected)); + } + *library_expected.add_function() = test::function::XTimesTwo(); *library_expected.add_function() = FunctionDefHelper::Create( "F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval:float"}, {}, @@ -584,19 +829,18 @@ TEST(EncapsulateSubgraphsTest, OneFunctionOneOutside) { {{"c"}, "BinaryTest", {"b_0_arg", "C:o:0"}, {}, {"C"}}, {{"F"}, "BinaryTest", - {"C:o:0", "outside_compilation_O1_recv:output:0"}, + {"C:o:0", "outside_compilation_O1_host_compute:outputs:0"}, {}, - {"outside_compilation_O1_recv"}}, - {{"outside_compilation_O1_send"}, - "_XlaSendToHost", + {"outside_compilation_O1_host_compute"}}, + {{"outside_compilation_O1_host_compute"}, + "_XlaHostCompute", {"C:o:0", "c:o:0"}, - {{"dtypes", gtl::ArraySlice({DT_FLOAT, DT_FLOAT})}}, + {{"Tinputs", gtl::ArraySlice({DT_FLOAT, DT_FLOAT})}, + {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, + {"key", "host_compute_channel_F1_O1"}, + {"shape_inference_graph", shape_string_expected}, + {"shapes", gtl::ArraySlice({})}}, {"c"}}, - {{"outside_compilation_O1_recv"}, - "_XlaRecvFromHost", - {}, - {{"dtypes", gtl::ArraySlice({DT_FLOAT})}}, - {"outside_compilation_O1_send"}}, }, {{"f_0_retval", "F:o:0"}}); @@ -612,11 +856,11 @@ TEST(EncapsulateSubgraphsTest, OneFunctionOneOutside) { Node* call = b2.opts().FinalizeBuilder(&node_builder); Node* recv = - RecvAtHost({DT_FLOAT, DT_FLOAT}, + RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT, DT_FLOAT}, b2.opts().WithName("outside_compilation_F1_O1_recv")); Node* e = Binary(ops::NodeOut(recv, 0), ops::NodeOut(recv, 1), b2.opts().WithName("E").WithControlInputs({recv, b})); - Node* send = SendFromHost({e}, {DT_FLOAT}, + Node* send = SendFromHost("host_compute_channel_F1_O1", {e}, b2.opts() .WithName("outside_compilation_F1_O1_send") .WithControlInput(e)); @@ -674,37 +918,71 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) { FunctionDefLibrary library_expected; GraphDef graphdef_expected; + string shape_string_expected_1; + { + GraphDefBuilder shape1(GraphDefBuilder::kFailImmediately); + Node* recv = + RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT, DT_FLOAT}, + shape1.opts().WithName("outside_compilation_F1_O1_recv")); + Node* e = Binary(ops::NodeOut(recv, 0), ops::NodeOut(recv, 1), + shape1.opts().WithName("E")); + SendFromHost("host_compute_channel_F1_O1", {e}, + shape1.opts().WithName("outside_compilation_F1_O1_send")); + GraphDef shape1_graph; + TF_EXPECT_OK(shape1.ToGraphDef(&shape1_graph)); + EXPECT_TRUE(shape1_graph.SerializeToString(&shape_string_expected_1)); + } + + string shape_string_expected_2; + { + GraphDefBuilder shape2(GraphDefBuilder::kFailImmediately); + Node* recv1 = + RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT, DT_FLOAT}, + shape2.opts().WithName("outside_compilation_F1_O1_recv")); + Node* e = Binary(ops::NodeOut(recv1, 0), ops::NodeOut(recv1, 1), + shape2.opts().WithName("E")); + Node* recv2 = + RecvAtHost("host_compute_channel_F1_O2", {DT_FLOAT, DT_FLOAT}, + shape2.opts().WithName("outside_compilation_F1_O2_recv")); + Node* h = Binary(ops::NodeOut(recv2, 0), e, shape2.opts().WithName("H")); + SendFromHost("host_compute_channel_F1_O2", {h}, + shape2.opts().WithName("outside_compilation_F1_O2_send")); + GraphDef shape2_graph; + TF_EXPECT_OK(shape2.ToGraphDef(&shape2_graph)); + EXPECT_TRUE(shape2_graph.SerializeToString(&shape_string_expected_2)); + } + *library_expected.add_function() = FunctionDefHelper::Create( "F1", {"a_0_arg:float", "b_0_arg:float"}, {"i_0_retval:float"}, {}, { {{"C"}, "UnaryTest", {"a_0_arg"}}, {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}, {}}, - {{"I"}, "UnaryTest", {"outside_compilation_O2_recv:output:0"}}, + {{"I"}, + "UnaryTest", + {"outside_compilation_O2_host_compute:outputs:0"}}, {{"F"}, "BinaryTest", - {"C:o:0", "outside_compilation_O1_recv:output:0"}, + {"C:o:0", "outside_compilation_O1_host_compute:outputs:0"}, {}, - {"outside_compilation_O1_recv"}}, - {{"outside_compilation_O2_send"}, - "_XlaSendToHost", + {"outside_compilation_O1_host_compute"}}, + {{"outside_compilation_O2_host_compute"}, + "_XlaHostCompute", {"D:o:0", "F:o:0"}, - {{"dtypes", gtl::ArraySlice({DT_FLOAT, DT_FLOAT})}}, + {{"Tinputs", gtl::ArraySlice({DT_FLOAT, DT_FLOAT})}, + {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, + {"key", "host_compute_channel_F1_O2"}, + {"shape_inference_graph", shape_string_expected_2}, + {"shapes", gtl::ArraySlice({})}}, {"F"}}, - {{"outside_compilation_O1_send"}, - "_XlaSendToHost", + {{"outside_compilation_O1_host_compute"}, + "_XlaHostCompute", {"C:o:0", "D:o:0"}, - {{"dtypes", gtl::ArraySlice({DT_FLOAT, DT_FLOAT})}}, + {{"Tinputs", gtl::ArraySlice({DT_FLOAT, DT_FLOAT})}, + {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, + {"key", "host_compute_channel_F1_O1"}, + {"shape_inference_graph", shape_string_expected_1}, + {"shapes", gtl::ArraySlice({})}}, {"D"}}, - {{"outside_compilation_O2_recv"}, - "_XlaRecvFromHost", - {}, - {{"dtypes", gtl::ArraySlice({DT_FLOAT})}}, - {"outside_compilation_O2_send"}}, - {{"outside_compilation_O1_recv"}, - "_XlaRecvFromHost", - {}, - {{"dtypes", gtl::ArraySlice({DT_FLOAT})}}, - {"outside_compilation_O1_send"}}, }, {{"i_0_retval", "I:o:0"}}); @@ -720,23 +998,24 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) { Node* call = b2.opts().FinalizeBuilder(&node_builder); Node* recv1 = - RecvAtHost({DT_FLOAT, DT_FLOAT}, + RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT, DT_FLOAT}, b2.opts().WithName("outside_compilation_F1_O1_recv")); Node* e = Binary(ops::NodeOut(recv1, 0), ops::NodeOut(recv1, 1), b2.opts().WithName("E").WithControlInputs({recv1, b})); - Node* send1 = SendFromHost({e}, {DT_FLOAT}, + Node* send1 = SendFromHost("host_compute_channel_F1_O1", {e}, b2.opts() .WithName("outside_compilation_F1_O1_send") .WithControlInput(e)); Node* recv2 = - RecvAtHost({DT_FLOAT, DT_FLOAT}, + RecvAtHost("host_compute_channel_F1_O2", {DT_FLOAT, DT_FLOAT}, b2.opts().WithName("outside_compilation_F1_O2_recv")); Node* g = Binary(e, ops::NodeOut(recv2, 1), b2.opts().WithName("G").WithControlInputs({recv2, e})); Node* h = Binary(ops::NodeOut(recv2, 0), e, b2.opts().WithName("H")); - Node* send2 = SendFromHost( - {h}, {DT_FLOAT}, b2.opts().WithName("outside_compilation_F1_O2_send")); + Node* send2 = + SendFromHost("host_compute_channel_F1_O2", {h}, + b2.opts().WithName("outside_compilation_F1_O2_send")); Node* s = NoOp(b2.opts() .WithName("F1_sequencer") @@ -758,8 +1037,8 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) { { GraphDefBuilder b1(GraphDefBuilder::kFailImmediately); - Node* a = Input(b1.opts().WithName("A")); - Node* b = Input(b1.opts().WithName("B")); + Node* a = InputShaped(b1.opts().WithName("A")); + Node* b = InputShaped(b1.opts().WithName("B")); Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1")); Node* d = Binary(b, c, b1.opts().WithName("D").WithAttr("_encapsulate", "F1")); @@ -791,6 +1070,24 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) { FunctionDefLibrary library_expected; GraphDef graphdef_expected; + string shape_string_expected; + { + GraphDefBuilder shape(GraphDefBuilder::kFailImmediately); + Node* recv = + RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT, DT_FLOAT}, + shape.opts().WithName("outside_compilation_F1_O1_recv")); + Node* e = Binary(ops::NodeOut(recv, 0), ops::NodeOut(recv, 1), + shape.opts().WithName("E")); + SendFromHost("host_compute_channel_F1_O1", {e}, + shape.opts().WithName("outside_compilation_F1_O1_send")); + GraphDef shape_graph; + TF_EXPECT_OK(shape.ToGraphDef(&shape_graph)); + EXPECT_TRUE(shape_graph.SerializeToString(&shape_string_expected)); + } + + TensorShapeProto shape_proto_expected; + shape_proto_expected.add_dim()->set_size(2); + *library_expected.add_function() = FunctionDefHelper::Create( "F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval:float", "d_0_retval:float"}, {}, @@ -799,19 +1096,18 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) { {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}}, {{"F"}, "BinaryTest", - {"C:o:0", "outside_compilation_O1_recv:output:0"}, + {"C:o:0", "outside_compilation_O1_host_compute:outputs:0"}, {}, - {"outside_compilation_O1_recv"}}, - {{"outside_compilation_O1_send"}, - "_XlaSendToHost", + {"outside_compilation_O1_host_compute"}}, + {{"outside_compilation_O1_host_compute"}, + "_XlaHostCompute", {"C:o:0", "D:o:0"}, - {{"dtypes", gtl::ArraySlice({DT_FLOAT, DT_FLOAT})}}, + {{"Tinputs", gtl::ArraySlice({DT_FLOAT, DT_FLOAT})}, + {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, + {"key", "host_compute_channel_F1_O1"}, + {"shape_inference_graph", shape_string_expected}, + {"shapes", gtl::ArraySlice({})}}, {"D"}}, - {{"outside_compilation_O1_recv"}, - "_XlaRecvFromHost", - {}, - {{"dtypes", gtl::ArraySlice({DT_FLOAT})}}, - {"outside_compilation_O1_send"}}, }, {{"d_0_retval", "D:o:0"}, {"f_0_retval", "F:o:0"}}); @@ -822,16 +1118,16 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) { {{"G"}, "BinaryTest", {"e_0_arg", "f_0_arg"}}, {{"I"}, "BinaryTest", - {"f_0_arg", "outside_compilation_O1_recv:output:0"}}, - {{"outside_compilation_O1_send"}, - "_XlaSendToHost", + {"f_0_arg", "outside_compilation_O1_host_compute:outputs:0"}}, + {{"outside_compilation_O1_host_compute"}, + "_XlaHostCompute", {"G:o:0"}, - {{"dtypes", gtl::ArraySlice({DT_FLOAT})}}}, - {{"outside_compilation_O1_recv"}, - "_XlaRecvFromHost", - {}, - {{"dtypes", gtl::ArraySlice({DT_FLOAT})}}, - {"outside_compilation_O1_send"}}, + {{"Tinputs", gtl::ArraySlice({DT_FLOAT})}, + {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, + {"key", "host_compute_channel_F2_O1"}, + {"shape_inference_graph", ""}, + {"shapes", + gtl::ArraySlice({shape_proto_expected})}}}, }, {{"g_0_retval", "G:o:0"}, {"i_0_retval", "I:o:0"}}); @@ -839,15 +1135,15 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) { std::unique_ptr lib_def( new FunctionLibraryDefinition(OpRegistry::Global(), library_expected)); GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get()); - Node* a = Input(b2.opts().WithName("A")); - Node* b = Input(b2.opts().WithName("B")); + Node* a = InputShaped(b2.opts().WithName("A")); + Node* b = InputShaped(b2.opts().WithName("B")); Node* recv1 = - RecvAtHost({DT_FLOAT, DT_FLOAT}, + RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT, DT_FLOAT}, b2.opts().WithName("outside_compilation_F1_O1_recv")); Node* e = Binary(ops::NodeOut(recv1, 0), ops::NodeOut(recv1, 1), b2.opts().WithName("E").WithControlInputs({recv1, b})); - Node* send1 = SendFromHost({e}, {DT_FLOAT}, + Node* send1 = SendFromHost("host_compute_channel_F1_O1", {e}, b2.opts() .WithName("outside_compilation_F1_O1_send") .WithControlInput(e)); @@ -857,12 +1153,14 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) { Node* s1 = NoOp( b2.opts().WithName("F1_sequencer").WithControlInputs({recv1, send1})); - Node* recv2 = RecvAtHost( - {DT_FLOAT}, b2.opts().WithName("outside_compilation_F2_O1_recv")); + Node* recv2 = + RecvAtHost("host_compute_channel_F2_O1", {DT_FLOAT}, + b2.opts().WithName("outside_compilation_F2_O1_recv")); Node* h = Binary(ops::NodeOut(call1, 1), recv2, b2.opts().WithName("H").WithControlInput(s1)); - Node* send2 = SendFromHost( - {h}, {DT_FLOAT}, b2.opts().WithName("outside_compilation_F2_O1_send")); + Node* send2 = + SendFromHost("host_compute_channel_F2_O1", {h}, + b2.opts().WithName("outside_compilation_F2_O1_send")); NodeBuilder node_builder2("F2", "F2", lib_def.get()); node_builder2.Input(e).Input(call1); @@ -888,7 +1186,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputs) { { GraphDefBuilder b1(GraphDefBuilder::kFailImmediately); - Node* a = Input(b1.opts().WithName("A")); + Node* a = InputShaped(b1.opts().WithName("A")); Node* b = Input(b1.opts().WithName("B")); Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1")); Node* d = @@ -908,6 +1206,9 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputs) { FunctionDefLibrary library_expected; GraphDef graphdef_expected; + TensorShapeProto shape_proto_expected; + shape_proto_expected.add_dim()->set_size(2); + *library_expected.add_function() = FunctionDefHelper::Create( "F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval:float"}, {}, { @@ -915,11 +1216,16 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputs) { {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}}, {{"F"}, "BinaryTest", - {"D:o:0", "outside_compilation_O1_recv:output:0"}}, - {{"outside_compilation_O1_recv"}, - "_XlaRecvFromHost", + {"D:o:0", "outside_compilation_O1_host_compute:outputs:0"}}, + {{"outside_compilation_O1_host_compute"}, + "_XlaHostCompute", {}, - {{"dtypes", gtl::ArraySlice({DT_FLOAT})}}}, + {{"Tinputs", gtl::ArraySlice({})}, + {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, + {"key", "host_compute_channel_F1_O1"}, + {"shape_inference_graph", ""}, + {"shapes", + gtl::ArraySlice({shape_proto_expected})}}}, }, {{"f_0_retval", "F:o:0"}}); @@ -927,12 +1233,13 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputs) { std::unique_ptr lib_def( new FunctionLibraryDefinition(OpRegistry::Global(), library_expected)); GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get()); - Node* a = Input(b2.opts().WithName("A")); + Node* a = InputShaped(b2.opts().WithName("A")); Node* b = Input(b2.opts().WithName("B")); Node* e = Unary(a, b2.opts().WithName("E")); - Node* send1 = SendFromHost( - {e}, {DT_FLOAT}, b2.opts().WithName("outside_compilation_F1_O1_send")); + Node* send1 = + SendFromHost("host_compute_channel_F1_O1", {e}, + b2.opts().WithName("outside_compilation_F1_O1_send")); NodeBuilder node_builder1("F1", "F1", lib_def.get()); node_builder1.Input(a).Input(b); Node* call1 = b2.opts().FinalizeBuilder(&node_builder1); @@ -954,7 +1261,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlInput) { { GraphDefBuilder b1(GraphDefBuilder::kFailImmediately); - Node* a = Input(b1.opts().WithName("A")); + Node* a = InputShaped(b1.opts().WithName("A")); Node* b = Input(b1.opts().WithName("B")); Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1")); Node* d = @@ -975,6 +1282,9 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlInput) { FunctionDefLibrary library_expected; GraphDef graphdef_expected; + TensorShapeProto shape_proto_expected; + shape_proto_expected.add_dim()->set_size(2); + *library_expected.add_function() = FunctionDefHelper::Create( "F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval:float"}, {}, { @@ -982,17 +1292,17 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlInput) { {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}}, {{"F"}, "BinaryTest", - {"D:o:0", "outside_compilation_O1_recv:output:0"}}, - {{"outside_compilation_O1_send"}, - "_XlaSendToHost", + {"D:o:0", "outside_compilation_O1_host_compute:outputs:0"}}, + {{"outside_compilation_O1_host_compute"}, + "_XlaHostCompute", {}, - {{"dtypes", gtl::ArraySlice({})}}, + {{"Tinputs", gtl::ArraySlice({})}, + {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, + {"key", "host_compute_channel_F1_O1"}, + {"shape_inference_graph", ""}, + {"shapes", + gtl::ArraySlice({shape_proto_expected})}}, {"D"}}, - {{"outside_compilation_O1_recv"}, - "_XlaRecvFromHost", - {}, - {{"dtypes", gtl::ArraySlice({DT_FLOAT})}}, - {"outside_compilation_O1_send"}}, }, {{"f_0_retval", "F:o:0"}}); @@ -1000,14 +1310,16 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlInput) { std::unique_ptr lib_def( new FunctionLibraryDefinition(OpRegistry::Global(), library_expected)); GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get()); - Node* a = Input(b2.opts().WithName("A")); + Node* a = InputShaped(b2.opts().WithName("A")); Node* b = Input(b2.opts().WithName("B")); Node* recv1 = - RecvAtHost({}, b2.opts().WithName("outside_compilation_F1_O1_recv")); + RecvAtHost("host_compute_channel_F1_O1", {}, + b2.opts().WithName("outside_compilation_F1_O1_recv")); Node* e = Unary(a, b2.opts().WithName("E").WithControlInput(recv1)); - Node* send1 = SendFromHost( - {e}, {DT_FLOAT}, b2.opts().WithName("outside_compilation_F1_O1_send")); + Node* send1 = + SendFromHost("host_compute_channel_F1_O1", {e}, + b2.opts().WithName("outside_compilation_F1_O1_send")); NodeBuilder node_builder1("F1", "F1", lib_def.get()); node_builder1.Input(a).Input(b); Node* call1 = b2.opts().FinalizeBuilder(&node_builder1); @@ -1055,10 +1367,14 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoOutputs) { {{"C"}, "UnaryTest", {"a_0_arg"}}, {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}}, {{"F"}, "UnaryTest", {"D:o:0"}}, - {{"outside_compilation_O1_send"}, - "_XlaSendToHost", + {{"outside_compilation_O1_host_compute"}, + "_XlaHostCompute", {"D:o:0"}, - {{"dtypes", gtl::ArraySlice({DT_FLOAT})}}}, + {{"Tinputs", gtl::ArraySlice({DT_FLOAT})}, + {"Toutputs", gtl::ArraySlice({})}, + {"key", "host_compute_channel_F1_O1"}, + {"shape_inference_graph", ""}, + {"shapes", gtl::ArraySlice({})}}}, }, {{"f_0_retval", "F:o:0"}}); @@ -1069,8 +1385,9 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoOutputs) { Node* a = Input(b2.opts().WithName("A")); Node* b = Input(b2.opts().WithName("B")); - Node* recv1 = RecvAtHost( - {DT_FLOAT}, b2.opts().WithName("outside_compilation_F1_O1_recv")); + Node* recv1 = + RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT}, + b2.opts().WithName("outside_compilation_F1_O1_recv")); Node* e = Unary(recv1, b2.opts().WithName("E")); NodeBuilder node_builder1("F1", "F1", lib_def.get()); node_builder1.Input(a).Input(b); @@ -1118,16 +1435,19 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlOutput) { { {{"C"}, "UnaryTest", {"a_0_arg"}}, {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}}, - {{"F"}, "UnaryTest", {"D:o:0"}, {}, {"outside_compilation_O1_recv"}}, - {{"outside_compilation_O1_send"}, - "_XlaSendToHost", + {{"F"}, + "UnaryTest", {"D:o:0"}, - {{"dtypes", gtl::ArraySlice({DT_FLOAT})}}}, - {{"outside_compilation_O1_recv"}, - "_XlaRecvFromHost", {}, - {{"dtypes", gtl::ArraySlice({})}}, - {"outside_compilation_O1_send"}}, + {"outside_compilation_O1_host_compute"}}, + {{"outside_compilation_O1_host_compute"}, + "_XlaHostCompute", + {"D:o:0"}, + {{"Tinputs", gtl::ArraySlice({DT_FLOAT})}, + {"Toutputs", gtl::ArraySlice({})}, + {"key", "host_compute_channel_F1_O1"}, + {"shape_inference_graph", ""}, + {"shapes", gtl::ArraySlice({})}}}, }, {{"f_0_retval", "F:o:0"}}); @@ -1138,10 +1458,11 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlOutput) { Node* a = Input(b2.opts().WithName("A")); Node* b = Input(b2.opts().WithName("B")); - Node* recv1 = RecvAtHost( - {DT_FLOAT}, b2.opts().WithName("outside_compilation_F1_O1_recv")); + Node* recv1 = + RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT}, + b2.opts().WithName("outside_compilation_F1_O1_recv")); Node* e = Unary(recv1, b2.opts().WithName("E")); - Node* send1 = SendFromHost({}, {}, + Node* send1 = SendFromHost("host_compute_channel_F1_O1", {}, b2.opts() .WithName("outside_compilation_F1_O1_send") .WithControlInput(e)); @@ -1215,5 +1536,110 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputsOrOutputs) { TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library); } +// Test for shape inference of outside compilation. +TEST(EncapsulateSubgraphsTest, OutsideCompilationShapeInference) { + FunctionDefLibrary library; + GraphDef graphdef; + + { + *library.add_function() = test::function::XTimesTwo(); + + GraphDefBuilder b1(GraphDefBuilder::kFailImmediately); + Node* a = InputShaped(b1.opts().WithName("A")); + Node* b = Input(b1.opts().WithName("B")); + // Give nodes 'c' and 'd' names that collide after lowercasing. + Node* c = Unary(a, b1.opts().WithName("C")); + Node* d = Unary(b, b1.opts().WithName("c").WithControlInput(c).WithAttr( + "_encapsulate", "F1")); + Node* e = BinaryUnknownShape(c, d, + b1.opts() + .WithName("E") + .WithControlInputs({b, d}) + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); + Node* f = Binary(c, e, + b1.opts().WithName("F").WithControlInput(e).WithAttr( + "_encapsulate", "F1")); + Binary(a, f, b1.opts().WithName("G").WithControlInput(e)); + TF_EXPECT_OK(b1.ToGraphDef(&graphdef)); + } + + TF_EXPECT_OK(Encapsulate(&graphdef, &library)); + + FunctionDefLibrary library_expected; + GraphDef graphdef_expected; + + string shape_string_expected; + { + GraphDefBuilder shape(GraphDefBuilder::kFailImmediately); + Node* known = KnownShape({2}, shape.opts().WithName("KnownShape/_0")); + Node* recv = + RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT}, + shape.opts().WithName("outside_compilation_F1_O1_recv")); + Node* e = BinaryUnknownShape(known, recv, shape.opts().WithName("E")); + SendFromHost("host_compute_channel_F1_O1", {e}, + shape.opts().WithName("outside_compilation_F1_O1_send")); + GraphDef shape_graph; + TF_EXPECT_OK(shape.ToGraphDef(&shape_graph)); + EXPECT_TRUE(shape_graph.SerializeToString(&shape_string_expected)); + } + + *library_expected.add_function() = test::function::XTimesTwo(); + *library_expected.add_function() = FunctionDefHelper::Create( + "F1", {"b_0_arg:float", "c_0_arg:float"}, {"f_0_retval:float"}, {}, + { + {{"c"}, "UnaryTest", {"b_0_arg"}, {}, {}}, + {{"F"}, + "BinaryTest", + {"c_0_arg", "outside_compilation_O1_host_compute:outputs:0"}, + {}, + {"outside_compilation_O1_host_compute"}}, + {{"outside_compilation_O1_host_compute"}, + "_XlaHostCompute", + {"c:o:0"}, + {{"Tinputs", gtl::ArraySlice({DT_FLOAT})}, + {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, + {"key", "host_compute_channel_F1_O1"}, + {"shape_inference_graph", shape_string_expected}, + {"shapes", gtl::ArraySlice({})}}, + {"c"}}, + }, + {{"f_0_retval", "F:o:0"}}); + + { + std::unique_ptr lib_def( + new FunctionLibraryDefinition(OpRegistry::Global(), library_expected)); + GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get()); + Node* a = InputShaped(b2.opts().WithName("A")); + Node* b = Input(b2.opts().WithName("B")); + Node* c = Unary(a, b2.opts().WithName("C")); + + NodeBuilder node_builder("F1", "F1", lib_def.get()); + node_builder.Input(b).Input(c); + Node* call = + b2.opts().WithControlInputs({c}).FinalizeBuilder(&node_builder); + + Node* recv = + RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT}, + b2.opts().WithName("outside_compilation_F1_O1_recv")); + Node* e = BinaryUnknownShape( + c, ops::NodeOut(recv, 0), + b2.opts().WithName("E").WithControlInputs({recv, b})); + Node* send = SendFromHost("host_compute_channel_F1_O1", {e}, + b2.opts() + .WithName("outside_compilation_F1_O1_send") + .WithControlInput(e)); + + Node* s = NoOp( + b2.opts().WithName("F1_sequencer").WithControlInputs({recv, send})); + + Binary(a, call, b2.opts().WithName("G").WithControlInputs({s, e})); + TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); + } + + TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef); + TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.cc b/tensorflow/compiler/jit/kernels/xla_launch_op.cc index 17ae2bb25cac94cee0a1f3df66edf1b3a404e3ec..6353149e4afdf739fe44dd5c76502ef5d98b8477 100644 --- a/tensorflow/compiler/jit/kernels/xla_launch_op.cc +++ b/tensorflow/compiler/jit/kernels/xla_launch_op.cc @@ -376,8 +376,6 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) { OP_REQUIRES(ctx, write.input_index >= 0 && write.input_index < ctx->num_inputs(), errors::Internal("Invalid input index for variable write.")); - TensorShape write_shape; - OP_REQUIRES_OK(ctx, XLAShapeToTensorShape(write.shape, &write_shape)); gpu::DeviceMemoryBase buffer = output->buffer({output_num}); @@ -399,7 +397,7 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) { // Looks up the owning Tensor by buffer address. OP_REQUIRES_OK( - ctx, xla_allocator.MakeTensorFromBuffer(buffer, write.type, write_shape, + ctx, xla_allocator.MakeTensorFromBuffer(buffer, write.type, write.shape, variable->tensor())); ++output_num; } diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc index 21d3a54f1b8ea59c3da09d8c5d626a9f5bcebbc9..6d854a920eb0b4c01b09024ceaef5035e847d392 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.cc +++ b/tensorflow/compiler/jit/xla_compilation_cache.cc @@ -148,8 +148,7 @@ Status BuildArguments(int num_constant_args, XlaCompiler::Argument& arg = (*args)[input_num]; arg.kind = XlaCompiler::Argument::kConstant; arg.type = input.dtype(); - TF_RETURN_IF_ERROR( - TensorShapeToXLAShape(input.dtype(), input.shape(), &arg.shape)); + arg.shape = input.shape(); arg.constant_value = input; ++input_num; } @@ -170,8 +169,7 @@ Status BuildArguments(int num_constant_args, arg.constant_value = input; } arg.type = input.dtype(); - TF_RETURN_IF_ERROR( - TensorShapeToXLAShape(input.dtype(), input.shape(), &arg.shape)); + arg.shape = input.shape(); ++input_num; } @@ -189,8 +187,7 @@ Status BuildArguments(int num_constant_args, if (variable_args[variable_id].present) { const Tensor& value = variable_args[variable_id].value; arg.type = value.dtype(); - TF_RETURN_IF_ERROR( - TensorShapeToXLAShape(value.dtype(), value.shape(), &arg.shape)); + arg.shape = value.shape(); arg.initialized = true; } else { // The values of uninitialized variables are not passed as inputs, since @@ -199,7 +196,7 @@ Status BuildArguments(int num_constant_args, // uninitialized variables. arg.initialized = false; arg.type = DT_INVALID; - arg.shape = xla::Shape(); + arg.shape = TensorShape(); } ++input_num; } diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index 314f5506b16e2c28736d9d39aa6c856d50885108..25e329b6aadbab7219d7120ce5f51b3a6f5884e9 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -144,6 +144,21 @@ tf_xla_py_test( ], ) +tf_xla_py_test( + name = "matrix_triangular_solve_op_test", + size = "small", + srcs = ["matrix_triangular_solve_op_test.py"], + tags = ["optonly"], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + "//tensorflow/python:training", + ], +) + tf_xla_py_test( name = "clustering_test", size = "small", @@ -240,6 +255,18 @@ tf_xla_py_test( ], ) +tf_xla_py_test( + name = "extract_image_patches_op_test", + size = "small", + srcs = ["extract_image_patches_op_test.py"], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:platform_test", + ], +) + tf_xla_py_test( name = "fft_test", size = "medium", @@ -326,6 +353,19 @@ tf_xla_py_test( ], ) +tf_xla_py_test( + name = "matrix_band_part_test", + size = "medium", + srcs = ["matrix_band_part_test.py"], + tags = ["optonly"], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:platform_test", + ], +) + tf_xla_py_test( name = "momentum_test", size = "small", @@ -437,6 +477,18 @@ tf_xla_py_test( ], ) +tf_xla_py_test( + name = "reverse_sequence_op_test", + size = "small", + srcs = ["reverse_sequence_op_test.py"], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:platform_test", + ], +) + tf_xla_py_test( name = "rmsprop_test", size = "small", @@ -613,6 +665,18 @@ tf_xla_py_test( ], ) +tf_xla_py_test( + name = "gather_nd_op_test", + size = "medium", + srcs = ["gather_nd_op_test.py"], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:platform_test", + ], +) + cuda_py_test( name = "xla_device_test", size = "small", diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py index c95fb1c515242ca38369b11aa5e616b13624edf9..30a6d3a74d64f90ad33062df6d1e16e3a575bd63 100644 --- a/tensorflow/compiler/tests/binary_ops_test.py +++ b/tensorflow/compiler/tests/binary_ops_test.py @@ -1181,6 +1181,50 @@ class BinaryOpsTest(XLATestCase): np.array([4, 5, 6], dtype=np.int32), expected=None) + def testMatrixSetDiag(self): + for dtype in self.numeric_types: + # Square + self._testBinary( + array_ops.matrix_set_diag, + np.array([[0.0, 1.0, 0.0], [1.0, 0.0, 1.0], [1.0, 1.0, 1.0]], + dtype=dtype), + np.array([1.0, 2.0, 3.0], dtype=dtype), + expected=np.array([[1.0, 1.0, 0.0], [1.0, 2.0, 1.0], [1.0, 1.0, 3.0]], + dtype=dtype)) + + self._testBinary( + array_ops.matrix_set_diag, + np.array([[[1.0, 0.0, 3.0], [0.0, 2.0, 0.0], [1.0, 0.0, 3.0]], + [[4.0, 0.0, 4.0], [0.0, 5.0, 0.0], [2.0, 0.0, 6.0]]], + dtype=dtype), + np.array([[-1.0, 0.0, -3.0], [-4.0, -5.0, -6.0]], dtype=dtype), + expected=np.array( + [[[-1.0, 0.0, 3.0], [0.0, 0.0, 0.0], [1.0, 0.0, -3.0]], + [[-4.0, 0.0, 4.0], [0.0, -5.0, 0.0], [2.0, 0.0, -6.0]]], + dtype=dtype)) + + # Rectangular + self._testBinary( + array_ops.matrix_set_diag, + np.array([[0.0, 1.0, 0.0], [1.0, 0.0, 1.0]], dtype=dtype), + np.array([3.0, 4.0], dtype=dtype), + expected=np.array([[3.0, 1.0, 0.0], [1.0, 4.0, 1.0]], dtype=dtype)) + + self._testBinary( + array_ops.matrix_set_diag, + np.array([[0.0, 1.0], [1.0, 0.0], [1.0, 1.0]], dtype=dtype), + np.array([3.0, 4.0], dtype=dtype), + expected=np.array([[3.0, 1.0], [1.0, 4.0], [1.0, 1.0]], dtype=dtype)) + + self._testBinary( + array_ops.matrix_set_diag, + np.array([[[1.0, 0.0, 3.0], [0.0, 2.0, 0.0]], + [[4.0, 0.0, 4.0], [0.0, 5.0, 0.0]]], dtype=dtype), + np.array([[-1.0, -2.0], [-4.0, -5.0]], + dtype=dtype), + expected=np.array([[[-1.0, 0.0, 3.0], [0.0, -2.0, 0.0]], + [[-4.0, 0.0, 4.0], [0.0, -5.0, 0.0]]], + dtype=dtype)) if __name__ == "__main__": googletest.main() diff --git a/tensorflow/compiler/tests/extract_image_patches_op_test.py b/tensorflow/compiler/tests/extract_image_patches_op_test.py new file mode 100644 index 0000000000000000000000000000000000000000..0361702e7af778176daed941d64e61198090daf2 --- /dev/null +++ b/tensorflow/compiler/tests/extract_image_patches_op_test.py @@ -0,0 +1,134 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Functional tests for ExtractImagePatches op.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +class ExtractImagePatches(XLATestCase): + """Functional tests for ExtractImagePatches op.""" + + def _VerifyValues(self, image, ksizes, strides, rates, padding, patches): + """Tests input-output pairs for the ExtractImagePatches op. + + Args: + image: Input tensor with shape: [batch, in_rows, in_cols, depth]. + ksizes: Patch size specified as: [ksize_rows, ksize_cols]. + strides: Output strides, specified as [stride_rows, stride_cols]. + rates: Atrous rates, specified as [rate_rows, rate_cols]. + padding: Padding type. + patches: Expected output. + """ + ksizes = [1] + ksizes + [1] + strides = [1] + strides + [1] + rates = [1] + rates + [1] + + with self.test_session(): + image_placeholder = array_ops.placeholder(dtypes.float32) + with self.test_scope(): + out_tensor = array_ops.extract_image_patches( + image_placeholder, + ksizes=ksizes, + strides=strides, + rates=rates, + padding=padding, + name="im2col") + feed_dict = {image_placeholder: image} + self.assertAllClose(patches, out_tensor.eval(feed_dict=feed_dict)) + + def testKsize1x1Stride1x1Rate1x1(self): + """Verifies that for 1x1 kernel the output equals the input.""" + # [2, 3, 4, 5] + image = np.reshape(range(120), [2, 3, 4, 5]) + # [2, 3, 4, 5] + patches = np.reshape(range(120), [2, 3, 4, 5]) + for padding in ["VALID", "SAME"]: + self._VerifyValues( + image, + ksizes=[1, 1], + strides=[1, 1], + rates=[1, 1], + padding=padding, + patches=patches) + + def testKsize1x1Stride2x3Rate1x1(self): + """Test for 1x1 kernel and strides.""" + # [2, 4, 5, 3] + image = np.reshape(range(120), [2, 4, 5, 3]) + # [2, 2, 2, 3] + patches = image[:, ::2, ::3, :] + for padding in ["VALID", "SAME"]: + self._VerifyValues( + image, + ksizes=[1, 1], + strides=[2, 3], + rates=[1, 1], + padding=padding, + patches=patches) + + def testKsize2x2Stride1x1Rate1x1Valid(self): + """Test for 2x2 kernel with VALID padding.""" + # [1, 2, 2, 1] + image = [[[[1], [2]], [[3], [4]]]] + # [1, 1, 1, 4] + patches = [[[[1, 2, 3, 4]]]] + self._VerifyValues( + image, + ksizes=[2, 2], + strides=[1, 1], + rates=[1, 1], + padding="VALID", + patches=patches) + + def testKsize2x2Stride1x1Rate1x1Same(self): + """Test for 2x2 kernel with SAME padding.""" + # [1, 2, 2, 1] + image = [[[[1], [2]], [[3], [4]]]] + # [1, 2, 2, 4] + patches = [[[[1, 2, 3, 4], [2, 0, 4, 0]], [[3, 4, 0, 0], [4, 0, 0, 0]]]] + self._VerifyValues( + image, + ksizes=[2, 2], + strides=[1, 1], + rates=[1, 1], + padding="SAME", + patches=patches) + + def testKsize2x2Stride1x1Rate2x2Valid(self): + """Test for 2x2 kernel with 2x2 dilation.""" + # [1, 2, 2, 1] + image = np.arange(16).reshape(1, 4, 4, 1).astype(np.float32) + # [1, 2, 2, 4] + patches = [[[[0, 2, 8, 10], [1, 3, 9, 11]], + [[4, 6, 12, 14], [5, 7, 13, 15]]]] + self._VerifyValues( + image, + ksizes=[2, 2], + strides=[1, 1], + rates=[2, 2], + padding="VALID", + patches=patches) + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tests/gather_nd_op_test.py b/tensorflow/compiler/tests/gather_nd_op_test.py new file mode 100644 index 0000000000000000000000000000000000000000..9378b1db7245c0da3e8298e7dcd972491616b0cd --- /dev/null +++ b/tensorflow/compiler/tests/gather_nd_op_test.py @@ -0,0 +1,147 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tensorflow.ops.tf.gather_nd.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.python.framework import errors +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +class GatherNdTest(XLATestCase): + + def _runGather(self, params, indices): + with self.test_session(): + paramsp = array_ops.placeholder(params.dtype) + indicesp = array_ops.placeholder(indices.dtype) + with self.test_scope(): + gather_nd_t = array_ops.gather_nd(paramsp, indicesp) + feed_dict = {paramsp: params, indicesp: indices} + return gather_nd_t.eval(feed_dict=feed_dict) + + def testSimpleDtype(self): + for dtype in self.numeric_types: + self.assertAllEqual( + np.array([7, 7, 8], dtype=dtype), + self._runGather( + np.array([8, 1, 2, 3, 7, 5], dtype=dtype), + np.array([[4], [4], [0]], np.int32))) + + def testEmptyIndicesAndParamsOKButJustEmptyParamsFails(self): + with self.test_session(): + params = np.ones((3, 3), dtype=np.float32) + + indices_empty = np.empty((0, 2), dtype=np.int32) + gather_nd_ok_val = self._runGather(params, indices_empty) + self.assertAllClose(np.empty((0,), dtype=np.float32), gather_nd_ok_val) + + indices_empty = np.empty((0, 1), dtype=np.int32) + gather_nd_ok_val = self._runGather(params, indices_empty) + self.assertAllClose(np.empty((0, 3), dtype=np.float32), gather_nd_ok_val) + + params_empty = np.empty((0, 3), dtype=np.float32) + indices_empty = np.empty((0, 2), dtype=np.int32) + gather_nd_ok_val = self._runGather(params_empty, indices_empty) + self.assertAllClose(np.empty((0,), dtype=np.float32), gather_nd_ok_val) + + params_empty = np.empty((0, 3), dtype=np.float32) + indices_nonempty = np.zeros((1, 2), dtype=np.int32) + with self.assertRaisesWithPredicateMatch( + errors.InvalidArgumentError, r"Gather dimension 0 is of size zero"): + self._runGather(params_empty, indices_nonempty) + + def testIndexScalar(self): + params = np.array( + [[-8, -1, -2, -3, -7, -5], [8, 1, 2, 3, 7, 5]], dtype=np.float32).T + indices = np.array([4, 1], dtype=np.int32) + gather_nd_val = self._runGather(params, indices) + self.assertAllEqual(np.array(7), gather_nd_val) + + def testParamsRankLargerThanIndexIndexScalarSlices(self): + params = np.array( + [[-8, -1, -2, -3, -7, -5], [8, 1, 2, 3, 7, 5]], dtype=np.float32).T + indices = np.array( + [ + 4, + ], dtype=np.int32) + gather_nd_val = self._runGather(params, indices) + self.assertAllEqual(np.array([-7, 7]), gather_nd_val) + + def testParamsRankLargerThanIndexSlices(self): + params = np.array( + [[-8, -1, -2, -3, -7, -5], [8, 1, 2, 3, 7, 5]], dtype=np.float32).T + indices = np.array([[4], [4], [0]], np.int32) + gather_nd_val = self._runGather(params, indices) + self.assertAllEqual(np.array([[-7, 7], [-7, 7], [-8, 8]]), gather_nd_val) + + def testHigherRankParamsLargerThanIndexSlices(self): + params = np.array( + [[[-8, -1, -2, -3, -7, -5], [8, 1, 2, 3, 7, 5]], + [[-80, -10, -20, -30, -70, -50], [80, 10, 20, 30, 70, 50]]], + dtype=np.float32).T + indices = np.array([[4], [4], [0]], np.int32) + gather_nd_val = self._runGather(params, indices) + self.assertAllEqual(params[[4, 4, 0]], gather_nd_val) + + def testEmptyIndicesLastRankMeansCopyEntireTensor(self): + params = np.array( + [[[-8, -1, -2, -3, -7, -5], [8, 1, 2, 3, 7, 5]], + [[-80, -10, -20, -30, -70, -50], [80, 10, 20, 30, 70, 50]]], + dtype=np.float32).T + indices = np.array([[], []], dtype=np.int32) # Size (2, 0) + gather_nd_val = self._runGather(params, indices) + self.assertAllEqual( + np.vstack((params[np.newaxis, :], params[np.newaxis, :])), + gather_nd_val) + + def testHigherRankParamsAndIndicesLargerThanIndexSlices(self): + params = np.array( + [[[-8, -1, -2, -3, -7, -5], [8, 1, 2, 3, 7, 5]], + [[-80, -10, -20, -30, -70, -50], [80, 10, 20, 30, 70, 50]]], + dtype=np.float32).T + indices = np.array([[[3], [2], [1]], [[4], [4], [0]]], np.int32) + gather_nd_val = self._runGather(params, indices) + self.assertAllEqual(params[[3, 2, 1, 4, 4, 0]].reshape(2, 3, 2, 2), + gather_nd_val) + + def testHigherRankParams(self): + shape = (10, 20, 5, 1, 17) + params = np.random.rand(*shape).astype(np.float32) + indices = np.vstack( + [np.random.randint(0, s, size=2000, dtype=np.int32) for s in shape]).T + gather_nd_val = self._runGather(params, indices) + + expected = params[tuple(indices.T)] + self.assertAllEqual(expected, gather_nd_val) + + def testHigherRankParamsAndIndices(self): + shape = (10, 20, 5, 1, 17) + params = np.random.rand(*shape).astype(np.float32) + indices = np.vstack( + [np.random.randint(0, s, size=2000, dtype=np.int32) for s in shape]).T + indices_reshaped = indices.reshape([10, 10, 20, 5]) + gather_nd_val = self._runGather(params, indices_reshaped) + expected = params[tuple(indices.T)] + self.assertAllEqual(expected.reshape([10, 10, 20]), gather_nd_val) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tests/matrix_band_part_test.py b/tensorflow/compiler/tests/matrix_band_part_test.py new file mode 100644 index 0000000000000000000000000000000000000000..29394f9ea5139b30f88f53de0469b27e37d79195 --- /dev/null +++ b/tensorflow/compiler/tests/matrix_band_part_test.py @@ -0,0 +1,64 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +class MatrixBandPartTest(XLATestCase): + + def _testMatrixBandPart(self, dtype, shape): + with self.test_session(): + batch_shape = shape[:-2] + mat = np.ones(shape).astype(dtype) + batch_mat = np.tile(mat, batch_shape + [1, 1]) + for lower in -1, 0, 1, shape[-2] - 1: + for upper in -1, 0, 1, shape[-1] - 1: + band_np = mat + if lower >= 0: + band_np = np.triu(band_np, -lower) + if upper >= 0: + band_np = np.tril(band_np, upper) + if batch_shape: + band_np = np.tile(band_np, batch_shape + [1, 1]) + + placeholder = array_ops.placeholder(dtype) + with self.test_scope(): + band = array_ops.matrix_band_part( + placeholder, + constant_op.constant(lower, dtype=dtypes.int32), + constant_op.constant(upper, dtype=dtypes.int32)) + feed_dict = {placeholder: batch_mat} + self.assertAllEqual(band_np, band.eval(feed_dict=feed_dict)) + + def testMatrixBandPart(self): + for dtype in self.float_types: + for batch_shape in [[], [2,], [1, 3, 2]]: + for rows in 1, 2, 7: + for cols in 1, 2, 7: + self._testMatrixBandPart(dtype, batch_shape + [rows, cols]) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py b/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py new file mode 100644 index 0000000000000000000000000000000000000000..cccb7f5789dce39ef8c3d4b3a7573aaa983b3fbd --- /dev/null +++ b/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py @@ -0,0 +1,130 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tensorflow.ops.tf.MatrixTriangularSolve.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import itertools + +import numpy as np + +from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import linalg_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test + + +def MakePlaceholder(x): + return array_ops.placeholder(dtypes.as_dtype(x.dtype), shape=x.shape) + + +class MatrixTriangularSolveOpTest(XLATestCase): + + def _VerifyTriangularSolveBase(self, sess, placeholder_a, placeholder_ca, + placeholder_b, a, clean_a, b, verification, + atol): + feed_dict = {placeholder_a: a, placeholder_ca: clean_a, placeholder_b: b} + verification_np = sess.run(verification, feed_dict) + self.assertAllClose(b, verification_np, atol=atol) + + def _VerifyTriangularSolve(self, a, b, lower, adjoint, atol): + clean_a = np.tril(a) if lower else np.triu(a) + with self.test_session() as sess: + placeholder_a = MakePlaceholder(a) + placeholder_ca = MakePlaceholder(clean_a) + placeholder_b = MakePlaceholder(b) + with self.test_scope(): + x = linalg_ops.matrix_triangular_solve( + placeholder_a, placeholder_b, lower=lower, adjoint=adjoint) + verification = math_ops.matmul(placeholder_ca, x, adjoint_a=adjoint) + self._VerifyTriangularSolveBase(sess, placeholder_a, placeholder_ca, + placeholder_b, a, clean_a, b, + verification, atol) + + def _VerifyTriangularSolveCombo(self, a, b, atol=1e-4): + transp = lambda x: np.swapaxes(x, -1, -2) + for lower, adjoint in itertools.product([True, False], repeat=2): + self._VerifyTriangularSolve( + a if lower else transp(a), b, lower, adjoint, atol) + + def testBasic(self): + rng = np.random.RandomState(0) + a = np.tril(rng.randn(5, 5)) + b = rng.randn(5, 7) + for dtype in self.float_types: + self._VerifyTriangularSolveCombo(a.astype(dtype), b.astype(dtype)) + + def testBasicNotActuallyTriangular(self): + rng = np.random.RandomState(0) + a = rng.randn(5, 5) # the `a` matrix is not lower-triangular + b = rng.randn(5, 7) + for dtype in self.float_types: + self._VerifyTriangularSolveCombo(a.astype(dtype), b.astype(dtype)) + + def testBasicComplexDtypes(self): + rng = np.random.RandomState(0) + a = np.tril(rng.randn(5, 5) + rng.randn(5, 5) * 1j) + b = rng.randn(5, 7) + rng.randn(5, 7) * 1j + for dtype in self.complex_types: + self._VerifyTriangularSolveCombo(a.astype(dtype), b.astype(dtype)) + + def testBatch(self): + rng = np.random.RandomState(0) + shapes = [((4, 3, 3), (4, 3, 5)), ((1, 2, 2), (1, 2, 1)), + ((1, 1, 1), (1, 1, 2)), ((2, 3, 4, 4), (2, 3, 4, 1))] + tuples = itertools.product(self.float_types, shapes) + for dtype, (a_shape, b_shape) in tuples: + n = a_shape[-1] + a = np.tril(rng.rand(*a_shape) - 0.5) / (2.0 * n) + np.eye(n) + b = rng.randn(*b_shape) + self._VerifyTriangularSolveCombo( + a.astype(dtype), b.astype(dtype), atol=1e-3) + + def testLarge(self): + n = 1024 + rng = np.random.RandomState(0) + a = np.tril(rng.rand(n, n) - 0.5) / (2.0 * n) + np.eye(n) + b = rng.randn(n, n) + self._VerifyTriangularSolve( + a.astype(np.float32), b.astype(np.float32), True, False, 1e-4) + + def testNonSquareCoefficientMatrix(self): + rng = np.random.RandomState(0) + for dtype in self.float_types: + a = rng.randn(3, 4).astype(dtype) + b = rng.randn(4, 4).astype(dtype) + with self.assertRaises(ValueError): + linalg_ops.matrix_triangular_solve(a, b) + with self.assertRaises(ValueError): + linalg_ops.matrix_triangular_solve(a, b) + + def testWrongDimensions(self): + randn = np.random.RandomState(0).randn + for dtype in self.float_types: + lhs = constant_op.constant(randn(3, 3), dtype=dtype) + rhs = constant_op.constant(randn(4, 3), dtype=dtype) + with self.assertRaises(ValueError): + linalg_ops.matrix_triangular_solve(lhs, rhs) + with self.assertRaises(ValueError): + linalg_ops.matrix_triangular_solve(lhs, rhs) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tests/reverse_sequence_op_test.py b/tensorflow/compiler/tests/reverse_sequence_op_test.py new file mode 100644 index 0000000000000000000000000000000000000000..1a5d05094e53cfecd9476d7d87f023e8a02d7458 --- /dev/null +++ b/tensorflow/compiler/tests/reverse_sequence_op_test.py @@ -0,0 +1,93 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tensorflow.ops.reverse_sequence_op.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +class ReverseSequenceTest(XLATestCase): + + def _testReverseSequence(self, + x, + batch_axis, + seq_axis, + seq_lengths, + truth, + expected_err_re=None): + with self.test_session(): + p = array_ops.placeholder(dtypes.as_dtype(x.dtype)) + lengths = array_ops.placeholder(dtypes.as_dtype(seq_lengths.dtype)) + with self.test_scope(): + ans = array_ops.reverse_sequence( + p, batch_axis=batch_axis, seq_axis=seq_axis, seq_lengths=lengths) + if expected_err_re is None: + tf_ans = ans.eval(feed_dict={p: x, lengths: seq_lengths}) + self.assertAllClose(tf_ans, truth, atol=1e-10) + else: + with self.assertRaisesOpError(expected_err_re): + ans.eval(feed_dict={p: x, lengths: seq_lengths}) + + def testSimple(self): + x = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.int32) + expected = np.array([[1, 2, 3], [6, 5, 4], [8, 7, 9]], dtype=np.int32) + self._testReverseSequence( + x, + batch_axis=0, + seq_axis=1, + seq_lengths=np.array([1, 3, 2], np.int32), + truth=expected) + + def _testBasic(self, dtype, len_dtype): + x = np.asarray( + [[[1, 2, 3, 4], [5, 6, 7, 8]], [[9, 10, 11, 12], [13, 14, 15, 16]], + [[17, 18, 19, 20], [21, 22, 23, 24]]], + dtype=dtype) + x = x.reshape(3, 2, 4, 1, 1) + x = x.transpose([2, 1, 0, 3, 4]) # permute axes 0 <=> 2 + + # reverse dim 2 up to (0:3, none, 0:4) along dim=0 + seq_lengths = np.asarray([3, 0, 4], dtype=len_dtype) + + truth_orig = np.asarray( + [ + [[3, 2, 1, 4], [7, 6, 5, 8]], # reverse 0:3 + [[9, 10, 11, 12], [13, 14, 15, 16]], # reverse none + [[20, 19, 18, 17], [24, 23, 22, 21]] + ], # reverse 0:4 (all) + dtype=dtype) + truth_orig = truth_orig.reshape(3, 2, 4, 1, 1) + truth = truth_orig.transpose([2, 1, 0, 3, 4]) # permute axes 0 <=> 2 + + seq_axis = 0 # permute seq_axis and batch_axis (originally 2 and 0, resp.) + batch_axis = 2 + self._testReverseSequence(x, batch_axis, seq_axis, seq_lengths, truth) + + def testSeqLength(self): + for dtype in self.all_types: + for seq_dtype in self.int_types: + self._testBasic(dtype, seq_dtype) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py index 8e4b8a38336c5e8b2e10edc4c81502eeebb628d2..3d3e112f4821ea8e57ea9589a5b4433647ad294b 100644 --- a/tensorflow/compiler/tests/unary_ops_test.py +++ b/tensorflow/compiler/tests/unary_ops_test.py @@ -154,6 +154,21 @@ class UnaryOpsTest(XLATestCase): def testFloatOps(self): for dtype in self.float_types: + x = np.arange(-0.90, 0.90, 0.25) + self._assertOpOutputMatchesExpected( + math_ops.acos, + x.astype(dtype), + expected=np.arccos(x).astype(dtype)) + self._assertOpOutputMatchesExpected( + math_ops.asin, + x.astype(dtype), + expected=np.arcsin(x).astype(dtype)) + x = np.arange(-3, 3).reshape(1, 3, 2) + self._assertOpOutputMatchesExpected( + math_ops.atan, + x.astype(dtype), + expected=np.arctan(x).astype(dtype)) + self._assertOpOutputMatchesExpected( math_ops.acosh, np.array([1, 2, 3, 4], dtype=dtype), diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc index 1d9e0fb33ee4a4229c78d116831e95391a5ac3f8..bf304102ede610e952a424f0b24505a14692f8ed 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc @@ -427,16 +427,36 @@ Status FunctionalizeLoop(Graph* graph, Frame* frame, // identity nodes are values used by the loop body or condition. // The Identity node may have the wrong device so copy the device from // one of its outputs instead. + std::deque possible_exit; for (const Edge* edge : arg.switch_node->out_edges()) { - if (edge->src_output() == 0 && IsExit(edge->dst())) { + if (edge->src_output() == 0) { + possible_exit.push_back(edge); + } + if (IsIdentity(edge->dst())) { + TF_RETURN_IF_ERROR( + SetNodeShardingFromNeighbors(edge->dst(), /*out_edges=*/true)); + } + } + // TODO(b/67425339): Allow general graph between switch and exit. + while (!possible_exit.empty()) { + const Edge* edge = possible_exit.front(); + possible_exit.pop_front(); + if (IsExit(edge->dst())) { if (arg.exit != nullptr) { return errors::InvalidArgument("Duplicate Exit successors to ", arg.switch_node->name()); } arg.exit = edge->dst(); - } else if (StringPiece(edge->dst()->type_string()) == "Identity") { - TF_RETURN_IF_ERROR( - SetNodeShardingFromNeighbors(edge->dst(), /*out_edges=*/true)); + } else { + if (!IsIdentity(edge->dst())) { + return errors::Unimplemented("General graph between switch (", + arg.switch_node->name(), + ") and exit node of frame ", + frame->name, " not supported yet."); + } + for (const Edge* out : edge->dst()->out_edges()) { + possible_exit.push_back(out); + } } } } diff --git a/tensorflow/compiler/tf2xla/g3doc/cpu_supported_ops.md b/tensorflow/compiler/tf2xla/g3doc/cpu_supported_ops.md index 82b3b46a2f1e97001d1e0c6b993ec243170bc7d8..91351421bcacd26c41b5c9f98ea833730e4aef30 100644 --- a/tensorflow/compiler/tf2xla/g3doc/cpu_supported_ops.md +++ b/tensorflow/compiler/tf2xla/g3doc/cpu_supported_ops.md @@ -6,6 +6,9 @@ Operator | Type Constraint `Acosh` | `T={complex64,double,float}` `Add` | `T={complex64,double,float,int32,int64}` `AddN` | `T={complex64,double,float,int32,int64,uint32,uint64}` +`AdjustContrastv2` | +`AdjustHue` | +`AdjustSaturation` | `All` | `Tidx={int32,int64}` `Angle` | `Tout={double,float}`
`T={complex64}` `Any` | `Tidx={int32,int64}` @@ -34,7 +37,7 @@ Operator | Type Constraint `BroadcastGradientArgs` | `T={int32,int64}` `Cast` | `DstT={bool,complex64,double,float,int32,int64,uint32,uint64}`
`SrcT={bool,complex64,double,float,int32,int64,uint32,uint64}` `Ceil` | `T={double,float}` -`Cholesky` | `T={complex64,double,float}` +`Cholesky` | `T={double,float}` `Complex` | `Tout={complex64}`
`T={double,float}` `ComplexAbs` | `Tout={double,float}`
`T={complex64}` `Concat` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` @@ -68,7 +71,11 @@ Operator | Type Constraint `Exp` | `T={complex64,double,float}` `ExpandDims` | `Tdim={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` `Expm1` | `T={complex64,double,float}` -`Fill` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`ExtractImagePatches` | `T={double,float,int32,int64,uint32,uint64}` +`FFT` | +`FFT2D` | +`FFT3D` | +`Fill` | `index_type={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` `Floor` | `T={double,float}` `FloorDiv` | `T={complex64,double,float,int32,int64}` `FloorMod` | `T={double,float,int32,int64}` @@ -80,6 +87,13 @@ Operator | Type Constraint `GatherV2` | `Taxis={int32,int64}`
`Tindices={int32,int64}`
`Tparams={bool,complex64,double,float,int32,int64,uint32,uint64}` `Greater` | `T={double,float,int32,int64,uint32,uint64}` `GreaterEqual` | `T={double,float,int32,int64,uint32,uint64}` +`HSVToRGB` | `T={double,float}` +`IFFT` | +`IFFT2D` | +`IFFT3D` | +`IRFFT` | +`IRFFT2D` | +`IRFFT3D` | `Identity` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` `IdentityN` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` `Imag` | `Tout={double,float}`
`T={complex64}` @@ -105,11 +119,14 @@ Operator | Type Constraint `MatMul` | `T={complex64,double,float}` `MatrixDiag` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` `MatrixDiagPart` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`MatrixTriangularSolve` | `T={complex64,double,float}` `Max` | `Tidx={int32,int64}`
`T={complex64,double,float,int32,int64,uint32,uint64}` `MaxPool` | `T={double,float,int32,int64}` `MaxPool3D` | `T={float}` `MaxPool3DGrad` | `TInput={float}`
`T={float}` `MaxPoolGrad` | `T={double,float,int32,int64,uint32,uint64}` +`MaxPoolGradV2` | `T={double,float,int32,int64,uint32,uint64}` +`MaxPoolV2` | `T={double,float,int32,int64}` `Maximum` | `T={double,float,int32,int64}` `Mean` | `Tidx={int32,int64}`
`T={complex64,double,float,int32,int64,uint32,uint64}` `Min` | `Tidx={int32,int64}`
`T={complex64,double,float,int32,int64,uint32,uint64}` @@ -131,6 +148,10 @@ Operator | Type Constraint `PreventGradient` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` `Prod` | `Tidx={int32,int64}`
`T={complex64,double,float,int32,int64,uint32,uint64}` `QuantizeAndDequantizeV2` | `T={double,float}` +`RFFT` | +`RFFT2D` | +`RFFT3D` | +`RGBToHSV` | `T={double,float}` `RandomStandardNormal` | `dtype={float}` `RandomUniform` | `T={int32,int64}`
`dtype={double,float}` `RandomUniformInt` | `T={int32,int64}`
`Tout={int32,int64}` @@ -146,6 +167,8 @@ Operator | Type Constraint `Relu6Grad` | `T={double,float,int32,int64,uint32,uint64}` `ReluGrad` | `T={double,float,int32,int64,uint32,uint64}` `Reshape` | `Tshape={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`ResizeBilinear` | `T={double,float,int32,int64}` +`ResizeBilinearGrad` | `T={double,float}` `ResourceApplyAdagrad` | `T={double,float}` `ResourceApplyAdam` | `T={double,float}` `ResourceApplyFtrl` | `T={double,float}` @@ -156,6 +179,7 @@ Operator | Type Constraint `ResourceGather` | `Tindices={int32,int64}`
`dtype={complex64,double,float,int32,int64,uint32,uint64}` `ResourceStridedSliceAssign` | `Index={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` `Reverse` | `T={bool,complex64,double,float,int32,int64}` +`ReverseSequence` | `Tlen={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` `ReverseV2` | `T={bool,complex64,double,float,int32,int64}`
`Tidx={int32,int64}` `RightShift` | `T={int32,int64,uint32,uint64}` `Rint` | `T={double,float}` diff --git a/tensorflow/compiler/tf2xla/g3doc/gpu_supported_ops.md b/tensorflow/compiler/tf2xla/g3doc/gpu_supported_ops.md index d4b7621ad2858fe17e93d292dd807e4f7c1c336b..b9bdb829d773825005a8921f48d28b6892d8f0cd 100644 --- a/tensorflow/compiler/tf2xla/g3doc/gpu_supported_ops.md +++ b/tensorflow/compiler/tf2xla/g3doc/gpu_supported_ops.md @@ -6,6 +6,9 @@ Operator | Type Constraint `Acosh` | `T={complex64,double,float}` `Add` | `T={complex64,double,float,int32,int64}` `AddN` | `T={complex64,double,float,int32,int64,uint32,uint64}` +`AdjustContrastv2` | +`AdjustHue` | +`AdjustSaturation` | `All` | `Tidx={int32,int64}` `Angle` | `Tout={double,float}`
`T={complex64}` `Any` | `Tidx={int32,int64}` @@ -34,7 +37,7 @@ Operator | Type Constraint `BroadcastGradientArgs` | `T={int32,int64}` `Cast` | `DstT={bool,complex64,double,float,int32,int64,uint32,uint64}`
`SrcT={bool,complex64,double,float,int32,int64,uint32,uint64}` `Ceil` | `T={double,float}` -`Cholesky` | `T={complex64,double,float}` +`Cholesky` | `T={double,float}` `Complex` | `Tout={complex64}`
`T={double,float}` `ComplexAbs` | `Tout={double,float}`
`T={complex64}` `Concat` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` @@ -68,7 +71,11 @@ Operator | Type Constraint `Exp` | `T={complex64,double,float}` `ExpandDims` | `Tdim={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` `Expm1` | `T={complex64,double,float}` -`Fill` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`ExtractImagePatches` | `T={double,float,int32,int64,uint32,uint64}` +`FFT` | +`FFT2D` | +`FFT3D` | +`Fill` | `index_type={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` `Floor` | `T={double,float}` `FloorDiv` | `T={complex64,double,float,int32,int64}` `FloorMod` | `T={double,float,int32,int64}` @@ -80,6 +87,13 @@ Operator | Type Constraint `GatherV2` | `Taxis={int32,int64}`
`Tindices={int32,int64}`
`Tparams={bool,complex64,double,float,int32,int64,uint32,uint64}` `Greater` | `T={double,float,int32,int64,uint32,uint64}` `GreaterEqual` | `T={double,float,int32,int64,uint32,uint64}` +`HSVToRGB` | `T={double,float}` +`IFFT` | +`IFFT2D` | +`IFFT3D` | +`IRFFT` | +`IRFFT2D` | +`IRFFT3D` | `Identity` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` `IdentityN` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` `Imag` | `Tout={double,float}`
`T={complex64}` @@ -105,11 +119,14 @@ Operator | Type Constraint `MatMul` | `T={complex64,double,float}` `MatrixDiag` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` `MatrixDiagPart` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`MatrixTriangularSolve` | `T={complex64,double,float}` `Max` | `Tidx={int32,int64}`
`T={complex64,double,float,int32,int64,uint32,uint64}` `MaxPool` | `T={double,float,int32,int64}` `MaxPool3D` | `T={float}` `MaxPool3DGrad` | `TInput={float}`
`T={float}` `MaxPoolGrad` | `T={double,float,int32,int64,uint32,uint64}` +`MaxPoolGradV2` | `T={double,float,int32,int64,uint32,uint64}` +`MaxPoolV2` | `T={double,float,int32,int64}` `Maximum` | `T={double,float,int32,int64}` `Mean` | `Tidx={int32,int64}`
`T={complex64,double,float,int32,int64,uint32,uint64}` `Min` | `Tidx={int32,int64}`
`T={complex64,double,float,int32,int64,uint32,uint64}` @@ -131,6 +148,10 @@ Operator | Type Constraint `PreventGradient` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` `Prod` | `Tidx={int32,int64}`
`T={complex64,double,float,int32,int64,uint32,uint64}` `QuantizeAndDequantizeV2` | `T={double,float}` +`RFFT` | +`RFFT2D` | +`RFFT3D` | +`RGBToHSV` | `T={double,float}` `Range` | `Tidx={double,float,int32,int64}` `Rank` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` `ReadVariableOp` | `dtype={bool,complex64,double,float,int32,int64,uint32,uint64}` @@ -143,6 +164,8 @@ Operator | Type Constraint `Relu6Grad` | `T={double,float,int32,int64,uint32,uint64}` `ReluGrad` | `T={double,float,int32,int64,uint32,uint64}` `Reshape` | `Tshape={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`ResizeBilinear` | `T={double,float,int32,int64}` +`ResizeBilinearGrad` | `T={double,float}` `ResourceApplyAdagrad` | `T={double,float}` `ResourceApplyAdam` | `T={double,float}` `ResourceApplyFtrl` | `T={double,float}` @@ -153,6 +176,7 @@ Operator | Type Constraint `ResourceGather` | `Tindices={int32,int64}`
`dtype={complex64,double,float,int32,int64,uint32,uint64}` `ResourceStridedSliceAssign` | `Index={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` `Reverse` | `T={bool,complex64,double,float,int32,int64}` +`ReverseSequence` | `Tlen={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` `ReverseV2` | `T={bool,complex64,double,float,int32,int64}`
`Tidx={int32,int64}` `RightShift` | `T={int32,int64,uint32,uint64}` `Rint` | `T={double,float}` diff --git a/tensorflow/compiler/tf2xla/graph_compiler.cc b/tensorflow/compiler/tf2xla/graph_compiler.cc index 02215b5112d37f726604da2c2caa4f804388d6e5..1418d95956e1536292d58dfc4c2b53c53421fa94 100644 --- a/tensorflow/compiler/tf2xla/graph_compiler.cc +++ b/tensorflow/compiler/tf2xla/graph_compiler.cc @@ -60,9 +60,7 @@ Status PrepareArguments(XlaOpKernelContext* ctx, Graph* graph, for (int i = 0; i < args->size(); ++i) { XlaCompiler::Argument& arg = (*args)[i]; arg.type = ctx->input_type(i); - - TF_RETURN_IF_ERROR( - TensorShapeToXLAShape(arg.type, ctx->InputShape(i), &arg.shape)); + arg.shape = ctx->InputShape(i); if (arg.type == DT_RESOURCE) { return errors::InvalidArgument( diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index 5e1b01878b74f2fbc2e84f8c2db1fa37c2c1eb0e..4c6b29bd015d048f842906cc509a6ed564629b73 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -31,6 +31,7 @@ tf_kernel_library( "diag_op.cc", "dynamic_stitch_op.cc", "elu_op.cc", + "extract_image_patches_op.cc", "fft_ops.cc", "fill_op.cc", "function_ops.cc", @@ -43,6 +44,9 @@ tf_kernel_library( "l2loss_op.cc", "lrn_ops.cc", "matmul_op.cc", + "matrix_band_part_op.cc", + "matrix_set_diag_op.cc", + "matrix_triangular_solve_op.cc", "mirror_pad_op.cc", "no_op.cc", "one_hot_op.cc", @@ -58,6 +62,7 @@ tf_kernel_library( "reshape_op.cc", "retval_op.cc", "reverse_op.cc", + "reverse_sequence_op.cc", "scan_ops.cc", "segment_reduction_ops.cc", "select_op.cc", @@ -82,7 +87,6 @@ tf_kernel_library( "variable_ops.cc", ], hdrs = [ - "gather_op.h", "index_ops.h", "shape_util.h", ], @@ -92,6 +96,7 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/lib:batch_dot", "//tensorflow/compiler/tf2xla/lib:cholesky", + "//tensorflow/compiler/tf2xla/lib:triangular_solve", "//tensorflow/compiler/tf2xla/lib:util", "//tensorflow/compiler/tf2xla/ops:sendrecv_ops", "//tensorflow/compiler/xla:array4d", diff --git a/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc b/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc index a015b8e0e8949f8aaa03a78b0f88b7ea8d6aaa1c..b0ba25b9983c3a9af26728ce4b1c263c844327db 100644 --- a/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc @@ -28,8 +28,9 @@ class BatchMatMulOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - auto result = - BatchDot(ctx->builder(), ctx->Input(0), ctx->Input(1), adj_x_, adj_y_); + auto result = BatchDot(ctx->builder(), ctx->Input(0), ctx->Input(1), + /*transpose_x=*/adj_x_, /*transpose_y=*/adj_y_, + /*conjugate_x=*/adj_x_, /*conjugate_y=*/adj_y_); OP_REQUIRES_OK(ctx, result.status()); ctx->SetOutput(0, result.ValueOrDie()); } diff --git a/tensorflow/compiler/tf2xla/kernels/cholesky_op.cc b/tensorflow/compiler/tf2xla/kernels/cholesky_op.cc index 87d858f763560be454c162e0cf40307c68217663..fe6651793dc763d13f4a4b0ac294ec3ecf64af8f 100644 --- a/tensorflow/compiler/tf2xla/kernels/cholesky_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/cholesky_op.cc @@ -33,7 +33,7 @@ class CholeskyOp : public XlaOpKernel { } }; -REGISTER_XLA_OP(Name("Cholesky"), CholeskyOp); +REGISTER_XLA_OP(Name("Cholesky").TypeConstraint("T", kFloatTypes), CholeskyOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc b/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..b2970eae20a3fb71f06619f476a49d41b22bca56 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc @@ -0,0 +1,169 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/core/util/tensor_format.h" + +namespace tensorflow { + +namespace { + +class ExtractImagePatchesOp : public XlaOpKernel { + public: + explicit ExtractImagePatchesOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("ksizes", &ksizes_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &strides_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("rates", &dilations_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + const TensorFormat data_format = FORMAT_NHWC; + const int num_dims = ksizes_.size(); + + OP_REQUIRES( + ctx, num_dims >= 3, + errors::InvalidArgument("Kernel size must have at least 3 dimensions")); + const int num_spatial_dims = num_dims - 2; + + OP_REQUIRES(ctx, strides_.size() == num_dims, + errors::InvalidArgument("Sliding window strides field must " + "specify ", + num_dims, " dimensions")); + OP_REQUIRES(ctx, dilations_.size() == num_dims, + errors::InvalidArgument("Dilations field must " + "specify ", + num_dims, " dimensions")); + + int batch_dim = GetTensorBatchDimIndex(num_dims, data_format); + int feature_dim = GetTensorFeatureDimIndex(num_dims, data_format); + OP_REQUIRES( + ctx, ksizes_[batch_dim] == 1 && ksizes_[feature_dim] == 1, + errors::Unimplemented("Current implementation does not yet support " + "kernel sizes > 1 in the batch and depth " + "dimensions.")); + OP_REQUIRES( + ctx, strides_[batch_dim] == 1 && strides_[feature_dim] == 1, + errors::Unimplemented("Current implementation does not yet support " + "strides in the batch and depth dimensions.")); + OP_REQUIRES( + ctx, dilations_[batch_dim] == 1 && dilations_[feature_dim] == 1, + errors::Unimplemented("Current implementation does not support " + "dilations in the batch and depth dimensions.")); + + for (int i = 0; i < num_spatial_dims; ++i) { + int input_dim = GetTensorSpatialDimIndex(num_dims, data_format, i); + OP_REQUIRES( + ctx, ksizes_[input_dim] >= 0, + errors::Unimplemented("Kernel size values must be non-negative; ", i, + "th spatial dimension had dilation ", + dilations_[input_dim])); + OP_REQUIRES(ctx, strides_[input_dim] >= 1, + errors::Unimplemented("Stride values must be positive; ", i, + "th spatial dimension had dilation ", + dilations_[input_dim])); + OP_REQUIRES(ctx, dilations_[input_dim] >= 1, + errors::Unimplemented("Dilation values must be positive; ", i, + "th spatial dimension had dilation ", + dilations_[input_dim])); + } + + xla::PrimitiveType type; + OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(ctx->input_type(0), &type)); + + const TensorShape input_shape = ctx->InputShape(0); + OP_REQUIRES( + ctx, input_shape.dims() == num_dims, + errors::InvalidArgument("input must be ", num_dims, "-dimensional", + input_shape.DebugString())); + const int64 depth = input_shape.dim_size(feature_dim); + + xla::ComputationBuilder* builder = ctx->builder(); + + // The following code is equivalent to: + // eye = np.eye(kH * kW * D).reshape([kH, kW, D, kH * kW * kD]) + int64 kernel_size = 1; + std::vector lhs_shape(num_dims, 1); + for (int i = 0; i < num_spatial_dims; ++i) { + int input_dim = GetTensorSpatialDimIndex(num_dims, data_format, i); + lhs_shape[i] = ksizes_[input_dim]; + kernel_size *= ksizes_[input_dim]; + } + lhs_shape[num_spatial_dims] = depth; + lhs_shape[num_spatial_dims + 1] = 1; + + // Builds an identity matrix as a broadcast equality of iotas. + // iota = np.arange(np.prod(ksize), depth) + // filter = np.equal(np.reshape(iota, [-1, 1]), iota).astype(np.float32) + xla::ComputationDataHandle iota; + TF_CHECK_OK(XlaHelpers::Iota(builder, DataType::DT_INT32, + kernel_size * depth, &iota)); + + auto lhs = builder->Reshape(iota, lhs_shape); + auto filter = builder->ConvertElementType( + builder->Eq(lhs, iota, {num_spatial_dims + 1}), type); + + xla::ConvolutionDimensionNumbers dims; + std::vector window_strides(num_spatial_dims); + std::vector lhs_dilation(num_spatial_dims, 1); + std::vector rhs_dilation(num_spatial_dims); + std::vector> padding(num_spatial_dims); + + dims.set_input_batch_dimension(batch_dim); + dims.set_output_batch_dimension(batch_dim); + dims.set_input_feature_dimension(feature_dim); + dims.set_output_feature_dimension(feature_dim); + dims.set_kernel_input_feature_dimension(num_spatial_dims); + dims.set_kernel_output_feature_dimension(num_spatial_dims + 1); + + for (int i = 0; i < num_spatial_dims; ++i) { + const int64 dim = GetTensorSpatialDimIndex(num_dims, data_format, i); + dims.add_input_spatial_dimensions(dim); + dims.add_kernel_spatial_dimensions(i); + dims.add_output_spatial_dimensions(dim); + window_strides[i] = strides_.at(dim); + rhs_dilation[i] = dilations_.at(dim); + + int64 unused_output_size; + OP_REQUIRES_OK( + ctx, GetWindowedOutputSizeVerboseV2( + input_shape.dim_size(dim), ksizes_[dim], rhs_dilation[i], + window_strides[i], padding_, &unused_output_size, + &padding[i].first, &padding[i].second)); + } + + xla::ComputationDataHandle conv = + builder->ConvGeneralDilated(ctx->Input(0), filter, window_strides, + padding, lhs_dilation, rhs_dilation, dims); + ctx->SetOutput(0, conv); + } + + protected: + std::vector ksizes_; + std::vector dilations_; + std::vector strides_; + Padding padding_; + + private: + TF_DISALLOW_COPY_AND_ASSIGN(ExtractImagePatchesOp); +}; + +REGISTER_XLA_OP(Name("ExtractImagePatches"), ExtractImagePatchesOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/gather_op.cc b/tensorflow/compiler/tf2xla/kernels/gather_op.cc index ffed38249416766850ba10f1069e706570b995fe..e9af1e9c2fcb4922ea3570516419abd04a611a04 100644 --- a/tensorflow/compiler/tf2xla/kernels/gather_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/gather_op.cc @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/tf2xla/kernels/gather_op.h" #include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" @@ -26,25 +25,38 @@ limitations under the License. namespace tensorflow { -xla::ComputationDataHandle XlaComputeGatherDynamicSlice( - XlaOpKernelContext* context, const xla::ComputationDataHandle& input, - const TensorShape& input_shape, const xla::ComputationDataHandle& indices, - const TensorShape& indices_shape, int64 axis, DataType dtype, - DataType index_type, xla::ComputationBuilder* builder) { +Status XlaGather(const xla::ComputationDataHandle& input, + const TensorShape& input_shape, + const xla::ComputationDataHandle& indices, + TensorShape indices_shape, int64 axis, bool indices_are_nd, + DataType dtype, DataType index_type, + xla::ComputationBuilder* builder, + xla::ComputationDataHandle* gather_output) { + // If the indices are N-dimensional, then the last dimension of indices should + // be of size N and correspond to the N indices. + int64 num_axes = 1; + if (indices_are_nd) { + CHECK_GE(indices_shape.dims(), 1); + num_axes = indices_shape.dim_size(indices_shape.dims() - 1); + indices_shape.RemoveLastDims(1); + } + // Although the indices Tensor is flattened into rank 1 during the lookup, // and each scalar entry is used as an index into the first dimension of the // input, the output is returned with shape: // input.shape[:axis] + indices.shape + input.shape[axis+1:] + const int num_indices = indices_shape.num_elements(); TensorShape input_shape_pre_axis(input_shape); input_shape_pre_axis.RemoveDimRange(axis, input_shape.dims()); TensorShape input_shape_post_axis(input_shape); - input_shape_post_axis.RemoveDimRange(0, axis + 1); - + input_shape_post_axis.RemoveDimRange(0, axis + num_axes); // Each slice of the input tensor has shape: - // [, 1, ] + // [, 1, ..., 1, ] TensorShape slice_shape(input_shape); - slice_shape.set_dim(axis, 1); + for (int64 i = 0; i < num_axes; ++i) { + slice_shape.set_dim(axis + i, 1); + } TensorShape loop_out_shape; loop_out_shape.AppendShape(input_shape_pre_axis); @@ -62,8 +74,24 @@ xla::ComputationDataHandle XlaComputeGatherDynamicSlice( // Degenerate case: empty indices. if (num_indices == 0) { - return builder->Broadcast(XlaHelpers::Zero(builder, dtype), - out_shape.dim_sizes()); + *gather_output = builder->Broadcast(XlaHelpers::Zero(builder, dtype), + out_shape.dim_sizes()); + return Status::OK(); + } + + for (int64 i = 0; i < num_axes; ++i) { + if (input_shape.dim_size(axis + i) == 0) { + return errors::InvalidArgument("Gather dimension ", axis + i, + " is of size zero in tensor with shape ", + input_shape.DebugString()); + } + } + + // Flatten the major dimensions of indices into a single dimension for ease of + // iteration. If there is an axis dimension, we must leave it alone. + std::vector flat_indices_shape = {num_indices}; + if (indices_are_nd) { + flat_indices_shape.push_back(num_axes); } // Specify the shape of the loop-carried Tensor tuple. @@ -76,8 +104,8 @@ xla::ComputationDataHandle XlaComputeGatherDynamicSlice( xla::ShapeUtil::MakeShape(idxtype, {}), // The input array has shape input_shape. Loop invariant. xla::ShapeUtil::MakeShape(ptype, input_shape.dim_sizes()), - // The gather indices are reshaped to rank 1. Loop invariant. - xla::ShapeUtil::MakeShape(idxtype, {num_indices}), + // The gather indices are reshaped to flat_indices_shape. Loop invariant. + xla::ShapeUtil::MakeShape(idxtype, flat_indices_shape), // The output array, which is updated on each loop iteration. xla::ShapeUtil::MakeShape(ptype, loop_out_shape.dim_sizes())}); xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(tuple_shapes); @@ -86,17 +114,16 @@ xla::ComputationDataHandle XlaComputeGatherDynamicSlice( auto init_i = XlaHelpers::Zero(builder, index_type); auto init_out = builder->Broadcast(XlaHelpers::Zero(builder, dtype), loop_out_shape.dim_sizes()); - // Flatten the indices into 1-D for ease of iteration. - auto indices_1d = builder->Reshape(indices, {num_indices}); - auto init = builder->Tuple({init_i, input, indices_1d, init_out}); + auto flat_indices = builder->Reshape(indices, flat_indices_shape); + auto init = builder->Tuple({init_i, input, flat_indices, init_out}); // Construct the while loop condition (i < num_indices) - xla::ComputationBuilder condb(context->builder()->client(), - "GatherWhileCond"); - condb.Lt(condb.GetTupleElement( - condb.Parameter(0, tuple_shape, "GatherWhileTuple"), 0), - XlaHelpers::IntegerLiteral(&condb, index_type, num_indices)); - auto cond_status = condb.Build(); + std::unique_ptr condb = + builder->CreateSubBuilder("GatherWhileCond"); + condb->Lt(condb->GetTupleElement( + condb->Parameter(0, tuple_shape, "GatherWhileTuple"), 0), + XlaHelpers::IntegerLiteral(condb.get(), index_type, num_indices)); + auto cond_status = condb->Build(); auto cond = cond_status.ConsumeValueOrDie(); // Construct the while loop body's function. The implementation of gather is: @@ -104,89 +131,145 @@ xla::ComputationDataHandle XlaComputeGatherDynamicSlice( // index = dynamic-slice(indices, i) // xi = dynamic-slice(input, index) // output = dynamic-update-slice(output, xi, i) - xla::ComputationBuilder bodyb(context->builder()->client(), - "GatherWhileBody"); + std::unique_ptr bodyb = + builder->CreateSubBuilder("GatherWhileBody"); { // The four loop carried values. - auto loop_tuple = bodyb.Parameter(0, tuple_shape, "GatherWhileTuple"); - auto i = bodyb.GetTupleElement(loop_tuple, 0); - auto input = bodyb.GetTupleElement(loop_tuple, 1); - auto indices = bodyb.GetTupleElement(loop_tuple, 2); - auto output = bodyb.GetTupleElement(loop_tuple, 3); - - // Slice from the input array. - auto index = bodyb.DynamicSlice(indices, bodyb.Reshape(i, {1}), {1}); - auto start_indices = bodyb.Pad( - bodyb.Reshape(index, {1}), XlaHelpers::Zero(&bodyb, index_type), + auto loop_tuple = bodyb->Parameter(0, tuple_shape, "GatherWhileTuple"); + auto i = bodyb->GetTupleElement(loop_tuple, 0); + auto input = bodyb->GetTupleElement(loop_tuple, 1); + auto indices = bodyb->GetTupleElement(loop_tuple, 2); + auto output = bodyb->GetTupleElement(loop_tuple, 3); + + auto zero_index = XlaHelpers::Zero(bodyb.get(), index_type); + + // Slice the i-th index from the indices array. + xla::ComputationDataHandle index; + auto indices_offset = bodyb->Reshape(i, {1}); + if (indices_are_nd) { + // Slice out the entire nd index, if applicable. + indices_offset = bodyb->Pad(indices_offset, zero_index, + xla::MakeEdgePaddingConfig({{0, 1}})); + index = bodyb->DynamicSlice(indices, indices_offset, {1, num_axes}); + index = bodyb->Collapse(index, {0, 1}); + } else { + index = bodyb->DynamicSlice(indices, indices_offset, {1}); + } + + // Slice the corresponding data from the input array. + auto start_indices = bodyb->Pad( + index, zero_index, xla::MakeEdgePaddingConfig( {{input_shape_pre_axis.dims(), input_shape_post_axis.dims()}})); - auto slice_i = bodyb.Reshape( - bodyb.DynamicSlice(input, start_indices, slice_shape.dim_sizes()), + auto slice_i = bodyb->Reshape( + bodyb->DynamicSlice(input, start_indices, slice_shape.dim_sizes()), loop_out_slice_shape.dim_sizes()); // Construct the index into the output Tensor 0, ..., , 0, ... std::vector out_index_vals( - loop_out_shape.dims(), - bodyb.Reshape(XlaHelpers::Zero(&bodyb, index_type), {1})); - out_index_vals[input_shape_pre_axis.dims()] = bodyb.Reshape(i, {1}); - auto out_index = bodyb.ConcatInDim(out_index_vals, 0); + loop_out_shape.dims(), bodyb->Reshape(zero_index, {1})); + out_index_vals[input_shape_pre_axis.dims()] = bodyb->Reshape(i, {1}); + auto out_index = bodyb->ConcatInDim(out_index_vals, 0); // Update the output Tensor - auto updated_output = bodyb.DynamicUpdateSlice(output, slice_i, out_index); + auto updated_output = bodyb->DynamicUpdateSlice(output, slice_i, out_index); - bodyb.Tuple({bodyb.Add(i, XlaHelpers::One(&bodyb, index_type)), input, - indices, updated_output}); + bodyb->Tuple({bodyb->Add(i, XlaHelpers::One(bodyb.get(), index_type)), + input, indices, updated_output}); } - auto body_status = bodyb.Build(); + auto body_status = bodyb->Build(); auto body = body_status.ConsumeValueOrDie(); // Construct the While loop, extract and reshape the output. auto gather_while = builder->While(cond, body, init); - auto gather_output = builder->GetTupleElement(gather_while, 3); - return builder->Reshape(gather_output, out_shape.dim_sizes()); + auto result = builder->GetTupleElement(gather_while, 3); + *gather_output = builder->Reshape(result, out_shape.dim_sizes()); + return Status::OK(); } -GatherOpDynamicSlice::GatherOpDynamicSlice(OpKernelConstruction* context) - : XlaOpKernel(context) {} - -void GatherOpDynamicSlice::Compile(XlaOpKernelContext* context) { - xla::ComputationBuilder* builder = context->builder(); - auto input = context->Input(0); - auto input_shape = context->InputShape(0); - auto indices = context->Input(1); - auto indices_shape = context->InputShape(1); - int64 axis = 0; - if (context->num_inputs() == 3) { - const TensorShape axis_shape = context->InputShape(2); - OP_REQUIRES(context, TensorShapeUtils::IsScalar(axis_shape), - errors::InvalidArgument("axis must be scalar")); - DataType axis_type = input_type(2); - OP_REQUIRES(context, axis_type == DT_INT32 || axis_type == DT_INT64, - errors::InvalidArgument("axis must be int32 or int64")); - - OP_REQUIRES_OK(context, context->ConstantInputAsIntScalar(2, &axis)); - const auto params_dims = input_shape.dims(); - if (axis < 0) { - axis += params_dims; +class GatherOp : public XlaOpKernel { + public: + explicit GatherOp(OpKernelConstruction* context) : XlaOpKernel(context) {} + + void Compile(XlaOpKernelContext* context) override { + xla::ComputationBuilder* builder = context->builder(); + auto input = context->Input(0); + auto input_shape = context->InputShape(0); + auto indices = context->Input(1); + auto indices_shape = context->InputShape(1); + int64 axis = 0; + if (context->num_inputs() == 3) { + const TensorShape axis_shape = context->InputShape(2); + OP_REQUIRES(context, TensorShapeUtils::IsScalar(axis_shape), + errors::InvalidArgument("axis must be scalar")); + DataType axis_type = input_type(2); + OP_REQUIRES(context, axis_type == DT_INT32 || axis_type == DT_INT64, + errors::InvalidArgument("axis must be int32 or int64")); + + OP_REQUIRES_OK(context, context->ConstantInputAsIntScalar(2, &axis)); + const auto params_dims = input_shape.dims(); + if (axis < 0) { + axis += params_dims; + } + OP_REQUIRES( + context, 0 <= axis && axis < params_dims, + errors::InvalidArgument("Expected axis in the range [", -params_dims, + ", ", params_dims, "), but got ", axis)); } - OP_REQUIRES( - context, 0 <= axis && axis < params_dims, - errors::InvalidArgument("Expected axis in the range [", -params_dims, - ", ", params_dims, "), but got ", axis)); + + DataType index_type = input_type(1); + OP_REQUIRES(context, index_type == DT_INT32 || index_type == DT_INT64, + errors::InvalidArgument("indices must be int32 or int64")); + + xla::ComputationDataHandle gather; + OP_REQUIRES_OK( + context, XlaGather(input, input_shape, indices, indices_shape, axis, + /*indices_are_nd=*/false, input_type(0), index_type, + builder, &gather)); + context->SetOutput(0, gather); } - DataType index_type = input_type(1); - OP_REQUIRES(context, index_type == DT_INT32 || index_type == DT_INT64, - errors::InvalidArgument("indices must be int32 or int64")); + private: + TF_DISALLOW_COPY_AND_ASSIGN(GatherOp); +}; - xla::ComputationDataHandle gather = XlaComputeGatherDynamicSlice( - context, input, input_shape, indices, indices_shape, axis, input_type(0), - index_type, builder); - context->SetOutput(0, gather); -} +REGISTER_XLA_OP(Name("Gather"), GatherOp); +REGISTER_XLA_OP(Name("GatherV2").CompileTimeConstInput("axis"), GatherOp); + +class GatherNdOp : public XlaOpKernel { + public: + explicit GatherNdOp(OpKernelConstruction* context) : XlaOpKernel(context) {} + + void Compile(XlaOpKernelContext* context) override { + DataType params_type = context->input_type(0); + DataType indices_type = context->input_type(1); + + TensorShape params_shape = context->InputShape(0); + TensorShape indices_shape = context->InputShape(1); + OP_REQUIRES(context, TensorShapeUtils::IsVectorOrHigher(params_shape), + errors::InvalidArgument("params must be at least a vector")); + OP_REQUIRES(context, TensorShapeUtils::IsVectorOrHigher(indices_shape), + errors::InvalidArgument("indices must be at least a vector")); + const int64 num_axes = indices_shape.dim_size(indices_shape.dims() - 1); + OP_REQUIRES( + context, num_axes <= params_shape.dims(), + errors::InvalidArgument( + "index innermost dimension length must be <= params rank; saw: ", + indices_shape.dim_size(indices_shape.dims() - 1), " vs. ", + params_shape.dims())); + + xla::ComputationBuilder* builder = context->builder(); + auto params = context->Input(0); + auto indices = context->Input(1); + xla::ComputationDataHandle gather; + OP_REQUIRES_OK(context, XlaGather(params, params_shape, indices, + indices_shape, /*axis=*/0, + /*indices_are_nd=*/true, params_type, + indices_type, builder, &gather)); + context->SetOutput(0, gather); + } +}; -REGISTER_XLA_OP(Name("Gather"), GatherOpDynamicSlice); -REGISTER_XLA_OP(Name("GatherV2").CompileTimeConstInput("axis"), - GatherOpDynamicSlice); +REGISTER_XLA_OP(Name("GatherNd"), GatherNdOp); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/gather_op.h b/tensorflow/compiler/tf2xla/kernels/gather_op.h deleted file mode 100644 index df86e1fcdd1a4860ed7ee0c5017d25ccf9d227ea..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/tf2xla/kernels/gather_op.h +++ /dev/null @@ -1,41 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Declaration of the Gather Op using the XLA dynamic slice implementation. - -#ifndef TENSORFLOW_COMPILER_TF2XLA_KERNELS_GATHER_OP_H_ -#define TENSORFLOW_COMPILER_TF2XLA_KERNELS_GATHER_OP_H_ - -#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" -#include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/util/bcast.h" - -namespace tensorflow { - -class GatherOpDynamicSlice : public XlaOpKernel { - public: - explicit GatherOpDynamicSlice(OpKernelConstruction* context); - - void Compile(XlaOpKernelContext* context) override; - - private: - TF_DISALLOW_COPY_AND_ASSIGN(GatherOpDynamicSlice); -}; - -} // namespace tensorflow - -#endif // TENSORFLOW_COMPILER_TF2XLA_KERNELS_GATHER_OP_H_ diff --git a/tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h b/tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h index 2c80395c56d73adad7dc1679ba6423fbe103605a..bd8b92c22d71fe89ab8951ec79f411feef6505e3 100644 --- a/tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h +++ b/tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h @@ -30,11 +30,16 @@ namespace tensorflow { // shape input_shape) keyed on indices (of shape indices_shape). // // index_type must be must be DT_INT32 or DT_INT64. -xla::ComputationDataHandle XlaComputeGatherDynamicSlice( - XlaOpKernelContext* ctx, const xla::ComputationDataHandle& input, - const TensorShape& input_shape, const xla::ComputationDataHandle& indices, - const TensorShape& indices_shape, int64 axis, DataType dtype, - DataType index_type, xla::ComputationBuilder* builder); +// If `indices_are_nd` is true, the last dimension of `indices` are treated as +// a multidimensional index values. Otherwise, `indices` is treated as a tensor +// of scalar indices. +Status XlaGather(const xla::ComputationDataHandle& input, + const TensorShape& input_shape, + const xla::ComputationDataHandle& indices, + TensorShape indices_shape, int64 axis, bool indices_are_nd, + DataType dtype, DataType index_type, + xla::ComputationBuilder* builder, + xla::ComputationDataHandle* gather_output); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc b/tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..faa415a97b053b4b11d015fefcd430210b98118a --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc @@ -0,0 +1,98 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/core/framework/tensor_shape.h" + +namespace tensorflow { +namespace { + +class MatrixBandPartOp : public XlaOpKernel { + public: + explicit MatrixBandPartOp(OpKernelConstruction* context) + : XlaOpKernel(context) {} + + void Compile(XlaOpKernelContext* context) override { + const TensorShape input_shape = context->InputShape(0); + // Preliminary validation of sizes. + OP_REQUIRES(context, TensorShapeUtils::IsMatrixOrHigher(input_shape), + errors::InvalidArgument( + "input must be at least 2-dim, received shape: ", + input_shape.DebugString())); + + const TensorShape num_lower_in_shape = context->InputShape(1); + OP_REQUIRES(context, TensorShapeUtils::IsScalar(num_lower_in_shape), + errors::InvalidArgument("num_lower must be scalar, got shape ", + num_lower_in_shape.DebugString())); + + const TensorShape num_upper_in_shape = context->InputShape(2); + OP_REQUIRES(context, TensorShapeUtils::IsScalar(num_upper_in_shape), + errors::InvalidArgument("num_upper must be scalar, got shape ", + num_upper_in_shape.DebugString())); + + xla::ComputationBuilder* builder = context->builder(); + xla::ComputationDataHandle input = context->Input(0); + xla::ComputationDataHandle num_lower = context->Input(1); + xla::ComputationDataHandle num_upper = context->Input(2); + DataType input_type = context->input_type(0); + DataType index_type = context->input_type(1); + + TensorShape batch_shape = input_shape; + batch_shape.RemoveLastDims(2); + const int64 m = input_shape.dim_size(input_shape.dims() - 2); + const int64 n = input_shape.dim_size(input_shape.dims() - 1); + + // Compute 'offset', which is how many diagonals we are above/below the + // diagonal. + xla::ComputationDataHandle iota_m; + OP_REQUIRES_OK(context, XlaHelpers::Iota(builder, index_type, m, &iota_m)); + + xla::ComputationDataHandle iota_n; + OP_REQUIRES_OK(context, XlaHelpers::Iota(builder, index_type, n, &iota_n)); + + auto offset = builder->Sub(builder->Broadcast(iota_n, {m}), iota_m, + /*broadcast_dimensions=*/{0}); + + // If num_lower or num_upper are negative, include all lower/upper + // diagonals. + auto zero_index = XlaHelpers::Zero(builder, index_type); + num_lower = builder->Select( + builder->Lt(num_lower, zero_index), + XlaHelpers::IntegerLiteral(builder, index_type, m), num_lower); + num_upper = builder->Select( + builder->Lt(num_upper, zero_index), + XlaHelpers::IntegerLiteral(builder, index_type, n), num_upper); + + auto indicator = builder->And(builder->Le(builder->Neg(num_lower), offset), + builder->Le(offset, num_upper)); + indicator = builder->Broadcast(indicator, batch_shape.dim_sizes()); + + auto zero_input = XlaHelpers::Zero(builder, input_type); + auto output = builder->Select( + indicator, input, + builder->Broadcast(zero_input, input_shape.dim_sizes())); + + context->SetOutput(0, output); + } + + private: + TF_DISALLOW_COPY_AND_ASSIGN(MatrixBandPartOp); +}; +REGISTER_XLA_OP(Name("MatrixBandPart"), MatrixBandPartOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc b/tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..b2940bdcff75a087c914fdad0cb2426276e41aff --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc @@ -0,0 +1,93 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" + +namespace tensorflow { + +class MatrixSetDiagOp : public XlaOpKernel { + public: + explicit MatrixSetDiagOp(OpKernelConstruction* context) + : XlaOpKernel(context) {} + + void Compile(XlaOpKernelContext* context) override { + const TensorShape input_shape = context->InputShape(0); + const TensorShape diag_shape = context->InputShape(1); + + const int rank = input_shape.dims(); + + // Preliminary validation of sizes. + OP_REQUIRES(context, TensorShapeUtils::IsMatrixOrHigher(input_shape), + errors::InvalidArgument( + "input must be at least 2-dim, received shape: ", + input_shape.DebugString())); + + // Check to make sure the last dimension of diag is equal to the smaller of + // the last two dimensions of input. + const int64 m = input_shape.dim_size(rank - 2); + const int64 n = input_shape.dim_size(rank - 1); + const int64 min_dim = std::min(m, n); + + TensorShape batch_shape = input_shape; + batch_shape.RemoveLastDims(2); + + TensorShape expected_diag_shape = batch_shape; + expected_diag_shape.AddDim(min_dim); + OP_REQUIRES(context, expected_diag_shape == diag_shape, + errors::InvalidArgument( + "must have diagonal.shape == input.shape[:-2] + " + "min(input.shape[-2:]), but received input shape: ", + input_shape.DebugString(), + " and diagonal shape: ", diag_shape.DebugString())); + + xla::ComputationBuilder* builder = context->builder(); + xla::ComputationDataHandle input = context->Input(0); + xla::ComputationDataHandle diag = context->Input(1); + + auto zero = XlaHelpers::Zero(builder, context->input_type(0)); + + // Create an indicator tensor that is true only on the diagonal. + xla::ComputationDataHandle iota_m; + OP_REQUIRES_OK(context, XlaHelpers::Iota(builder, DT_INT32, m, &iota_m)); + xla::ComputationDataHandle iota_n; + OP_REQUIRES_OK(context, XlaHelpers::Iota(builder, DT_INT32, n, &iota_n)); + auto indicator = builder->Eq(iota_m, + builder->Broadcast(iota_n, {m}), + /*broadcast_dimensions=*/{0}); + indicator = builder->Broadcast(indicator, batch_shape.dim_sizes()); + + // Broadcast diag up to the input shape. Use an implicit broadcast (Add) + // because we need to broadcast on the right. + std::vector diag_broadcast_dims(rank - 1); + std::iota(diag_broadcast_dims.begin(), diag_broadcast_dims.end(), 0); + if (min_dim != m) { + diag_broadcast_dims.back() = rank - 1; + } + diag = builder->Add(diag, builder->Broadcast(zero, input_shape.dim_sizes()), + /*broadcast_dimensions=*/diag_broadcast_dims); + + auto output = builder->Select(indicator, diag, input); + context->SetOutput(0, output); + } + + private: + TF_DISALLOW_COPY_AND_ASSIGN(MatrixSetDiagOp); +}; + +REGISTER_XLA_OP(Name("MatrixSetDiag"), MatrixSetDiagOp); + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/matrix_triangular_solve_op.cc b/tensorflow/compiler/tf2xla/kernels/matrix_triangular_solve_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..eaed93146460de5a6e8328432302cc75bf36a534 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/matrix_triangular_solve_op.cc @@ -0,0 +1,50 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/lib/triangular_solve.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" + +namespace tensorflow { +namespace { + +class MatrixTriangularSolveOp : public XlaOpKernel { + public: + explicit MatrixTriangularSolveOp(OpKernelConstruction* ctx) + : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("lower", &lower_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("adjoint", &adjoint_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + auto result = TriangularSolve( + ctx->builder(), ctx->Input(0), ctx->Input(1), /*left_side=*/true, + /*lower=*/lower_, /*transpose_a=*/adjoint_, /*conjugate_a=*/adjoint_); + if (!result.ok()) { + ctx->SetStatus(result.status()); + return; + } + ctx->SetOutput(0, result.ValueOrDie()); + } + + private: + bool lower_; + bool adjoint_; +}; + +REGISTER_XLA_OP(Name("MatrixTriangularSolve"), MatrixTriangularSolveOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc index d092e2e8d6de76f321d359acfc170092fdbb49c6..d4fb5dd4e06c7c70591262c0d63a91c383a2a6e0 100644 --- a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc @@ -276,22 +276,44 @@ class MaxPoolGradOp : public XlaOpKernel { public: MaxPoolGradOp(OpKernelConstruction* ctx, int num_spatial_dims) : XlaOpKernel(ctx), num_spatial_dims_(num_spatial_dims) { - OP_REQUIRES_OK(ctx, ctx->GetAttr("ksize", &ksize_)); + if (ctx->num_inputs() == 3) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("ksize", &ksize_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &stride_)); + } + OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_)); + } + + int num_dims() const { return num_spatial_dims_ + 2; } + + void Compile(XlaOpKernelContext* ctx) override { + if (ctx->num_inputs() != 3) { + OP_REQUIRES( + ctx, ctx->num_inputs() == 5, + errors::InvalidArgument("Must supply ksize and stride arguments.")); + const TensorShape ksize_shape = ctx->InputShape(3); + // Validate input sizes. + OP_REQUIRES(ctx, TensorShapeUtils::IsVector(ksize_shape), + errors::InvalidArgument("ksize must be a vector, not shape ", + ksize_shape.DebugString())); + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(3, &ksize_)); + + const TensorShape stride_shape = ctx->InputShape(4); + // Validate input sizes. + OP_REQUIRES(ctx, TensorShapeUtils::IsVector(stride_shape), + errors::InvalidArgument("stride must be a vector, not shape ", + stride_shape.DebugString())); + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(4, &stride_)); + } + OP_REQUIRES(ctx, ksize_.size() == num_dims(), errors::InvalidArgument("Sliding window ksize field must " "specify ", num_dims(), " dimensions")); - OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &stride_)); OP_REQUIRES(ctx, stride_.size() == num_dims(), errors::InvalidArgument("Sliding window strides field must " "specify ", num_dims(), " dimensions")); - OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_)); - } - - int num_dims() const { return num_spatial_dims_ + 2; } - void Compile(XlaOpKernelContext* ctx) override { const TensorShape tensor_in_shape = ctx->InputShape(0); const TensorShape tensor_out_shape = ctx->InputShape(1); const TensorShape out_backprop_shape = ctx->InputShape(2); @@ -348,6 +370,10 @@ class MaxPool2DGradOp : public MaxPoolGradOp { } }; REGISTER_XLA_OP(Name("MaxPoolGrad"), MaxPool2DGradOp); +REGISTER_XLA_OP(Name("MaxPoolGradV2") + .CompileTimeConstInput("ksize") + .CompileTimeConstInput("strides"), + MaxPool2DGradOp); class MaxPool3DGradOp : public MaxPoolGradOp { public: diff --git a/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc b/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..6bc5d3adb091cd238974c5b69b7a2f8fe639cc68 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc @@ -0,0 +1,182 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/core/framework/tensor_shape.h" + +namespace tensorflow { +namespace { + +class ReverseSequenceOp : public XlaOpKernel { + public: + explicit ReverseSequenceOp(OpKernelConstruction* context) + : XlaOpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("batch_dim", &batch_dim_)); + OP_REQUIRES_OK(context, context->GetAttr("seq_dim", &seq_dim_)); + } + + void Compile(XlaOpKernelContext* context) override { + const TensorShape input_shape = context->InputShape(0); + const TensorShape seq_lens_shape = context->InputShape(1); + + OP_REQUIRES(context, TensorShapeUtils::IsVector(seq_lens_shape), + errors::InvalidArgument("seq_lens input must be 1-dim, not ", + seq_lens_shape.dims())); + OP_REQUIRES(context, batch_dim_ != seq_dim_, + errors::InvalidArgument("batch_dim == seq_dim == ", seq_dim_)); + OP_REQUIRES( + context, seq_dim_ < input_shape.dims(), + errors::InvalidArgument("seq_dim must be < input.dims()", "( ", + seq_dim_, " vs. ", input_shape.dims(), ")")); + OP_REQUIRES( + context, batch_dim_ < input_shape.dims(), + errors::InvalidArgument("batch_dim must be < input.dims()", "( ", + batch_dim_, " vs. ", input_shape.dims(), ")")); + OP_REQUIRES( + context, + seq_lens_shape.num_elements() == input_shape.dim_size(batch_dim_), + errors::InvalidArgument("len(seq_lens) != input.dims(", batch_dim_, + "), ", "(", seq_lens_shape.num_elements(), + " vs. ", input_shape.dim_size(batch_dim_))); + + xla::ComputationBuilder* builder = context->builder(); + const auto input = context->Input(0); + const auto seq_lens = context->Input(1); + + const int64 batch_size = input_shape.dim_size(batch_dim_); + + const DataType input_type = context->input_type(0); + const DataType seq_lens_type = context->input_type(1); + const int64 max_seq_len = input_shape.dim_size(seq_dim_); + + xla::Shape input_xla_shape; + OP_REQUIRES_OK(context, TensorShapeToXLAShape(input_type, input_shape, + &input_xla_shape)); + xla::Shape seq_lens_xla_shape; + OP_REQUIRES_OK(context, TensorShapeToXLAShape(seq_lens_type, seq_lens_shape, + &seq_lens_xla_shape)); + + const auto tuple_shape = xla::ShapeUtil::MakeTupleShape({ + xla::ShapeUtil::MakeShape(seq_lens_xla_shape.element_type(), {}), + seq_lens_xla_shape, + input_xla_shape, + }); + + // For each entry in the batch, reverse the sequence. + // TODO(b/65689298): generalize the Map() operator to non-scalar cases and + // use it here, instead of a While loop. + + // Condition: lambda (i, _, _): i < batch_size + auto condition_builder = + builder->CreateSubBuilder("reverse_sequence_condition"); + { + auto param = condition_builder->Parameter(0, tuple_shape, "param"); + auto i = condition_builder->GetTupleElement(param, 0); + condition_builder->Lt( + i, XlaHelpers::IntegerLiteral(condition_builder.get(), seq_lens_type, + batch_size)); + } + auto condition = condition_builder->Build(); + OP_REQUIRES_OK(context, condition.status()); + + auto body_builder = builder->CreateSubBuilder("reverse_sequence_body"); + { + auto param = body_builder->Parameter(0, tuple_shape, "param"); + auto i = body_builder->GetTupleElement(param, 0); + auto seq_lens = body_builder->GetTupleElement(param, 1); + auto output = body_builder->GetTupleElement(param, 2); + + // seq_len is the sequence length of the current batch element (rank 1) + auto seq_len = body_builder->DynamicSlice( + seq_lens, body_builder->Reshape(i, {1}), {1}); + + // Indices is the offset of the batch element in the input. + auto indices = body_builder->Broadcast( + XlaHelpers::Zero(body_builder.get(), seq_lens_type), + {input_shape.dims()}); + indices = body_builder->DynamicUpdateSlice( + indices, body_builder->Reshape(i, {1}), + body_builder->Reshape( + XlaHelpers::IntegerLiteral(body_builder.get(), seq_lens_type, + batch_dim_), + {1})); + + // slice_indices is the offset of the start of the reversed sequence in + // the input. + auto slice_indices = body_builder->DynamicUpdateSlice( + indices, + body_builder->Sub(XlaHelpers::IntegerLiteral( + body_builder.get(), seq_lens_type, max_seq_len), + seq_len), + body_builder->Reshape( + XlaHelpers::IntegerLiteral(body_builder.get(), seq_lens_type, + seq_dim_), + {1})); + + // Slice out the reversed sequence. The slice will overflow the end of the + // sequence, and the contents of the overflow are implementation-defined. + // However, we will mask off these elements and replace them with elements + // from the original input so their values do not matter. + TensorShape slice_shape = input_shape; + slice_shape.set_dim(batch_dim_, 1); + auto slice = body_builder->DynamicSlice(output, slice_indices, + slice_shape.dim_sizes()); + + // Shift the reversed sequence to the left. + output = body_builder->DynamicUpdateSlice(output, slice, indices); + + body_builder->Tuple( + {body_builder->Add( + i, XlaHelpers::One(body_builder.get(), seq_lens_type)), + seq_lens, output}); + } + auto body = body_builder->Build(); + OP_REQUIRES_OK(context, body.status()); + + auto loop_output = builder->While( + condition.ValueOrDie(), body.ValueOrDie(), + builder->Tuple({XlaHelpers::Zero(builder, seq_lens_type), seq_lens, + builder->Rev(input, {seq_dim_})})); + auto output = builder->GetTupleElement(loop_output, 2); + + // Mask out elements after the sequence length. + xla::ComputationDataHandle iota; + OP_REQUIRES_OK( + context, XlaHelpers::Iota(builder, seq_lens_type, max_seq_len, &iota)); + std::vector dims(input_shape.dims(), 1); + dims[batch_dim_] = batch_size; + auto mask = builder->Lt(iota, builder->Reshape(seq_lens, dims), {seq_dim_}); + + // Broadcast the mask up to the input shape. + mask = + builder->Or(mask, builder->Broadcast(builder->ConstantR0(false), + input_shape.dim_sizes())); + + output = builder->Select(mask, output, input); + context->SetOutput(0, output); + } + + private: + int32 batch_dim_; + int32 seq_dim_; +}; + +REGISTER_XLA_OP(Name("ReverseSequence"), ReverseSequenceOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/stack_ops.cc b/tensorflow/compiler/tf2xla/kernels/stack_ops.cc index d77fb768ef4d124c403a1dc9b321c4f29571d806..1a78c7ab9be701d3d02285ed21604f0f856b3f1f 100644 --- a/tensorflow/compiler/tf2xla/kernels/stack_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/stack_ops.cc @@ -77,10 +77,8 @@ Status MaybeInitializeStack(xla::ComputationBuilder* builder, // Stack has not been initialized. xla::ComputationDataHandle zero = XlaHelpers::Zero(builder, resource->type()); - TF_RETURN_IF_ERROR(resource->SetValue( - dtype, - builder->Tuple({builder->Broadcast(zero, stack_shape.dim_sizes()), - builder->ConstantR0(0)}))); + TF_RETURN_IF_ERROR(resource->SetTypeAndShape(dtype, elem_shape)); + TF_RETURN_IF_ERROR(resource->SetZeroValue(builder)); } else { // Checks the expected shape matches the actual shape. TensorShape actual_shape; @@ -119,8 +117,8 @@ class StackOp : public XlaOpKernel { string name = strings::StrCat("Stack: ", stack_name_); OP_REQUIRES_OK( ctx, xc.CreateResource(XlaResource::kStack, -1, std::move(name), dtype_, - value, &resource)); - resource->set_tensor_array_size(size); + TensorShape(), value, /*tensor_array_size=*/size, + /*tensor_array_gradients=*/{}, &resource)); ctx->SetResourceOutput(0, resource); } @@ -164,11 +162,9 @@ class StackPushOp : public XlaOpKernel { // TODO(phawkins): We don't check the index is in bounds --- there is no // error mechanism in XLA. - OP_REQUIRES_OK( - ctx, - resource->SetValue( - dtype_, b->Tuple({b->DynamicUpdateSlice(ta, update, start_indices), - b->Add(index, b->ConstantR0(1))}))); + OP_REQUIRES_OK(ctx, resource->SetValue(b->Tuple( + {b->DynamicUpdateSlice(ta, update, start_indices), + b->Add(index, b->ConstantR0(1))}))); ctx->SetOutput(0, value); } @@ -208,7 +204,7 @@ class StackPopOp : public XlaOpKernel { xla::ComputationDataHandle index = b->GetTupleElement(state, 1); index = b->Sub(index, b->ConstantR0(1)); - OP_REQUIRES_OK(ctx, resource->SetValue(dtype_, b->Tuple({ta, index}))); + OP_REQUIRES_OK(ctx, resource->SetValue(b->Tuple({ta, index}))); // start_indices of the DynamicSlice are [index, 0, 0, ..., 0]. auto start_indices = diff --git a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc index f0525a5fb86d6d6f0aae954a916186cffc7f3a9f..91c169428c7a88a8d107a97445aeea999946e3e9 100644 --- a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc @@ -231,6 +231,7 @@ class StridedSliceAssignOp : public XlaOpKernel { OP_REQUIRES_OK(ctx, ctx->GetAttr("new_axis_mask", &new_axis_mask_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("shrink_axis_mask", &shrink_axis_mask_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("Index", &index_type_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_)); } void Compile(XlaOpKernelContext* ctx) override { @@ -252,9 +253,9 @@ class StridedSliceAssignOp : public XlaOpKernel { OP_REQUIRES_OK(ctx, LiteralToHostTensor(strides_literal, index_type_, &strides_tensor)); - DataType lhs_type; TensorShape lhs_shape; - OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &lhs_type, &lhs_shape)); + xla::ComputationDataHandle lhs; + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &lhs_shape, &lhs)); const TensorShape rhs_shape = ctx->InputShape(4); @@ -282,9 +283,6 @@ class StridedSliceAssignOp : public XlaOpKernel { " does not match r-value shape ", rhs_shape.DebugString(), ". Automatic broadcasting not yet implemented.")); - xla::ComputationDataHandle lhs; - OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &lhs)); - xla::ComputationDataHandle rhs = ctx->Input(4); gtl::InlinedVector dimensions_to_reverse; @@ -320,13 +318,14 @@ class StridedSliceAssignOp : public XlaOpKernel { lhs, rhs, ctx->builder()->ConstantR1(slice_begin)); } - OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, lhs_type, lhs)); + OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, lhs)); } private: int32 begin_mask_, end_mask_; int32 ellipsis_mask_, new_axis_mask_, shrink_axis_mask_; DataType index_type_; + DataType dtype_; }; REGISTER_XLA_OP(Name("ResourceStridedSliceAssign") diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc index 9224072a3cb92b8ff0e99c79e568ca1a76966ed6..000b50af6bd86b7268c016865fb0856c16053ece 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc @@ -62,15 +62,13 @@ Status MaybeInitializeTensorArray(xla::ComputationBuilder* builder, TF_RET_CHECK(resource->tensor_array_size() >= 0) << resource->name() << " size " << resource->tensor_array_size(); - TensorShape ta_shape; - ta_shape.AddDim(resource->tensor_array_size()); - ta_shape.AppendShape(elem_shape); if (!resource->initialized()) { xla::ComputationDataHandle zero = XlaHelpers::Zero(builder, resource->type()); - TF_RETURN_IF_ERROR(resource->SetValue( - dtype, builder->Broadcast(zero, ta_shape.dim_sizes()))); + + TF_RETURN_IF_ERROR(resource->SetTypeAndShape(dtype, elem_shape)); + TF_RETURN_IF_ERROR(resource->SetZeroValue(builder)); } else { // Checks the elem_shape matches the TensorArray shape. auto shape_or_status = builder->GetShape(resource->value()); @@ -80,6 +78,10 @@ Status MaybeInitializeTensorArray(xla::ComputationBuilder* builder, TensorShape shape; TF_RETURN_IF_ERROR( XLAShapeToTensorShape(*shape_or_status.ValueOrDie(), &shape)); + + TensorShape ta_shape; + ta_shape.AddDim(resource->tensor_array_size()); + ta_shape.AppendShape(elem_shape); if (ta_shape != shape) { return errors::InvalidArgument( "Mismatched TensorArray sizes: ", ta_shape.DebugString(), " vs ", @@ -114,10 +116,8 @@ Status CheckTensorArrayIsInitialized(const string& op_name, Status GetTensorArrayShape(const XlaResource* resource, xla::ComputationBuilder* builder, TensorShape* shape) { - TF_RETURN_IF_ERROR(resource->GetShape(builder, shape)); - if (shape->dims() < 1) { - return errors::InvalidArgument("TensorArray rank must be >= 1"); - } + *shape = resource->shape(); + shape->InsertDim(0, resource->tensor_array_size()); return Status::OK(); } @@ -160,8 +160,8 @@ class TensorArrayOp : public XlaOpKernel { // Initializes the TensorArray value if we know the element shape. // Otherwise, defer initialization to the first write. xla::ComputationDataHandle value; + TensorShape shape; if (element_shape_.IsFullyDefined()) { - TensorShape shape; CHECK(element_shape_.AsTensorShape(&shape)); TensorShape ta_shape; ta_shape.AddDim(size); @@ -175,8 +175,8 @@ class TensorArrayOp : public XlaOpKernel { string name = strings::StrCat("TensorArray: ", tensor_array_name_); OP_REQUIRES_OK( ctx, xc.CreateResource(XlaResource::kTensorArray, -1, std::move(name), - dtype_, value, &var)); - var->set_tensor_array_size(size); + dtype_, shape, value, /*tensor_array_size=*/size, + /*tensor_array_gradients=*/{}, &var)); ctx->SetResourceOutput(0, var); Tensor flow(DT_FLOAT, TensorShape({})); @@ -230,7 +230,7 @@ class TensorArrayWriteOp : public XlaOpKernel { xla::ComputationDataHandle written = DynamicAddSlice(b, ta, update, slice_shape.dim_sizes(), start_indices); - OP_REQUIRES_OK(ctx, resource->SetValue(dtype_, written)); + OP_REQUIRES_OK(ctx, resource->SetValue(written)); ctx->SetOutput(0, flow); } @@ -337,8 +337,11 @@ class TensorArrayGatherOp : public XlaOpKernel { } } - xla::ComputationDataHandle gather = XlaComputeGatherDynamicSlice( - ctx, ta, ta_shape, indices, indices_shape, 0, dtype_, index_type, b); + xla::ComputationDataHandle gather; + OP_REQUIRES_OK( + ctx, + XlaGather(ta, ta_shape, indices, indices_shape, /*axis=*/0, + /*indices_are_nd=*/false, dtype_, index_type, b, &gather)); ctx->SetOutput(0, gather); } @@ -421,7 +424,7 @@ class TensorArrayScatterOp : public XlaOpKernel { } } - OP_REQUIRES_OK(ctx, resource->SetValue(dtype_, ta)); + OP_REQUIRES_OK(ctx, resource->SetValue(ta)); ctx->SetOutput(0, flow); } @@ -525,9 +528,8 @@ class TensorArraySplitOp : public XlaOpKernel { value_shape.DebugString(), " vs. ", ta_shape.DebugString())); - OP_REQUIRES_OK( - ctx, resource->SetValue( - dtype_, b->Add(ta, b->Reshape(value, ta_shape.dim_sizes())))); + OP_REQUIRES_OK(ctx, resource->SetValue(b->Add( + ta, b->Reshape(value, ta_shape.dim_sizes())))); ctx->SetOutput(0, flow); } diff --git a/tensorflow/compiler/tf2xla/kernels/training_ops.cc b/tensorflow/compiler/tf2xla/kernels/training_ops.cc index 5534d1bfa1338c7fe3647cd6aa281c4907dfdf8c..f750f7003be288461f5f10455e58932d1b4e4524 100644 --- a/tensorflow/compiler/tf2xla/kernels/training_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/training_ops.cc @@ -32,9 +32,24 @@ class ResourceApplyGradientDescent : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { xla::ComputationDataHandle handle; xla::ComputationBuilder* b = ctx->builder(); - OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &handle)); + DataType type = ctx->input_type(1); + TensorShape var_shape; + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &var_shape, &handle)); + + TensorShape alpha_shape = ctx->InputShape(1); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(alpha_shape), + errors::InvalidArgument("alpha is not a scalar: ", + alpha_shape.DebugString())); + + TensorShape delta_shape = ctx->InputShape(2); + OP_REQUIRES( + ctx, var_shape.IsSameSize(delta_shape), + errors::InvalidArgument("var and delta do not have the same shape: ", + var_shape.DebugString(), " vs ", + delta_shape.DebugString())); + handle = b->Sub(handle, b->Mul(ctx->Input(1), ctx->Input(2))); - OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, ctx->input_type(1), handle)); + OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, handle)); } }; REGISTER_XLA_OP( @@ -52,18 +67,10 @@ class ResourceApplyMomentum : public XlaOpKernel { DataType type = ctx->input_type(2); - DataType var_type, accum_type; TensorShape var_shape, accum_shape; - OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &var_type, &var_shape)); - OP_REQUIRES_OK(ctx, - ctx->GetVariableTypeAndShape(1, &accum_type, &accum_shape)); - - OP_REQUIRES( - ctx, type == var_type && type == accum_type, - errors::InvalidArgument( - "Types of variable arguments to ResourceApplyMomentum must match: ", - DataTypeString(type), " vs. ", DataTypeString(var_type), " and ", - DataTypeString(accum_type))); + xla::ComputationDataHandle var, accum; + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &var_shape, &var)); + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, type, &accum_shape, &accum)); OP_REQUIRES(ctx, var_shape.IsSameSize(accum_shape), errors::InvalidArgument( @@ -86,10 +93,6 @@ class ResourceApplyMomentum : public XlaOpKernel { errors::InvalidArgument("momentum is not a scalar: ", momentum_shape.DebugString())); - xla::ComputationDataHandle var, accum; - OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &var)); - OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, &accum)); - xla::ComputationDataHandle lr = ctx->Input(2); xla::ComputationDataHandle grad = ctx->Input(3); xla::ComputationDataHandle momentum = ctx->Input(4); @@ -122,18 +125,10 @@ class ResourceApplyAdagrad : public XlaOpKernel { DataType type = ctx->input_type(2); - DataType var_type, accum_type; TensorShape var_shape, accum_shape; - OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &var_type, &var_shape)); - OP_REQUIRES_OK(ctx, - ctx->GetVariableTypeAndShape(1, &accum_type, &accum_shape)); - - OP_REQUIRES( - ctx, type == var_type && type == accum_type, - errors::InvalidArgument( - "Types of variable arguments to ResourceApplyAdagrad must match: ", - DataTypeString(type), " vs. ", DataTypeString(var_type), " and ", - DataTypeString(accum_type))); + xla::ComputationDataHandle var, accum; + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &var_shape, &var)); + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, type, &accum_shape, &accum)); OP_REQUIRES(ctx, var_shape.IsSameSize(accum_shape), errors::InvalidArgument( @@ -151,9 +146,6 @@ class ResourceApplyAdagrad : public XlaOpKernel { "var and grad do not have the same shape", var_shape.DebugString(), " ", grad_shape.DebugString())); - xla::ComputationDataHandle var, accum; - OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &var)); - OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, &accum)); xla::ComputationDataHandle lr = ctx->Input(2); xla::ComputationDataHandle grad = ctx->Input(3); @@ -175,18 +167,11 @@ class ResourceApplyAdam : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - DataType var_type, m_type, v_type; TensorShape var_shape, m_shape, v_shape; - OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &var_type, &var_shape)); - OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(1, &m_type, &m_shape)); - OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(2, &v_type, &v_shape)); - - OP_REQUIRES( - ctx, dtype_ == var_type && dtype_ == m_type && dtype_ == v_type, - errors::InvalidArgument( - "Types of variable arguments to ResourceApplyRMSProp must match: ", - DataTypeString(dtype_), " vs. ", DataTypeString(var_type), " vs. ", - DataTypeString(m_type), " vs. ", DataTypeString(v_type))); + xla::ComputationDataHandle var, m, v; + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &var_shape, &var)); + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, dtype_, &m_shape, &m)); + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, dtype_, &v_shape, &v)); TensorShape beta1_power_shape = ctx->InputShape(3); TensorShape beta2_power_shape = ctx->InputShape(4); @@ -228,10 +213,6 @@ class ResourceApplyAdam : public XlaOpKernel { "var and grad do not have the same shape", var_shape.DebugString(), " ", grad_shape.DebugString())); - xla::ComputationDataHandle var, m, v; - OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &var)); - OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, &m)); - OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, &v)); xla::ComputationDataHandle beta1_power = ctx->Input(3); xla::ComputationDataHandle beta2_power = ctx->Input(4); xla::ComputationDataHandle lr = ctx->Input(5); @@ -278,18 +259,11 @@ class ResourceApplyRMSProp : public XlaOpKernel { DataType type = ctx->input_type(3); - DataType var_type, ms_type, mom_type; TensorShape var_shape, ms_shape, mom_shape; - OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &var_type, &var_shape)); - OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(1, &ms_type, &ms_shape)); - OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(2, &mom_type, &mom_shape)); - - OP_REQUIRES( - ctx, type == var_type && type == ms_type && type == mom_type, - errors::InvalidArgument( - "Types of variable arguments to ResourceApplyRMSProp must match: ", - DataTypeString(type), " vs. ", DataTypeString(var_type), " vs. ", - DataTypeString(ms_type), " vs. ", DataTypeString(mom_type))); + xla::ComputationDataHandle var, ms, mom; + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &var_shape, &var)); + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, type, &ms_shape, &ms)); + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, type, &mom_shape, &mom)); TensorShape lr_shape = ctx->InputShape(3); OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_shape), @@ -323,10 +297,6 @@ class ResourceApplyRMSProp : public XlaOpKernel { "var and grad do not have the same shape", var_shape.DebugString(), " ", grad_shape.DebugString())); - xla::ComputationDataHandle var, ms, mom; - OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &var)); - OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, &ms)); - OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, &mom)); xla::ComputationDataHandle lr = ctx->Input(3); xla::ComputationDataHandle rho = ctx->Input(4); xla::ComputationDataHandle momentum = ctx->Input(5); @@ -373,20 +343,11 @@ void CompileFtrl(XlaOpKernelContext* ctx, DataType dtype, bool has_l2_shrinkage) { xla::ComputationBuilder* b = ctx->builder(); - DataType var_type, accum_type, linear_type; TensorShape var_shape, accum_shape, linear_shape; - OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &var_type, &var_shape)); - OP_REQUIRES_OK(ctx, - ctx->GetVariableTypeAndShape(1, &accum_type, &accum_shape)); - OP_REQUIRES_OK(ctx, - ctx->GetVariableTypeAndShape(2, &linear_type, &linear_shape)); - - OP_REQUIRES( - ctx, dtype == var_type && dtype == accum_type && dtype == linear_type, - errors::InvalidArgument( - "Types of variable arguments to ResourceApplyFtrlV2 must match: ", - DataTypeString(dtype), " vs. ", DataTypeString(var_type), " and ", - DataTypeString(accum_type), " and ", DataTypeString(linear_type))); + xla::ComputationDataHandle var, accum, linear; + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype, &var_shape, &var)); + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, dtype, &accum_shape, &accum)); + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, dtype, &linear_shape, &linear)); OP_REQUIRES(ctx, var_shape.IsSameSize(accum_shape), errors::InvalidArgument( @@ -438,10 +399,6 @@ void CompileFtrl(XlaOpKernelContext* ctx, DataType dtype, errors::InvalidArgument("lr_power is not a scalar: ", lr_power_shape.DebugString())); - xla::ComputationDataHandle var, accum, linear; - OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &var)); - OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, &accum)); - OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, &linear)); xla::ComputationDataHandle grad = ctx->Input(3); xla::ComputationDataHandle lr = ctx->Input(4); xla::ComputationDataHandle l1 = ctx->Input(5); diff --git a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc index a266e9013c41b88788dbc99849f01c09f3d61348..0c5ad9e5255ffc3dfcfb83335060ae833937b3ce 100644 --- a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc @@ -50,18 +50,41 @@ XLAJIT_MAKE_UNARY(Conj, b->Conj(x)); // Return x if x>0, otherwise -x. XLAJIT_MAKE_UNARY(Abs, b->Abs(x)); +// acos(x) = 2 * atan(sqrt(1 - x^2) / (1 + x)) +XLAJIT_MAKE_UNARY( + Acos, + b->Mul(XlaHelpers::FloatLiteral(b, input_type(0), 2.0), + b->Atan2(b->Pow(b->Sub(XlaHelpers::One(b, input_type(0)), + b->Mul(x, x)), + XlaHelpers::FloatLiteral(b, input_type(0), 0.5)), + b->Add(XlaHelpers::One(b, input_type(0)), x)))); + // acosh(x) = log(x + sqrt(x^2 - 1)) XLAJIT_MAKE_UNARY( Acosh, b->Log(b->Add(x, b->Pow(b->Sub(b->Mul(x, x), XlaHelpers::One(b, input_type(0))), XlaHelpers::FloatLiteral(b, input_type(0), 0.5))))); + +// asin(x) = 2 * atan(x / (1 + sqrt(1 - x^2))) +XLAJIT_MAKE_UNARY( + Asin, + b->Mul(XlaHelpers::FloatLiteral(b, input_type(0), 2.0), + b->Atan2(x, b->Add(XlaHelpers::One(b, input_type(0)), + b->Pow(b->Sub(XlaHelpers::One(b, input_type(0)), + b->Mul(x, x)), + XlaHelpers::FloatLiteral(b, input_type(0), + 0.5)))))); + // asinh(x) = log(x + sqrt(x^2 + 1)) XLAJIT_MAKE_UNARY( Asinh, b->Log(b->Add(x, b->Pow(b->Add(b->Mul(x, x), XlaHelpers::One(b, input_type(0))), XlaHelpers::FloatLiteral(b, input_type(0), 0.5))))); + +XLAJIT_MAKE_UNARY(Atan, b->Atan2(x, XlaHelpers::One(b, input_type(0)))); + // atanh(x) = 0.5 * log((1 + x) / (1 - x)) XLAJIT_MAKE_UNARY( Atanh, b->Mul(b->Log(b->Div(b->Add(XlaHelpers::One(b, input_type(0)), x), diff --git a/tensorflow/compiler/tf2xla/kernels/variable_ops.cc b/tensorflow/compiler/tf2xla/kernels/variable_ops.cc index 68847ae7a2cb926edd9d29007e24b0db7fb5a75f..71173f5aead47702f0ed9e95b827a6fefd9b7efd 100644 --- a/tensorflow/compiler/tf2xla/kernels/variable_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/variable_ops.cc @@ -33,21 +33,29 @@ class VarIsInitializedOp : public XlaOpKernel { public: explicit VarIsInitializedOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationDataHandle handle; - bool initialized = ctx->ReadVariableInput(0, &handle).ok(); - ctx->SetOutput(0, ctx->builder()->ConstantR0(initialized)); + XlaResource* variable; + OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &variable)); + ctx->SetOutput(0, + ctx->builder()->ConstantR0(variable->initialized())); } }; REGISTER_XLA_OP(Name("VarIsInitializedOp"), VarIsInitializedOp); class ReadVariableOp : public XlaOpKernel { public: - explicit ReadVariableOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + explicit ReadVariableOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_)); + } + void Compile(XlaOpKernelContext* ctx) override { xla::ComputationDataHandle handle; - OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &handle)); + OP_REQUIRES_OK( + ctx, ctx->ReadVariableInput(0, dtype_, /*shape=*/nullptr, &handle)); ctx->SetOutput(0, handle); } + + private: + DataType dtype_; }; REGISTER_XLA_OP(Name("ReadVariableOp"), ReadVariableOp); @@ -65,10 +73,12 @@ class AssignAddVariableOp : public XlaOpKernel { public: explicit AssignAddVariableOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { + DataType type = ctx->input_type(1); xla::ComputationDataHandle handle; - OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &handle)); + OP_REQUIRES_OK(ctx, + ctx->ReadVariableInput(0, type, /*shape=*/nullptr, &handle)); handle = ctx->builder()->Add(handle, ctx->Input(1)); - OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, ctx->input_type(1), handle)); + OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, handle)); } }; REGISTER_XLA_OP( @@ -79,10 +89,12 @@ class AssignSubVariableOp : public XlaOpKernel { public: explicit AssignSubVariableOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { + DataType type = ctx->input_type(1); xla::ComputationDataHandle handle; - OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &handle)); + OP_REQUIRES_OK(ctx, + ctx->ReadVariableInput(0, type, /*shape=*/nullptr, &handle)); handle = ctx->builder()->Sub(handle, ctx->Input(1)); - OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, ctx->input_type(1), handle)); + OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, handle)); } }; REGISTER_XLA_OP( @@ -95,28 +107,21 @@ class ResourceGatherOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { xla::ComputationBuilder* builder = ctx->builder(); - // Get the shape of the resource tensor. - TensorShape resource_shape; - DataType resource_dtype; - OP_REQUIRES_OK( - ctx, ctx->GetVariableTypeAndShape(0, &resource_dtype, &resource_shape)); - - DataType expected_output_dtype = ctx->expected_output_dtype(0); - OP_REQUIRES(ctx, resource_dtype == expected_output_dtype, - errors::InvalidArgument( - "Variable dtype is ", DataTypeString(resource_dtype), - " but expected output dtype is ", - DataTypeString(expected_output_dtype), ".")); + DataType type = ctx->expected_output_dtype(0); + TensorShape resource_shape; xla::ComputationDataHandle resource_handle; - OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &resource_handle)); + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &resource_shape, + &resource_handle)); auto indices = ctx->Input(1); auto indices_shape = ctx->InputShape(1); DataType index_type = ctx->input_type(1); - xla::ComputationDataHandle gather = XlaComputeGatherDynamicSlice( - ctx, resource_handle, resource_shape, indices, indices_shape, 0, - resource_dtype, index_type, builder); + xla::ComputationDataHandle gather; + OP_REQUIRES_OK( + ctx, XlaGather(resource_handle, resource_shape, indices, indices_shape, + /*axis=*/0, /*indices_are_nd=*/false, type, index_type, + builder, &gather)); ctx->SetOutput(0, gather); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/while_op.cc b/tensorflow/compiler/tf2xla/kernels/while_op.cc index 4a711e4d9b7aedb166a8a0ec9fe9ec2390f01b17..0ff1b65ae9179d506e453f98097cd88083eb2be7 100644 --- a/tensorflow/compiler/tf2xla/kernels/while_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/while_op.cc @@ -58,9 +58,8 @@ Status MakeXlaCompilerArgumentsFromInputs( } arg.type = resource->type(); - if (arg.initialized) { - TF_RETURN_IF_ERROR(resource->PackedShape(ctx->builder(), &arg.shape)); - } else { + arg.shape = resource->shape(); + if (!arg.initialized) { *has_uninitialized_vars = true; } arg.tensor_array_size = resource->tensor_array_size(); @@ -70,14 +69,13 @@ Status MakeXlaCompilerArgumentsFromInputs( arg.name = resource->name(); VLOG(2) << " resource " << resource->name() << " type: " << DataTypeString(arg.type) - << " shape: " << xla::ShapeUtil::HumanString(arg.shape) + << " shape: " << arg.shape.DebugString() << " initialized: " << arg.initialized; } else { arg.kind = XlaCompiler::Argument::kParameter; arg.type = ctx->input_type(i); - TF_RETURN_IF_ERROR( - TensorShapeToXLAShape(arg.type, ctx->InputShape(i), &arg.shape)); + arg.shape = ctx->InputShape(i); } } return Status::OK(); @@ -154,17 +152,14 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { XlaCompiler::Argument& arg = arguments[update.input_index]; if (!arg.initialized) { VLOG(2) << "Update shape for argument " << update.input_index << " " - << xla::ShapeUtil::HumanString(update.shape); + << update.shape.DebugString(); arg.initialized = true; - xla::Shape shape = update.shape; - if (!update.tensor_array_gradients_accessed.empty()) { - shape = xla::ShapeUtil::GetTupleElementShape(shape, 0); - } - std::unique_ptr zero = - xla::Literal::CreateFromShape(shape); - OP_REQUIRES_OK(ctx, resource->SetValue( - update.type, builder->ConstantLiteral(*zero))); + arg.shape = update.shape; + OP_REQUIRES_OK(ctx, + resource->SetTypeAndShape(update.type, update.shape)); + + OP_REQUIRES_OK(ctx, resource->SetZeroValue(builder)); } // Add any TensorArray gradients touched by the body to the enclosing @@ -182,9 +177,6 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { for (const auto& gradient : resource->tensor_array_gradients()) { arg.tensor_array_gradients.insert(gradient.first); } - - // Recompute the argument shape. - OP_REQUIRES_OK(ctx, resource->PackedShape(ctx->builder(), &arg.shape)); } // Recompile the body with the "correct" resource shapes. VLOG(1) << "Recompiling body with corrected resource shapes"; @@ -292,13 +284,12 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { OP_REQUIRES_OK(ctx, resource->SetFromPack( arguments[update.input_index].tensor_array_gradients, - builder->GetTupleElement(while_result, pos), - /*reset_initial_values=*/false, builder)); + builder->GetTupleElement(while_result, pos), builder)); } VLOG(2) << "Loop-carried variable: pos: " << update.input_index << " name: " << resource->name() << " modified: " << update.modified << " type: " << DataTypeString(update.type) - << " shape: " << xla::ShapeUtil::HumanString(update.shape); + << " shape: " << update.shape.DebugString(); // Copies the identity of the resource variable from input to output // unchanged, even if the variable was not modified. ctx->op_kernel_context()->set_output( diff --git a/tensorflow/compiler/tf2xla/lib/BUILD b/tensorflow/compiler/tf2xla/lib/BUILD index 21ad21f73737a289390ed1ea767db1078d05b466..d184f59e01788829d0ba97092c14d36e5188e4e8 100644 --- a/tensorflow/compiler/tf2xla/lib/BUILD +++ b/tensorflow/compiler/tf2xla/lib/BUILD @@ -60,6 +60,8 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/client:computation", "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/core:lib", diff --git a/tensorflow/compiler/tf2xla/lib/batch_dot.cc b/tensorflow/compiler/tf2xla/lib/batch_dot.cc index 9b0e6174475c22e325c090bec5f1d56822e106bc..798f0fa78055e800038e8bf41b4f410b670be7dd 100644 --- a/tensorflow/compiler/tf2xla/lib/batch_dot.cc +++ b/tensorflow/compiler/tf2xla/lib/batch_dot.cc @@ -25,11 +25,10 @@ limitations under the License. namespace tensorflow { -// The current implementation simply unrolls the computation along the batch -// dimension. xla::StatusOr BatchDot( xla::ComputationBuilder* builder, xla::ComputationDataHandle x, - xla::ComputationDataHandle y, bool transpose_x, bool transpose_y) { + xla::ComputationDataHandle y, bool transpose_x, bool transpose_y, + bool conjugate_x, bool conjugate_y) { TF_ASSIGN_OR_RETURN(std::unique_ptr x_shape, builder->GetShape(x)); TF_ASSIGN_OR_RETURN(std::unique_ptr y_shape, @@ -89,10 +88,10 @@ xla::StatusOr BatchDot( dimensions); } - if (x_shape->element_type() == xla::C64 && transpose_x) { + if (x_shape->element_type() == xla::C64 && conjugate_x) { x = builder->Conj(x); } - if (y_shape->element_type() == xla::C64 && transpose_y) { + if (y_shape->element_type() == xla::C64 && conjugate_y) { y = builder->Conj(y); } diff --git a/tensorflow/compiler/tf2xla/lib/batch_dot.h b/tensorflow/compiler/tf2xla/lib/batch_dot.h index b46bc7417d29dc5b7e9649ac28cc78b57d4b619c..b230e885f10f45a78cdd6e455da3ba55ce589b96 100644 --- a/tensorflow/compiler/tf2xla/lib/batch_dot.h +++ b/tensorflow/compiler/tf2xla/lib/batch_dot.h @@ -27,7 +27,10 @@ namespace tensorflow { // viewed as an element of a batch), and arranges the individual results // in a single output tensor of the same batch size. Each of the // individual slices can optionally be transposed before multiplication by -// setting the `transpose_x` or `transpose_y` flag to `true`. +// setting the `transpose_x` or `transpose_y` flag to `true`. Similarly, each +// can be elementwise-complex-conjugated by setting the `conjugate_x` or +// `conjugate_y` flag to `true`. To apply a Hermitian adjoint to `x`, set both +// `transpose_x` and `conjugate_x` to `true`, and analogously for `y`. // // The input tensors `x` and `y` are 2-D or higher with shape `[..., r_x, c_x]` // and `[..., r_y, c_y]`. @@ -40,11 +43,10 @@ namespace tensorflow { // It is computed as: // // output[..., :, :] = matrix(x[..., :, :]) * matrix(y[..., :, :]) -// TODO(phawkins): add an option to take the complex conjugate of the LHS or -// RHS. xla::StatusOr BatchDot( xla::ComputationBuilder* builder, xla::ComputationDataHandle x, - xla::ComputationDataHandle y, bool transpose_x, bool transpose_y); + xla::ComputationDataHandle y, bool transpose_x, bool transpose_y, + bool conjugate_x = false, bool conjugate_y = false); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/cholesky.cc b/tensorflow/compiler/tf2xla/lib/cholesky.cc index b3cc489adf6042acb3f56b3a0a6c8fbe43bde629..e795701181dd80a2ff544743d513bffd52fd2399 100644 --- a/tensorflow/compiler/tf2xla/lib/cholesky.cc +++ b/tensorflow/compiler/tf2xla/lib/cholesky.cc @@ -71,11 +71,14 @@ xla::StatusOr CholeskyUnblocked( SliceInMinorDims(builder, l, {j + 1, 0}, {n, j})); TF_ASSIGN_OR_RETURN(auto r_squared, BatchDot(builder, r, r, /*transpose_x=*/false, - /*transpose_y=*/true)); + /*transpose_y=*/true, /*conjugate_x=*/false, + /*conjugate_y=*/false)); new_d_squared = builder->Sub(new_d_squared, r_squared); TF_ASSIGN_OR_RETURN(br, BatchDot(builder, b, r, /*transpose_x=*/false, - /*transpose_y=*/true)); + /*transpose_y=*/true, + /*conjugate_x=*/false, + /*conjugate_y=*/false)); } auto new_d_inv = builder->Pow( new_d_squared, FloatLiteral(builder, shape->element_type(), -0.5)); @@ -134,7 +137,8 @@ xla::StatusOr Cholesky( SliceInMinorDims(builder, l, {i, 0}, {i + k, i})); TF_ASSIGN_OR_RETURN(auto delta, BatchDot(builder, lhs, rhs, /*transpose_x=*/false, - /*transpose_y=*/true)); + /*transpose_y=*/true, /*conjugate_x=*/false, + /*conjugate_y=*/false)); TF_ASSIGN_OR_RETURN(auto before, SliceInMinorDims(builder, a, {i, i}, {n, i + k})); TF_ASSIGN_OR_RETURN( @@ -155,6 +159,10 @@ xla::StatusOr Cholesky( SliceInMinorDims(builder, a, {i + k, i}, {n, i + k})); TF_ASSIGN_OR_RETURN(auto update, TriangularSolve(builder, factorized, panel, + /*left_side=*/false, + /*lower=*/true, + /*transpose_a=*/true, + /*conjugate_a=*/false, /*block_size=*/8)); TF_ASSIGN_OR_RETURN( l, UpdateSliceInMinorDims(builder, l, update, {i + k, i})); diff --git a/tensorflow/compiler/tf2xla/lib/cholesky.h b/tensorflow/compiler/tf2xla/lib/cholesky.h index 2bead7359baaf3582c1230adf0cd4a90046859d2..e083a383be4be0d1b556b63214fe5f70323b4149 100644 --- a/tensorflow/compiler/tf2xla/lib/cholesky.h +++ b/tensorflow/compiler/tf2xla/lib/cholesky.h @@ -29,6 +29,7 @@ namespace tensorflow { // the block size to use. // TODO(phawkins): check for negative values on the diagonal and return an // error, instead of silently yielding NaNs. +// TODO(mattjj): handle the complex Hermitian case xla::StatusOr Cholesky( xla::ComputationBuilder* builder, xla::ComputationDataHandle a, int64 block_size = 256); diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve.cc b/tensorflow/compiler/tf2xla/lib/triangular_solve.cc index 579944c3a381e7018b7fee5013d0509158ce21cc..7f72a6073df218b9e2bd4cc0c0b5bb10b5cd4b84 100644 --- a/tensorflow/compiler/tf2xla/lib/triangular_solve.cc +++ b/tensorflow/compiler/tf2xla/lib/triangular_solve.cc @@ -24,13 +24,15 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" namespace tensorflow { xla::StatusOr TriangularSolve( xla::ComputationBuilder* builder, const xla::ComputationDataHandle& a, - xla::ComputationDataHandle b, int64 block_size) { + xla::ComputationDataHandle b, bool left_side, bool lower, bool transpose_a, + bool conjugate_a, int64 block_size) { TF_ASSIGN_OR_RETURN(std::unique_ptr a_shape, builder->GetShape(a)); TF_ASSIGN_OR_RETURN(std::unique_ptr b_shape, @@ -60,14 +62,15 @@ xla::StatusOr TriangularSolve( batch_dimensions.push_back(a_size); } - const int64 n = xla::ShapeUtil::GetDimension(*a_shape, -1); - const int64 m = xla::ShapeUtil::GetDimension(*b_shape, -2); - if (n != xla::ShapeUtil::GetDimension(*a_shape, -2)) { + if (xla::ShapeUtil::GetDimension(*a_shape, -1) != + xla::ShapeUtil::GetDimension(*a_shape, -2)) { return errors::InvalidArgument( "The 'a' arguments to TriangularSolve must be square matrices: ", xla::ShapeUtil::HumanString(*a_shape)); } - if (n != xla::ShapeUtil::GetDimension(*b_shape, -1)) { + const int64 m = xla::ShapeUtil::GetDimension(*b_shape, -2); + const int64 n = xla::ShapeUtil::GetDimension(*b_shape, -1); + if ((left_side ? m : n) != xla::ShapeUtil::GetDimension(*a_shape, -1)) { return errors::InvalidArgument( "Arguments to TriangularSolve have incompatible matrix shapes: ", xla::ShapeUtil::HumanString(*a_shape), " vs ", @@ -89,6 +92,14 @@ xla::StatusOr TriangularSolve( return output; }; + // Applies a complex conjugation operation if `a` is complex and `conjugate_a` + // is true, otherwise returns its argument. + auto maybe_conj = [&](xla::ComputationBuilder* builder, + xla::ComputationDataHandle x) { + auto perform_conj = a_shape->element_type() == xla::C64 && conjugate_a; + return perform_conj ? builder->Conj(x) : x; + }; + std::map base_computations; auto get_base_triangular_solve = [&](int k) -> xla::StatusOr { @@ -103,19 +114,35 @@ xla::StatusOr TriangularSolve( prepend_batch_dims({k, k})), "a"); + std::array b_lastd; + if (left_side) { + b_lastd = {k, n}; + } else { + b_lastd = {m, k}; + } auto b_param = sub->Parameter(1, xla::ShapeUtil::MakeShape(b_shape->element_type(), - prepend_batch_dims({m, k})), + prepend_batch_dims(b_lastd)), "b"); - // TODO(phawkins): it might make sense to use a while loop here, rather - // than unrolling. - // TODO(phawkins): the left-looking variant of the algorithm might be more - // efficient at block size 1. - TF_RETURN_IF_ERROR(TriangularSolve(sub.get(), a_param, b_param, - /*block_size=*/1) - .status()); + // We use a left-looking subroutine on the block diagonal in some common + // cases, while falling back to a recursive call in unsupported cases. The + // left-looking subroutine is written with a While loop and so yields much + // faster compile times. Moreover, the left-looking variant can give + // higher performance on smaller (sub)problems. + if (left_side && lower) { + TF_RETURN_IF_ERROR(TriangularSolveLeftLooking(sub.get(), a_param, + b_param, transpose_a, + conjugate_a) + .status()); + } else { + TF_RETURN_IF_ERROR(TriangularSolve(sub.get(), a_param, b_param, + left_side, lower, transpose_a, + conjugate_a, + /*block_size=*/1) + .status()); + } TF_ASSIGN_OR_RETURN(computation, sub->Build()); } @@ -129,47 +156,396 @@ xla::StatusOr TriangularSolve( // Goto, Kazushige, and Robert Van De Geijn. "High-performance implementation // of the level-3 BLAS." ACM Transactions on Mathematical Software (TOMS) 35.1 // (2008): 4. - for (int64 i = 0; i < n; i += block_size) { - int64 k = std::min(block_size, n - i); - // if k > 1: - // output[..., :, i:i+k] = triangular_solve( - // a[..., i:i+k, ..., i:i+k], b[..., :, i:i+k], side='Right', - // kind='Lower', transpose=True, block_size=1) - // else: - // output[..., :, i] = b[..., :, i] / a[..., i, i] + // In the code comments below, T = lambda x: np.swapaxes(x, -1, -2) if + // conjugate_a is False, or T = lambda x: np.conj(np.swapaxes(x, -1, -2)) if + // conjugate_a is True. + + if (!left_side && lower == transpose_a) { + // for i in range(0, a.shape[-1], block_size): + for (int64 i = 0; i < n; i += block_size) { + int64 k = std::min(block_size, n - i); + + // output[..., :, i:i+k] = triangular_solve( + // a[..., i:i+k, i:i+k], b[..., :, i:i+k], ..., block_size=1) + TF_ASSIGN_OR_RETURN(auto a_slice, + SliceInMinorDims(builder, a, {i, i}, {i + k, i + k})); + TF_ASSIGN_OR_RETURN(auto b_slice, + SliceInMinorDims(builder, b, {0, i}, {m, i + k})); + xla::ComputationDataHandle update; + if (k > 1) { + TF_ASSIGN_OR_RETURN(xla::Computation * solve, + get_base_triangular_solve(k)); + update = builder->Call(*solve, {a_slice, b_slice}); + } else { + update = builder->Div(b_slice, maybe_conj(builder, a_slice)); + } + TF_ASSIGN_OR_RETURN( + output, UpdateSliceInMinorDims(builder, output, update, {0, i})); + + // if i + k < a.shape[-1]: + // a_slice_2 = a[..., i+k:, i:i+k] if lower else a[..., i:i+k, i+k:] + // a_slice_2 = T(a_slice_2) if transpose_a else a_slice_2 + // b[..., :, i+k:] -= np.matmul(output[..., :, i:i+k], a_slice_2) + if (i + k < n) { + xla::ComputationDataHandle a_slice_2; + if (lower) { + TF_ASSIGN_OR_RETURN( + a_slice_2, SliceInMinorDims(builder, a, {i + k, i}, {n, i + k})); + } else { + TF_ASSIGN_OR_RETURN( + a_slice_2, SliceInMinorDims(builder, a, {i, i + k}, {i + k, n})); + } + + TF_ASSIGN_OR_RETURN(auto b_update, + BatchDot(builder, update, a_slice_2, + /*transpose_x=*/false, + /*transpose_y=*/transpose_a, + /*conjugate_x=*/false, + /*conjugate_y=*/conjugate_a)); + TF_ASSIGN_OR_RETURN(auto b_slice_2, + SliceInMinorDims(builder, b, {0, i + k}, {m, n})); + b_update = builder->Sub(b_slice_2, b_update); + TF_ASSIGN_OR_RETURN( + b, UpdateSliceInMinorDims(builder, b, b_update, {0, i + k})); + } + } + + } else if (left_side && lower != transpose_a) { + // for i in range(0, a.shape[-1], block_size): + for (int64 i = 0; i < m; i += block_size) { + int64 k = std::min(block_size, m - i); + + // output[..., i:i+k, :] = triangular_solve( + // a[..., i:i+k, i:i+k], b[..., i:i+k, :], ..., block_size=1) + TF_ASSIGN_OR_RETURN(auto a_slice, + SliceInMinorDims(builder, a, {i, i}, {i + k, i + k})); + TF_ASSIGN_OR_RETURN(auto b_slice, + SliceInMinorDims(builder, b, {i, 0}, {i + k, n})); + xla::ComputationDataHandle update; + if (k > 1) { + TF_ASSIGN_OR_RETURN(xla::Computation * solve, + get_base_triangular_solve(k)); + update = builder->Call(*solve, {a_slice, b_slice}); + } else { + update = builder->Div(b_slice, maybe_conj(builder, a_slice)); + } + TF_ASSIGN_OR_RETURN( + output, UpdateSliceInMinorDims(builder, output, update, {i, 0})); + + // if i + k < a.shape[-1]: + // a_slice_2 = a[..., i+k:, i:i+k] if lower else a[..., i:i+k, i+k:] + // a_slice_2 = T(a_slice_2) if transpose_a else a_slice_2 + // b[..., i+k:, :] -= np.matmul(a_slice_2, output[..., i:i+k, :]) + if (i + k < m) { + xla::ComputationDataHandle a_slice_2; + if (lower) { + TF_ASSIGN_OR_RETURN( + a_slice_2, SliceInMinorDims(builder, a, {i + k, i}, {m, i + k})); + } else { + TF_ASSIGN_OR_RETURN( + a_slice_2, SliceInMinorDims(builder, a, {i, i + k}, {i + k, m})); + } + + TF_ASSIGN_OR_RETURN(auto b_update, BatchDot(builder, a_slice_2, update, + /*transpose_x=*/transpose_a, + /*transpose_y=*/false, + /*conjugate_x=*/conjugate_a, + /*conjugate_y=*/false)); + TF_ASSIGN_OR_RETURN(auto b_slice_2, + SliceInMinorDims(builder, b, {i + k, 0}, {m, n})); + b_update = builder->Sub(b_slice_2, b_update); + TF_ASSIGN_OR_RETURN( + b, UpdateSliceInMinorDims(builder, b, b_update, {i + k, 0})); + } + } + } else if (!left_side && lower != transpose_a) { + // for i in reversed(range(0, a.shape[-1], block_size)): + const int64 last_blk_ix = xla::RoundUpToNearest(n, block_size) - block_size; + for (int64 i = last_blk_ix; i >= 0; i -= block_size) { + int64 k = std::min(block_size, n - i); + + // output[..., :, i:i+k] triangular_solve( + // a[..., i:i+k, i:i+k], b[..., :, i:i+k], ..., block_size=1) + TF_ASSIGN_OR_RETURN(auto a_slice, + SliceInMinorDims(builder, a, {i, i}, {i + k, i + k})); + TF_ASSIGN_OR_RETURN(auto b_slice, + SliceInMinorDims(builder, b, {0, i}, {m, i + k})); + xla::ComputationDataHandle update; + if (k > 1) { + TF_ASSIGN_OR_RETURN(xla::Computation * solve, + get_base_triangular_solve(k)); + update = builder->Call(*solve, {a_slice, b_slice}); + } else { + update = builder->Div(b_slice, maybe_conj(builder, a_slice)); + } + TF_ASSIGN_OR_RETURN( + output, UpdateSliceInMinorDims(builder, output, update, {0, i})); + + // if i - k >= 0: + // a_slice_2 = a[..., i:i+k, :i] if lower else a[..., :i, i:i+k] + // a_slice_2 = T(a_slice_2) if transpose_a else a_slice_2 + // b[..., :, :i] -= np.matmul(out[..., :, i:i+k], a_slice_2) + if (i - k >= 0) { + xla::ComputationDataHandle a_slice_2; + if (lower) { + TF_ASSIGN_OR_RETURN(a_slice_2, + SliceInMinorDims(builder, a, {i, 0}, {i + k, i})); + } else { + TF_ASSIGN_OR_RETURN(a_slice_2, + SliceInMinorDims(builder, a, {0, i}, {i, i + k})); + } + + TF_ASSIGN_OR_RETURN(auto b_update, + BatchDot(builder, update, a_slice_2, + /*transpose_x=*/false, + /*transpose_y=*/transpose_a, + /*conjugate_x=*/false, + /*conjugate_y=*/conjugate_a)); + TF_ASSIGN_OR_RETURN(auto b_slice_2, + SliceInMinorDims(builder, b, {0, 0}, {m, i})); + b_update = builder->Sub(b_slice_2, b_update); + TF_ASSIGN_OR_RETURN( + b, UpdateSliceInMinorDims(builder, b, b_update, {0, 0})); + } + } + } else { // left_side && lower == transpose_a + // for i in reversed(range(0, a.shape[-1], block_size)): + const int64 last_blk_ix = xla::RoundUpToNearest(m, block_size) - block_size; + for (int64 i = last_blk_ix; i >= 0; i -= block_size) { + int64 k = std::min(block_size, m - i); + + // output[..., i:i+k, :] triangular_solve( + // a[..., i:i+k, i:i+k], b[..., i:i+k, :], ..., block_size=1) + TF_ASSIGN_OR_RETURN(auto a_slice, + SliceInMinorDims(builder, a, {i, i}, {i + k, i + k})); + TF_ASSIGN_OR_RETURN(auto b_slice, + SliceInMinorDims(builder, b, {i, 0}, {i + k, n})); + xla::ComputationDataHandle update; + if (k > 1) { + TF_ASSIGN_OR_RETURN(xla::Computation * solve, + get_base_triangular_solve(k)); + update = builder->Call(*solve, {a_slice, b_slice}); + } else { + update = builder->Div(b_slice, maybe_conj(builder, a_slice)); + } + TF_ASSIGN_OR_RETURN( + output, UpdateSliceInMinorDims(builder, output, update, {i, 0})); + + // if i - k >= 0: + // a_slice_2 = a[..., i:i+k, :i] if lower else a[..., :i, i:i+k] + // a_slice_2 = T(a_slice_2) if transpose_a else a_slice_2 + // b[..., :i, :] -= np.matmul(a_slice_2, out[..., i:i+k, :]) + if (i - k >= 0) { + xla::ComputationDataHandle a_slice_2; + if (lower) { + TF_ASSIGN_OR_RETURN(a_slice_2, + SliceInMinorDims(builder, a, {i, 0}, {i + k, i})); + } else { + TF_ASSIGN_OR_RETURN(a_slice_2, + SliceInMinorDims(builder, a, {0, i}, {i, i + k})); + } + + TF_ASSIGN_OR_RETURN(auto b_update, BatchDot(builder, a_slice_2, update, + /*transpose_x=*/transpose_a, + /*transpose_y=*/false, + /*conjugate_x=*/conjugate_a, + /*conjugate_y=*/false)); + TF_ASSIGN_OR_RETURN(auto b_slice_2, + SliceInMinorDims(builder, b, {0, 0}, {i, n})); + b_update = builder->Sub(b_slice_2, b_update); + TF_ASSIGN_OR_RETURN( + b, UpdateSliceInMinorDims(builder, b, b_update, {0, 0})); + } + } + } + + return output; +} + +xla::StatusOr TriangularSolveLeftLooking( + xla::ComputationBuilder* builder, const xla::ComputationDataHandle& a, + const xla::ComputationDataHandle& b, bool transpose_a, bool conjugate_a) { + TF_ASSIGN_OR_RETURN(std::unique_ptr a_shape, + builder->GetShape(a)); + TF_ASSIGN_OR_RETURN(std::unique_ptr b_shape, + builder->GetShape(b)); + const int64 m = xla::ShapeUtil::GetDimension(*b_shape, -2); + const int64 n = xla::ShapeUtil::GetDimension(*b_shape, -1); + const int64 ndims = xla::ShapeUtil::Rank(*a_shape); + + std::vector batch_dimensions; + for (int i = 0; i < ndims - 2; ++i) { + int64 a_size = a_shape->dimensions(i); + batch_dimensions.push_back(a_size); + } + + auto prepend_batch_dims = [&](std::array indices) { + std::vector output(ndims); + std::copy(batch_dimensions.begin(), batch_dimensions.end(), output.begin()); + std::copy(indices.begin(), indices.end(), + output.begin() + batch_dimensions.size()); + return output; + }; + + auto maybe_conj = [&](xla::ComputationBuilder* builder, + xla::ComputationDataHandle x) { + auto perform_conj = a_shape->element_type() == xla::C64 && conjugate_a; + return perform_conj ? builder->Conj(x) : x; + }; + + // The main computation is performed in a While loop. + + // Allocate the output and set its first or last row, + // output = np.zeros_like(b) + // if transpose_a: + // output[..., m-1:, :] = b[..., m-1:, :] / a[..., m-1:, m-1:] + // else: + // output[..., :1, :] = b[..., :1, :] / a[..., :1, :1] + xla::ComputationDataHandle output = Zeros(builder, *b_shape); + { + auto i = transpose_a ? m - 1 : 0; TF_ASSIGN_OR_RETURN(auto a_slice, - SliceInMinorDims(builder, a, {i, i}, {i + k, i + k})); + SliceInMinorDims(builder, a, {i, i}, {i + 1, i + 1})); TF_ASSIGN_OR_RETURN(auto b_slice, - SliceInMinorDims(builder, b, {0, i}, {m, i + k})); - xla::ComputationDataHandle update; - if (k > 1) { - TF_ASSIGN_OR_RETURN(xla::Computation * solve, - get_base_triangular_solve(k)); - update = builder->Call(*solve, {a_slice, b_slice}); + SliceInMinorDims(builder, b, {i, 0}, {i + 1, n})); + auto update = builder->Div(b_slice, maybe_conj(builder, a_slice)); + TF_ASSIGN_OR_RETURN( + output, UpdateSliceInMinorDims(builder, output, update, {i, 0})); + } + + // Construct the initial loop carry tuple, + // if transpose_a: + // init = (m-2, output, a, b) + // else: + // init = (1, output, a, b) + std::vector tuple_shapes = { + // The loop iteration counter is a scalar, incremented each iteration. + xla::ShapeUtil::MakeShape(xla::S32, {}), + // The output has the shape of b, with one row updated each iteration. + *b_shape, + // The coefficient matrix a is a loop invariant. + *a_shape, + // The right-hand-side matrix b is a loop invariant. + *b_shape}; + xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(tuple_shapes); + auto init_i = builder->ConstantR0(transpose_a ? m - 2 : 1); + auto init = builder->Tuple({init_i, output, a, b}); + + // Construct the loop condition function, + // def cond_fun(loop_carry): + // i, output, a, b = loop_carry + // return i >= 0 if transpose_a else i < m + std::unique_ptr condb = + builder->CreateSubBuilder("TriangularSolveLeftLookingWhileCond"); + { + auto i = condb->GetTupleElement( + condb->Parameter(0, tuple_shape, + "TriangularSolveLeftLookingWhileTuple"), + 0); + if (transpose_a) { + condb->Ge(i, condb->ConstantR0(0)); } else { - update = builder->Div(b_slice, a_slice); + condb->Lt(i, condb->ConstantR0(m)); } + } + TF_ASSIGN_OR_RETURN(auto cond, condb->Build()); - TF_ASSIGN_OR_RETURN( - output, UpdateSliceInMinorDims(builder, output, update, {0, i})); - // b[..., :, i+k:] -= np.dot(output[..., :, i:i+k], - // np.transpose(..., a[i+k:, i:i+k])) - if (i + k < n) { - TF_ASSIGN_OR_RETURN(auto a_slice_2, - SliceInMinorDims(builder, a, {i + k, i}, {n, i + k})); - TF_ASSIGN_OR_RETURN(auto b_update, BatchDot(builder, update, a_slice_2, - /*transpose_x=*/false, - /*transpose_y=*/true)); - - TF_ASSIGN_OR_RETURN(auto b_slice_2, - SliceInMinorDims(builder, b, {0, i + k}, {m, n})); - b_update = builder->Sub(b_slice_2, b_update); - TF_ASSIGN_OR_RETURN( - b, UpdateSliceInMinorDims(builder, b, b_update, {0, i + k})); + // Construct the loop body function, + // def body_fun(loop_carry): + // i, output, a, b = loop_carry + // if transpose_a: + // a_row = np.swapaxes(a[..., i+1:, i:i+1], -1 -2) + // else: + // a_row = a[..., i:i+1, :i] + // result_row = b[..., i:i+1, :] - np.matmul(a_row, output[..., :, :]) + // output[..., i:i+1, :] = result_row / a[..., i:i+1, i:i+1] + // if transpose_a: + // return (i - 1, output, a, b) + // else: + // return (i + 1, output, a, b) + // We have to do some extra FLOPs propagating zeros in the matrix multiply + // because we can't have the size of its arguments depend on the loop counter. + std::unique_ptr bodyb = + builder->CreateSubBuilder("TriangularSolveLeftLookingWhileBody"); + { + auto input_tuple = bodyb->Parameter(0, tuple_shape, + "TriangularSolveLeftLookingWhileTuple"); + + // i, output, a, b = loop_carry + auto i = bodyb->GetTupleElement(input_tuple, 0); + auto body_out = bodyb->GetTupleElement(input_tuple, 1); + auto body_a = bodyb->GetTupleElement(input_tuple, 2); + auto body_b = bodyb->GetTupleElement(input_tuple, 3); + auto zero = bodyb->ConstantR0(0); + + // Set up some helper functions. + auto prepend_zeros = [&](std::array starts) { + auto zero = bodyb->Reshape(bodyb->ConstantR0(0), {1}); + std::vector padded_starts(ndims, zero); + padded_starts[ndims - 2] = bodyb->Reshape(starts[0], {1}); + padded_starts[ndims - 1] = bodyb->Reshape(starts[1], {1}); + return bodyb->ConcatInDim(padded_starts, 0); + }; + + auto dynamic_slice = [&](xla::ComputationDataHandle x, + std::array starts, + std::array sizes) { + auto padded_starts = prepend_zeros(starts); + auto padded_sizes = prepend_batch_dims(sizes); + return bodyb->DynamicSlice(x, padded_starts, padded_sizes); + }; + + auto update = [&](xla::ComputationDataHandle x, + xla::ComputationDataHandle update, + std::array starts) { + auto padded_starts = prepend_zeros(starts); + return bodyb->DynamicUpdateSlice(x, update, padded_starts); + }; + + // We'd like to implement this: + // if transpose_a: + // a_row = T(a[..., i+1:, i:i+1]) + // result_row = (b[..., i:i+1, :] + // - np.matmul(a_row, body_out[..., i+1:, :])) + // else: + // result_row = (b[..., i:i+1, :] + // - np.matmul(a[..., i:i+1, :i], body_out[..., :i, :])) + // But since we can't have intermediate array sizes depend on the loop + // counter, we instead exploit the fact that we initialized the output to + // all zeros and use that as zero-padding (doing unnecessary FLOPs). + xla::ComputationDataHandle a_row; + if (transpose_a) { + a_row = dynamic_slice(body_a, {zero, i}, {m, 1}); + } else { + a_row = dynamic_slice(body_a, {i, zero}, {1, m}); } + TF_ASSIGN_OR_RETURN(auto b_update, BatchDot(bodyb.get(), a_row, body_out, + /*transpose_x=*/transpose_a, + /*transpose_y=*/false, + /*conjugate_x=*/conjugate_a, + /*conjugate_y=*/false)); + auto result_row = + bodyb->Sub(dynamic_slice(body_b, {i, zero}, {1, n}), b_update); + + // body_out[..., i:i+1, :] = result_row / a[..., i:i+1, i:i+1] + auto a_elt = dynamic_slice(body_a, {i, i}, {1, 1}); + auto div_result = bodyb->Div(result_row, maybe_conj(bodyb.get(), a_elt)); + body_out = update(body_out, div_result, {i, zero}); + + // if transpose_a: + // return (i - 1, body_out, a, b) + // else: + // return (i + 1, body_out, a, b) + auto next_i = bodyb->Add(i, bodyb->ConstantR0(transpose_a ? -1 : 1)); + bodyb->Tuple({next_i, body_out, body_a, body_b}); } - return output; + TF_ASSIGN_OR_RETURN(auto body, bodyb->Build()); + + // Construct the While loop and return the result, + // return while_loop(cond_fun, body_fun, init)[1] + auto triangular_solve_left_looking_while = builder->While(cond, body, init); + return builder->GetTupleElement(triangular_solve_left_looking_while, 1); } } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve.h b/tensorflow/compiler/tf2xla/lib/triangular_solve.h index 501d026411c80359c7efa406ece5929a2e46ac1f..e32223bfdddda800b1fd4de3e4f0c8061e0f81d8 100644 --- a/tensorflow/compiler/tf2xla/lib/triangular_solve.h +++ b/tensorflow/compiler/tf2xla/lib/triangular_solve.h @@ -21,25 +21,50 @@ limitations under the License. namespace tensorflow { -// Solves systems of linear equations with upper or lower triangular matrices by -// backsubstitution. +// Solves systems of linear equations with lower or upper triangular coefficient +// matrices by forward- or back-substitution. Broadcasting along leading +// dimensions, this routine solves one of the matrix systems +// `op(a) * x = b`, or `x * op(a) = b`, +// for the variable `x` given `a` and `b`, where `op(a)` is either +// `op(a) = a`, or `op(a) = transpose(a)`, or `op(a) = conj(transpose(a))`. +// That is, the innermost matrices in the output satisfy a scalar system +// depending on the value of the value of (left_side, transpose_a, conjugate_a) +// according to: +// (F, F, F) => `output[..., i, k] a[..., k, j] = b[..., i, j]`, +// (F, F, T) => `output[..., i, k] a*[..., k, j] = b[..., i, j]`, +// (F, T, F) => `output[..., i, k] a[..., j, k] = b[..., i, j]`, +// (F, T, T) => `output[..., i, k] a*[..., j, k] = b[..., i, j]`, +// (T, F, F) => ` a[..., i, k] output[..., k, j] = b[..., i, j]`, +// (T, F, T) => `a*[..., i, k] output[..., k, j] = b[..., i, j]`, +// (T, T, F) => ` a[..., i, k] output[..., j, k] = b[..., i, j]`, +// (T, T, T) => `a*[..., i, k] output[..., j, k] = b[..., i, j]`, +// where * denotes complex conjugation and where the index `k` is summed over. // -// `a` is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions form -// square matrices. The strictly upper triangular part of each inner-most matrix -// is assumed to be zero and not accessed. -// `b` is a tensor of shape `[..., M, K]`. -// -// The innermost matrices in the output satisfy matrix equations -// `output[..., i, j] * adjoint(a[..., k, j]) = b[..., i, k]`. +// `a` is a tensor of shape `[..., M, M]` whose innermost 2 dimensions form +// square matrices. If lower is true (false), then the strictly upper (lower) +// triangular part of each innermost matrix in `a` is assumed to be zero and is +// not accessed. +// `b` is a tensor of shape `[..., M, K]` if left_side is true, otherwise a +// tensor of shape `[..., K, M]`. +// `left_side` is a boolean, indicating whether to solve a system of the form +// op(a) * x = b (true) or x * op(a) = b (false). +// `lower` is a boolean, indicating whether the argument `a` is lower-triangular +// (true) or upper-triangular (false). +// `transpose_a` is a boolean indicating whether the matrix `a` is transposed. +// `conjugate_a` is a boolean indicating whether the entries of `a` are complex +// conjugated (independently of whether they are transposed), so that when both +// transpose_a and conjugate_a are true the effect is a Hermitian adjoint. // // Uses a blocked algorithm if `block_size` is > 1; if block_size == 1 then no // blocking is used. -// TODO(phawkins): equivalent to the BLAS TRSM routine with side=right, -// kind=lower, and transposed_a=true. Implement the other possible combinations -// of side, kind and transposed_a. xla::StatusOr TriangularSolve( xla::ComputationBuilder* builder, const xla::ComputationDataHandle& a, - xla::ComputationDataHandle b, int64 block_size = 256); + xla::ComputationDataHandle b, bool left_side, bool lower, bool transpose_a, + bool conjugate_a, int64 block_size = 256); + +xla::StatusOr TriangularSolveLeftLooking( + xla::ComputationBuilder* builder, const xla::ComputationDataHandle& a, + const xla::ComputationDataHandle& b, bool transpose_a, bool conjugate_a); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc b/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc index 671d9aa4fe0c042a3cc44468074653d51c2be75d..661707062916263fd0d5d935ce41698a7655df02 100644 --- a/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc +++ b/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc @@ -27,32 +27,134 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/status_test_util.h" namespace tensorflow { namespace { using TriangularSolveTest = xla::ClientLibraryTestBase; +using TriangularSolveLeftLookingTest = xla::ClientLibraryTestBase; +using complex64 = xla::complex64; -XLA_TEST_F(TriangularSolveTest, Simple) { +xla::Array2D AValsLower() { + return {{2, 0, 0, 0}, {3, 6, 0, 0}, {4, 7, 9, 0}, {5, 8, 10, 11}}; +} + +xla::Array2D AValsUpper() { + return {{2, 3, 4, 5}, {0, 6, 7, 8}, {0, 0, 9, 10}, {0, 0, 0, 11}}; +} + +xla::Array2D BValsRight() { + return {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}; +} + +xla::Array2D BValsLeft() { + return {{1, 2, 3}, {4, 5, 6}, {7, 8, 9}, {10, 11, 12}}; +} + +xla::Array2D AValsLowerComplex() { + return {{2, 0, 0, 0}, + {complex64(3, 1), 6, 0, 0}, + {4, complex64(7, 2), 9, 0}, + {5, 8, complex64(10, 3), 11}}; +} + +xla::Array2D AValsUpperComplex() { + return {{2, 3, complex64(4, 3), 5}, + {0, 6, complex64(7, 2), 8}, + {0, 0, complex64(9, 1), 10}, + {0, 0, 0, 11}}; +} + +xla::Array2D BValsRightComplex() { + return {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}; +} + +xla::Array2D BValsLeftComplex() { + return {{1, 2, 3}, {4, 5, 6}, {7, 8, 9}, {10, 11, 12}}; +} + +xla::Array2D AValsFull() { + return {{2, 0, 1, 2}, {3, 6, 0, 1}, {4, 7, 9, 0}, {5, 8, 10, 11}}; +} + +XLA_TEST_F(TriangularSolveTest, SimpleRightLowerTranspose) { xla::ComputationBuilder builder(client_, TestName()); - xla::Array2D a_vals({ - {2, 0, 0, 0}, - {3, 6, 0, 0}, - {4, 7, 9, 0}, - {5, 8, 10, 11}, + xla::ComputationDataHandle a, b; + auto a_data = CreateR2Parameter(AValsLower(), 0, "a", &builder, &a); + auto b_data = CreateR2Parameter(BValsRight(), 1, "b", &builder, &b); + auto result = TriangularSolve(&builder, a, b, + /*left_side=*/false, /*lower=*/true, + /*transpose_a=*/true, /*conjugate_a=*/false, + /*block_size=*/2); + TF_ASSERT_OK(result.status()); + + xla::Array2D expected({ + {0.5, 0.08333334, 0.04629629, 0.03367003}, + {2.5, -0.25, -0.1388889, -0.1010101}, + {4.5, -0.58333331, -0.32407406, -0.23569024}, + }); + + ComputeAndCompareR2(&builder, expected, {a_data.get(), b_data.get()}, + xla::ErrorSpec(1e-2, 1e-2)); +} + +XLA_TEST_F(TriangularSolveTest, SimpleRightLowerNotranspose) { + xla::ComputationBuilder builder(client_, TestName()); + + xla::ComputationDataHandle a, b; + auto a_data = CreateR2Parameter(AValsLower(), 0, "a", &builder, &a); + auto b_data = CreateR2Parameter(BValsRight(), 1, "b", &builder, &b); + auto result = TriangularSolve(&builder, a, b, + /*left_side=*/false, /*lower=*/true, + /*transpose_a=*/false, /*conjugate_a=*/false, + /*block_size=*/2); + TF_ASSERT_OK(result.status()); + + xla::Array2D expected({ + {-0.16414141, -0.06902357, -0.07070707, 0.36363636}, + {0.64393939, 0.06565657, -0.03030303, 0.72727273}, + {1.4520202, 0.2003367, 0.01010101, 1.09090909}, }); - xla::Array2D b_vals({ - {1, 2, 3, 4}, - {5, 6, 7, 8}, - {9, 10, 11, 12}, + + ComputeAndCompareR2(&builder, expected, {a_data.get(), b_data.get()}, + xla::ErrorSpec(1e-2, 1e-2)); +} + +XLA_TEST_F(TriangularSolveTest, SimpleRightUpperTranspose) { + xla::ComputationBuilder builder(client_, TestName()); + + xla::ComputationDataHandle a, b; + auto a_data = CreateR2Parameter(AValsUpper(), 0, "a", &builder, &a); + auto b_data = CreateR2Parameter(BValsRight(), 1, "b", &builder, &b); + auto result = TriangularSolve(&builder, a, b, + /*left_side=*/false, /*lower=*/false, + /*transpose_a=*/true, /*conjugate_a=*/false, + /*block_size=*/2); + TF_ASSERT_OK(result.status()); + + xla::Array2D expected({ + {-0.16414141, -0.06902357, -0.07070707, 0.36363636}, + {0.64393939, 0.06565657, -0.03030303, 0.72727273}, + {1.4520202, 0.2003367, 0.01010101, 1.09090909}, }); + ComputeAndCompareR2(&builder, expected, {a_data.get(), b_data.get()}, + xla::ErrorSpec(1e-2, 1e-2)); +} + +XLA_TEST_F(TriangularSolveTest, SimpleRightUpperNotranspose) { + xla::ComputationBuilder builder(client_, TestName()); + xla::ComputationDataHandle a, b; - auto a_data = CreateR2Parameter(a_vals, 0, "a", &builder, &a); - auto b_data = CreateR2Parameter(b_vals, 1, "b", &builder, &b); - auto result = TriangularSolve(&builder, a, b, /*block_size=*/2); + auto a_data = CreateR2Parameter(AValsUpper(), 0, "a", &builder, &a); + auto b_data = CreateR2Parameter(BValsRight(), 1, "b", &builder, &b); + auto result = TriangularSolve(&builder, a, b, + /*left_side=*/false, /*lower=*/false, + /*transpose_a=*/false, /*conjugate_a=*/false, + /*block_size=*/2); TF_ASSERT_OK(result.status()); xla::Array2D expected({ @@ -62,7 +164,201 @@ XLA_TEST_F(TriangularSolveTest, Simple) { }); ComputeAndCompareR2(&builder, expected, {a_data.get(), b_data.get()}, - xla::ErrorSpec(2e-3, 2e-3)); + xla::ErrorSpec(1e-2, 1e-2)); +} + +XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerTranspose) { + xla::ComputationBuilder builder(client_, TestName()); + + xla::ComputationDataHandle a, b; + auto a_data = CreateR2Parameter(AValsLower(), 0, "a", &builder, &a); + auto b_data = CreateR2Parameter(BValsLeft(), 1, "b", &builder, &b); + auto result = TriangularSolve(&builder, a, b, + /*left_side=*/true, /*lower=*/true, + /*transpose_a=*/true, /*conjugate_a=*/false, + /*block_size=*/2); + TF_ASSERT_OK(result.status()); + + xla::Array2D expected({ + {-0.89646465, -0.69444444, -0.49242424}, + {-0.27441077, -0.24074074, -0.20707071}, + {-0.23232323, -0.22222222, -0.21212121}, + {0.90909091, 1., 1.09090909}, + }); + + ComputeAndCompareR2(&builder, expected, {a_data.get(), b_data.get()}, + xla::ErrorSpec(1e-2, 1e-2)); +} + +XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerNotranspose) { + xla::ComputationBuilder builder(client_, TestName()); + + xla::ComputationDataHandle a, b; + auto a_data = CreateR2Parameter(AValsLower(), 0, "a", &builder, &a); + auto b_data = CreateR2Parameter(BValsLeft(), 1, "b", &builder, &b); + auto result = TriangularSolve(&builder, a, b, + /*left_side=*/true, /*lower=*/true, + /*transpose_a=*/false, /*conjugate_a=*/false, + /*block_size=*/2); + TF_ASSERT_OK(result.status()); + + xla::Array2D expected({ + {0.5, 1.0, 1.5}, + {0.41666667, 0.33333333, 0.25}, + {0.23148148, 0.18518519, 0.13888889}, + {0.16835017, 0.13468013, 0.1010101}, + }); + + ComputeAndCompareR2(&builder, expected, {a_data.get(), b_data.get()}, + xla::ErrorSpec(1e-2, 1e-2)); +} + +XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperTranspose) { + xla::ComputationBuilder builder(client_, TestName()); + + xla::ComputationDataHandle a, b; + auto a_data = CreateR2Parameter(AValsUpper(), 0, "a", &builder, &a); + auto b_data = CreateR2Parameter(BValsLeft(), 1, "b", &builder, &b); + auto result = TriangularSolve(&builder, a, b, + /*left_side=*/true, /*lower=*/false, + /*transpose_a=*/true, /*conjugate_a=*/false, + /*block_size=*/2); + TF_ASSERT_OK(result.status()); + + xla::Array2D expected({ + {0.5, 1.0, 1.5}, + {0.41666667, 0.33333333, 0.25}, + {0.23148148, 0.18518519, 0.13888889}, + {0.16835017, 0.13468013, 0.1010101}, + }); + + ComputeAndCompareR2(&builder, expected, {a_data.get(), b_data.get()}, + xla::ErrorSpec(1e-2, 1e-2)); +} + +XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperNotranspose) { + xla::ComputationBuilder builder(client_, TestName()); + + xla::ComputationDataHandle a, b; + auto a_data = CreateR2Parameter(AValsUpper(), 0, "a", &builder, &a); + auto b_data = CreateR2Parameter(BValsLeft(), 1, "b", &builder, &b); + auto result = TriangularSolve(&builder, a, b, + /*left_side=*/true, /*lower=*/false, + /*transpose_a=*/false, /*conjugate_a=*/false, + /*block_size=*/2); + TF_ASSERT_OK(result.status()); + + xla::Array2D expected({ + {-0.89646465, -0.69444444, -0.49242424}, + {-0.27441077, -0.24074074, -0.20707071}, + {-0.23232323, -0.22222222, -0.21212121}, + {0.90909091, 1., 1.09090909}, + }); + + ComputeAndCompareR2(&builder, expected, {a_data.get(), b_data.get()}, + xla::ErrorSpec(1e-2, 1e-2)); +} + +XLA_TEST_F(TriangularSolveTest, SimpleRightLowerTransposeConjugate) { + xla::ComputationBuilder builder(client_, TestName()); + + xla::ComputationDataHandle a, b; + auto a_data = + CreateR2Parameter(AValsLowerComplex(), 0, "a", &builder, &a); + auto b_data = + CreateR2Parameter(BValsRightComplex(), 1, "b", &builder, &b); + auto result = TriangularSolve(&builder, a, b, + /*left_side=*/false, /*lower=*/true, + /*transpose_a=*/true, /*conjugate_a=*/true, + /*block_size=*/2); + TF_ASSERT_OK(result.status()); + + xla::Array2D expected({ + {0.5, complex64(0.08333333, 0.08333333), + complex64(0.02777778, -0.0462963), complex64(0.06313131, -0.01094276)}, + {2.5, complex64(-0.25, 0.41666667), complex64(-0.23148148, -0.37962963), + complex64(0.08670034, -0.02104377)}, + {4.5, complex64(-0.58333333, 0.75), complex64(-0.49074074, -0.71296296), + complex64(0.11026936, -0.03114478)}, + }); + + ComputeAndCompareR2(&builder, expected, + {a_data.get(), b_data.get()}, + xla::ErrorSpec(1e-2, 1e-2)); +} + +XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperTransposeNoconjugate) { + xla::ComputationBuilder builder(client_, TestName()); + + xla::ComputationDataHandle a, b; + auto a_data = + CreateR2Parameter(AValsUpperComplex(), 0, "a", &builder, &a); + auto b_data = + CreateR2Parameter(BValsLeftComplex(), 1, "b", &builder, &b); + auto result = TriangularSolve(&builder, a, b, + /*left_side=*/true, /*lower=*/false, + /*transpose_a=*/true, /*conjugate_a=*/false, + /*block_size=*/2); + TF_ASSERT_OK(result.status()); + + xla::Array2D expected({ + {0.5, 1., 1.5}, + {0.41666667, 0.33333333, 0.25}, + {complex64(0.20020325, -2.81504065e-01), + complex64(0.13821138, -4.22764228e-01), + complex64(0.07621951, -5.64024390e-01)}, + {complex64(0.19678492, 2.55912786e-01), + complex64(0.17738359, 3.84331116e-01), + complex64(0.15798226, 5.12749446e-01)}, + }); + + ComputeAndCompareR2(&builder, expected, + {a_data.get(), b_data.get()}, + xla::ErrorSpec(1e-2, 1e-2)); +} + +XLA_TEST_F(TriangularSolveLeftLookingTest, Simple) { + xla::ComputationBuilder builder(client_, TestName()); + + xla::ComputationDataHandle a, b; + auto a_data = CreateR2Parameter(AValsLower(), 0, "a", &builder, &a); + auto b_data = CreateR2Parameter(BValsLeft(), 1, "b", &builder, &b); + auto result = TriangularSolveLeftLooking(&builder, a, b, + /*transpose_a=*/false, + /*conjugate_a=*/false); + TF_ASSERT_OK(result.status()); + + xla::Array2D expected({ + {0.5, 1.0, 1.5}, + {0.41666667, 0.33333333, 0.25}, + {0.23148148, 0.18518519, 0.13888889}, + {0.16835017, 0.13468013, 0.1010101}, + }); + + ComputeAndCompareR2(&builder, expected, {a_data.get(), b_data.get()}, + xla::ErrorSpec(1e-2, 1e-2)); +} + +XLA_TEST_F(TriangularSolveLeftLookingTest, NonzeroUpperTriangle) { + xla::ComputationBuilder builder(client_, TestName()); + + xla::ComputationDataHandle a, b; + auto a_data = CreateR2Parameter(AValsFull(), 0, "a", &builder, &a); + auto b_data = CreateR2Parameter(BValsLeft(), 1, "b", &builder, &b); + auto result = TriangularSolveLeftLooking(&builder, a, b, + /*transpose_a=*/false, + /*conjugate_a=*/false); + TF_ASSERT_OK(result.status()); + + xla::Array2D expected({ + {0.5, 1.0, 1.5}, + {0.41666667, 0.33333333, 0.25}, + {0.23148148, 0.18518519, 0.13888889}, + {0.16835017, 0.13468013, 0.1010101}, + }); + + ComputeAndCompareR2(&builder, expected, {a_data.get(), b_data.get()}, + xla::ErrorSpec(1e-2, 1e-2)); } } // namespace diff --git a/tensorflow/compiler/tf2xla/lib/util.cc b/tensorflow/compiler/tf2xla/lib/util.cc index ce24b61b5dc7176f3caa05e3eb9257399fef7926..9b7492f8cf6e86498d7e2f5d42e42ea978c664d8 100644 --- a/tensorflow/compiler/tf2xla/lib/util.cc +++ b/tensorflow/compiler/tf2xla/lib/util.cc @@ -107,4 +107,15 @@ xla::StatusOr UpdateSliceInMinorDims( return UpdateSlice(builder, x, update, padded_start); } +xla::StatusOr TransposeInMinorDims( + xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x) { + TF_ASSIGN_OR_RETURN(std::unique_ptr shape, builder->GetShape(x)); + const int64 n_dims = xla::ShapeUtil::Rank(*shape); + TF_RET_CHECK(n_dims >= 2); + std::vector permutation(n_dims); + std::iota(permutation.begin(), permutation.end(), 0); + std::swap(permutation[n_dims - 1], permutation[n_dims - 2]); + return builder->Transpose(x, permutation); +} + } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/util.h b/tensorflow/compiler/tf2xla/lib/util.h index fb138b4f736500aac8184770d97fbf930ced69ea..7f93102ee78bec60018814975a0badfeb7874aa6 100644 --- a/tensorflow/compiler/tf2xla/lib/util.h +++ b/tensorflow/compiler/tf2xla/lib/util.h @@ -49,6 +49,10 @@ xla::StatusOr UpdateSliceInMinorDims( xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x, const xla::ComputationDataHandle& update, gtl::ArraySlice start); +// Transposes a stack of matrices `x` by swapping the last two dimensions. +xla::StatusOr TransposeInMinorDims( + xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x); + } // namespace tensorflow #endif // TENSORFLOW_COMPILER_TF2XLA_LIB_UTIL_H_ diff --git a/tensorflow/compiler/tf2xla/tf2xla.cc b/tensorflow/compiler/tf2xla/tf2xla.cc index 906f2290433face4cce3296b2f815d50d8c496ce..6051d7dffd7493d8cffb07c1b5d10500e7e75522 100644 --- a/tensorflow/compiler/tf2xla/tf2xla.cc +++ b/tensorflow/compiler/tf2xla/tf2xla.cc @@ -241,9 +241,7 @@ Status CreateXlaArgs(const Graph& graph, XlaCompiler::Argument arg; arg.kind = XlaCompiler::Argument::kParameter; TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "T", &arg.type)); - TensorShape shape; - TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), kShapeAttr, &shape)); - TF_RETURN_IF_ERROR(TensorShapeToXLAShape(arg.type, shape, &arg.shape)); + TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), kShapeAttr, &arg.shape)); TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), kDebugNameAttr, &arg.name)); xla_args->push_back(arg); } diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index 69b265436bb19bbbdd9deb872f4097d4bac7ea52..c5b4ec5b15f90eb43c61cddb7bfd7640fa237a3d 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -66,13 +66,14 @@ Status CheckSignature(const DataTypeVector& types, bool XlaCompiler::Argument::operator==( const XlaCompiler::Argument& other) const { - if (std::tie(kind, resource_kind, type, name, tensor_array_size, + if (std::tie(kind, resource_kind, type, name, initialized, tensor_array_size, tensor_array_gradients) != std::tie(other.kind, other.resource_kind, other.type, other.name, - other.tensor_array_size, other.tensor_array_gradients)) { + other.initialized, other.tensor_array_size, + other.tensor_array_gradients)) { return false; } - if (!xla::ShapeUtil::Equal(shape, other.shape)) { + if (shape != other.shape) { return false; } if (constant_value.shape() != other.constant_value.shape()) { @@ -230,6 +231,64 @@ Status XlaCompiler::CompileFunction(const XlaCompiler::CompileOptions& options, return Status::OK(); } +// Computes the XLA shape for argument 'arg'. +/*static*/ Status XlaCompiler::XLAShapeForArgument( + const XlaCompiler::Argument& arg, xla::Shape* xla_shape) { + switch (arg.kind) { + case XlaCompiler::Argument::kConstant: + return TensorShapeToXLAShape(arg.type, arg.constant_value.shape(), + xla_shape); + case XlaCompiler::Argument::kParameter: + return TensorShapeToXLAShape(arg.type, arg.shape, xla_shape); + case XlaCompiler::Argument::kResource: { + TF_RET_CHECK(arg.initialized); + + switch (arg.resource_kind) { + case XlaResource::kVariable: + return TensorShapeToXLAShape(arg.type, arg.shape, xla_shape); + case XlaResource::kTensorArray: { + if (arg.tensor_array_size < 0) { + return errors::InvalidArgument( + "Negative tensor_array_size in XLAShapeForArgument"); + } + TensorShape shape; + shape.AddDim(arg.tensor_array_size); + shape.AppendShape(arg.shape); + TF_RETURN_IF_ERROR(TensorShapeToXLAShape(arg.type, shape, xla_shape)); + + if (!arg.tensor_array_gradients.empty()) { + std::vector tuple_shape( + arg.tensor_array_gradients.size() + 1, *xla_shape); + *xla_shape = xla::ShapeUtil::MakeTupleShape(tuple_shape); + } + return Status::OK(); + } + case XlaResource::kStack: { + if (arg.tensor_array_size < 0) { + return errors::InvalidArgument( + "Negative tensor_array_size in XLAShapeForArgument"); + } + TensorShape shape; + shape.AddDim(arg.tensor_array_size); + shape.AppendShape(arg.shape); + xla::Shape buffer_shape; + TF_RETURN_IF_ERROR( + TensorShapeToXLAShape(arg.type, shape, &buffer_shape)); + *xla_shape = xla::ShapeUtil::MakeTupleShape( + {buffer_shape, xla::ShapeUtil::MakeShape(xla::S32, {})}); + return Status::OK(); + } + + case XlaResource::kInvalid: + return errors::Internal( + "Invalid resource type in XLAShapeForArgument()"); + } + } + case XlaCompiler::Argument::kInvalid: + return errors::Internal("Invalid argument type in XLAShapeForArgument()"); + } +} + namespace { Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr graph, @@ -275,8 +334,9 @@ Status BuildArguments(const Graph& graph, // Argument numbers of arguments and resources that are to be passed to the // XLA computation as runtime parameters. - std::vector parameters, resources; - parameters.reserve(args.size()); + input_mapping->clear(); + input_mapping->reserve(args.size()); + std::vector resources; resources.reserve(args.size()); // Fills in constant arguments, and computes non-constant argument order. @@ -290,18 +350,20 @@ Status BuildArguments(const Graph& graph, // TODO(phawkins): this code assumes that resource arguments do not // alias. XlaResource* resource; - TF_RETURN_IF_ERROR( - context->CreateResource(arg.resource_kind, i, arg.name, arg.type, - xla::ComputationDataHandle(), &resource)); - resource->set_tensor_array_size(arg.tensor_array_size); + TF_RETURN_IF_ERROR(context->CreateResource( + arg.resource_kind, i, arg.name, arg.type, arg.shape, + xla::ComputationDataHandle(), + /*tensor_array_size=*/arg.tensor_array_size, + /*tensor_array_gradients=*/arg.tensor_array_gradients, &resource)); arg_expression.set_resource(resource); if (arg.initialized) { resources.push_back(i); } break; - case XlaCompiler::Argument::kParameter: - parameters.push_back(i); + case XlaCompiler::Argument::kParameter: { + input_mapping->push_back(i); break; + } case XlaCompiler::Argument::kConstant: arg_expression.set_constant_value(arg.constant_value); break; @@ -312,19 +374,17 @@ Status BuildArguments(const Graph& graph, // Append parameters containing variable values after the other runtime // parameters. - parameters.insert(parameters.end(), resources.begin(), resources.end()); - if (parameters.empty()) { + input_mapping->insert(input_mapping->end(), resources.begin(), + resources.end()); + if (input_mapping->empty()) { return Status::OK(); } - std::vector arg_shapes; - arg_shapes.reserve(parameters.size()); - input_mapping->resize(parameters.size()); - for (std::vector::size_type i = 0; i < parameters.size(); ++i) { - const XlaCompiler::Argument& arg = args[parameters[i]]; + std::vector arg_shapes(input_mapping->size()); + for (std::vector::size_type i = 0; i < input_mapping->size(); ++i) { // Computes the shapes of non-constant arguments. - arg_shapes.push_back(arg.shape); - (*input_mapping)[i] = parameters[i]; + TF_RETURN_IF_ERROR(XlaCompiler::XLAShapeForArgument( + args[(*input_mapping)[i]], &arg_shapes[i])); } if (use_tuple_arg) { @@ -354,13 +414,13 @@ Status BuildArguments(const Graph& graph, } // Build parameter handles for non-constant arguments. - std::vector arg_handles(parameters.size()); + std::vector arg_handles(input_mapping->size()); if (use_tuple_arg) { xla::ComputationDataHandle tuple; if (is_entry_computation) { xla::OpSharding tuple_sharding; tuple_sharding.set_type(xla::OpSharding::Type::OpSharding_Type_TUPLE); - for (int64 parameter : parameters) { + for (int64 parameter : *input_mapping) { const int core = (*arg_cores)[parameter]; const int root_device = 0; *tuple_sharding.add_tuple_shardings() = @@ -373,16 +433,16 @@ Status BuildArguments(const Graph& graph, } else { tuple = builder->Parameter(0, (*input_shapes)[0], "arg_tuple"); } - for (std::vector::size_type i = 0; i < parameters.size(); ++i) { - const int core = (*arg_cores)[parameters[i]]; + for (std::vector::size_type i = 0; i < input_mapping->size(); ++i) { + const int core = (*arg_cores)[input_mapping->at(i)]; xla::ScopedShardingAssignment assign_sharding( builder, core == -1 ? tensorflow::gtl::optional() : xla::sharding_builder::AssignDevice(core)); arg_handles[i] = builder->GetTupleElement(tuple, i); } } else { - for (std::vector::size_type i = 0; i < parameters.size(); ++i) { - const int core = (*arg_cores)[parameters[i]]; + for (std::vector::size_type i = 0; i < input_mapping->size(); ++i) { + const int core = (*arg_cores)[input_mapping->at(i)]; xla::ScopedShardingAssignment assign_sharding( builder, core == -1 ? tensorflow::gtl::optional() : xla::sharding_builder::AssignDevice(core)); @@ -393,19 +453,18 @@ Status BuildArguments(const Graph& graph, // Fill in the handles in non-constant arguments. VLOG(2) << "XLA computation inputs:"; - for (std::vector::size_type i = 0; i < parameters.size(); ++i) { - const XlaCompiler::Argument& arg = args[parameters[i]]; + for (std::vector::size_type i = 0; i < input_mapping->size(); ++i) { + const XlaCompiler::Argument& arg = args[input_mapping->at(i)]; VLOG(2) << " XLA arg " << i << " shape: " << xla::ShapeUtil::HumanString(arg_shapes[i]) - << " name: " << arg.name << " TF arg " << parameters[i]; - XlaExpression& arg_expression = (*arg_expressions)[parameters[i]]; + << " name: " << arg.name << " TF arg " << input_mapping->at(i); + XlaExpression& arg_expression = (*arg_expressions)[input_mapping->at(i)]; switch (arg.kind) { case XlaCompiler::Argument::kResource: { TF_RET_CHECK(arg.initialized); XlaResource* resource = arg_expression.resource(); - TF_RETURN_IF_ERROR( - resource->SetFromPack(arg.tensor_array_gradients, arg_handles[i], - /*reset_initial_values=*/true, builder)); + TF_RETURN_IF_ERROR(resource->SetFromPack(arg.tensor_array_gradients, + arg_handles[i], builder)); VLOG(2) << " resource: num_gradients: " << arg.tensor_array_gradients.size(); break; @@ -486,6 +545,7 @@ Status BuildComputation( XlaCompiler::ResourceUpdate& update = resource_updates->back(); update.input_index = resource->arg_num(); update.type = resource->type(); + update.shape = resource->shape(); update.modified = modified; for (const auto& grad : resource->tensor_array_gradients()) { update.tensor_array_gradients_accessed.insert(grad.first); @@ -616,13 +676,6 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, ++computation_output; } } - - for (std::vector::size_type i = 0; - i < result->resource_updates.size(); ++i) { - result->resource_updates[i].shape = xla::ShapeUtil::GetTupleElementShape( - result->xla_output_shape, computation_output); - ++computation_output; - } return Status::OK(); } diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h index 30d3c05ee9aa33accc0ad122901f70b0b6613104..b86c82c0ab5ce379d35a13043857f459199e2ad2 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.h +++ b/tensorflow/compiler/tf2xla/xla_compiler.h @@ -104,9 +104,17 @@ class XlaCompiler { // is the type of the variable's value, not DT_RESOURCE. DataType type; - // The shape of the argument. If the argument is a resource, this is the - // shape of the resource's value. - xla::Shape shape; + // The shape of the argument. For: + // * a parameter: the shape of the parameter. + // * a constant: ignored; the shape given by constant_value is used + // instead. + // * an uninitialized resource: ignored. We don't yet know the shape of an + // uninitialized resource (otherwise we would have initialized it!) + // * an initialized variable: the shape of the variable's value. + // * an initialized TensorArray or Stack resource: the shape of an entry in + // the TensorArray/Stack. Note this is the size of a single entry, not the + // XLA data structure that represents the complete stack/array. + TensorShape shape; // The value of the argument, if it is a compile-time constant. Must be a // host-memory tensor. @@ -175,8 +183,9 @@ class XlaCompiler { int input_index; // Type and shape of the tensor to be written back. + // The `shape` field has the same meaning as the Argument::shape field. DataType type; - xla::Shape shape; + TensorShape shape; // Was the value of the variable modified by the computation? // (Always true, unless `return_updated_values_for_all_resources` is true.) @@ -266,11 +275,10 @@ class XlaCompiler { const std::vector& args, CompilationResult* result); - Status PrepareArguments(xla::ComputationBuilder* builder, NameAttrList func, - const std::vector& types, - const std::vector& shapes, - const std::vector& expressions, - std::vector* args); + // Returns the shape of the XLA parameter for an argument 'arg'. + // See the class comment for more details about the argument passing + // convention. + static Status XLAShapeForArgument(const Argument& arg, xla::Shape* xla_shape); // Retrieves the channel handle associated with `key`. Allocates // a new channel handle if none exists. diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc index 7ebe4b75bc1e33e506624314b11163e36a2477de..65de4dbad75b7fb76a041bc799fc31dc5cb80d74 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc @@ -191,10 +191,10 @@ TEST_F(XlaCompilerTest, Simple) { std::vector args(2); args[0].kind = XlaCompiler::Argument::kParameter; args[0].type = DT_INT32; - args[0].shape = xla::ShapeUtil::MakeShape(xla::S32, {2}); + args[0].shape = TensorShape({2}); args[1].kind = XlaCompiler::Argument::kParameter; args[1].type = DT_INT32; - args[1].shape = xla::ShapeUtil::MakeShape(xla::S32, {2}); + args[1].shape = TensorShape({2}); // Compiles the graph. XlaCompiler compiler(DefaultOptions()); @@ -242,10 +242,10 @@ TEST_F(XlaCompilerTest, HasSaneErrorOnNonCompileTimeConstantInputToReshape) { std::vector args(2); args[0].kind = XlaCompiler::Argument::kParameter; args[0].type = DT_INT32; - args[0].shape = xla::ShapeUtil::MakeShape(xla::S32, {2}); + args[0].shape = TensorShape({2}); args[1].kind = XlaCompiler::Argument::kParameter; args[1].type = DT_INT32; - args[1].shape = xla::ShapeUtil::MakeShape(xla::S32, {2}); + args[1].shape = TensorShape({2}); // Compiles the graph. XlaCompiler compiler(DefaultOptions()); @@ -281,7 +281,7 @@ TEST_F(XlaCompilerTest, ConstantOutputs) { std::vector args(1); args[0].kind = XlaCompiler::Argument::kParameter; args[0].type = DT_INT32; - args[0].shape = xla::ShapeUtil::MakeShape(xla::S32, {2}); + args[0].shape = TensorShape({2}); XlaCompiler::Options options = DefaultOptions(); XlaCompiler compiler(options); @@ -373,7 +373,7 @@ TEST_F(XlaCompilerTest, ResourceManager) { std::vector args(1); args[0].kind = XlaCompiler::Argument::kParameter; args[0].type = DT_INT32; - args[0].shape = xla::ShapeUtil::MakeShape(xla::S32, {2}); + args[0].shape = TensorShape({2}); DummyResourceForTest* resource = new DummyResourceForTest(); @@ -420,7 +420,7 @@ TEST_F(XlaCompilerTest, DeterministicCompilation) { std::vector args(1); args[0].kind = XlaCompiler::Argument::kParameter; args[0].type = DT_INT32; - args[0].shape = xla::ShapeUtil::MakeShape(xla::S32, {2}); + args[0].shape = TensorShape({2}); // Compiles the graph. auto options = DefaultOptions(); @@ -472,9 +472,7 @@ TEST_F(XlaCompilerTest, CanPassTensorArraysToAndFromComputation) { args[0].resource_kind = XlaResource::kTensorArray; args[0].initialized = true; args[0].type = DT_INT32; - args[0].shape = xla::ShapeUtil::MakeTupleShape( - {xla::ShapeUtil::MakeShape(xla::S32, {2}), - xla::ShapeUtil::MakeShape(xla::S32, {2})}); + args[0].shape = TensorShape({}); args[0].tensor_array_size = 2; args[0].tensor_array_gradients = {"grad2"}; @@ -540,9 +538,7 @@ TEST_F(XlaCompilerTest, UnwrittenTensorArrayGradientsAreNotComputationOutputs) { args[0].resource_kind = XlaResource::kTensorArray; args[0].initialized = true; args[0].type = DT_INT32; - args[0].shape = xla::ShapeUtil::MakeTupleShape( - {xla::ShapeUtil::MakeShape(xla::S32, {2}), - xla::ShapeUtil::MakeShape(xla::S32, {2})}); + args[0].shape = TensorShape({}); args[0].tensor_array_size = 2; args[0].tensor_array_gradients = {"grad1"}; @@ -574,9 +570,7 @@ TEST_F(XlaCompilerTest, NewTensorArrayGradientsAreComputationOutputs) { args[0].resource_kind = XlaResource::kTensorArray; args[0].initialized = true; args[0].type = DT_INT32; - args[0].shape = xla::ShapeUtil::MakeTupleShape( - {xla::ShapeUtil::MakeShape(xla::S32, {2}), - xla::ShapeUtil::MakeShape(xla::S32, {2})}); + args[0].shape = TensorShape({}); args[0].tensor_array_size = 2; args[0].tensor_array_gradients = {"grad1"}; diff --git a/tensorflow/compiler/tf2xla/xla_context.cc b/tensorflow/compiler/tf2xla/xla_context.cc index e8d17e2e0a1ba01f16d4bbbd2895b112f4dd1989..73878955e3fd54c103c0b07faf7f5ee5bcd84de0 100644 --- a/tensorflow/compiler/tf2xla/xla_context.cc +++ b/tensorflow/compiler/tf2xla/xla_context.cc @@ -103,12 +103,14 @@ Status XlaContext::AddConstRetval(int retval_index, DataType dtype, xla::ComputationBuilder* XlaContext::builder() { return builder_; } -Status XlaContext::CreateResource(XlaResource::Kind kind, int arg_num, - string name, DataType type, - const xla::ComputationDataHandle& handle, - XlaResource** resource) { +Status XlaContext::CreateResource( + XlaResource::Kind kind, int arg_num, string name, DataType type, + TensorShape shape, const xla::ComputationDataHandle& handle, + int64 tensor_array_size, const std::set& tensor_array_gradients, + XlaResource** resource) { resources_.emplace_back( - new XlaResource(kind, arg_num, std::move(name), type, handle)); + new XlaResource(kind, arg_num, std::move(name), type, std::move(shape), + handle, tensor_array_size, tensor_array_gradients)); *resource = resources_.back().get(); return Status::OK(); } diff --git a/tensorflow/compiler/tf2xla/xla_context.h b/tensorflow/compiler/tf2xla/xla_context.h index 1a7dafe8cdb56cc9b8fcd3ba6e262c21c2a07d90..fac0352ae81e24597e1045981ac47a7cd09481da 100644 --- a/tensorflow/compiler/tf2xla/xla_context.h +++ b/tensorflow/compiler/tf2xla/xla_context.h @@ -71,11 +71,15 @@ class XlaContext : public ResourceBase { Status AddConstRetval(int retval_index, DataType dtype, const xla::Literal& literal); - // Creates a resource with resource `kind` and initial type `type` and - // value `handle`. `name` is a descriptive name for use in error messages. + // Creates a resource with resource `kind` and initial value `handle`. `name` + // is a descriptive name for use in error messages. See the `XlaResource` + // constructor for a description of the remaining arguments. // Fails if the resource already exists. Status CreateResource(XlaResource::Kind kind, int arg_num, string name, - DataType type, const xla::ComputationDataHandle& handle, + DataType type, TensorShape shape, + const xla::ComputationDataHandle& handle, + int64 tensor_array_size, + const std::set& tensor_array_gradients, XlaResource** resource); const std::vector>& resources() { diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc index ee0aed672e1b264fee0a7f381c334400c55f3581..ee29158646fa96fe554d089e11d50afb47e3e300 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc @@ -286,7 +286,8 @@ Status XlaOpKernelContext::ConstantInputList( } Status XlaOpKernelContext::ReadVariableInput( - int index, xla::ComputationDataHandle* value) { + int index, DataType type, TensorShape* shape, + xla::ComputationDataHandle* value) { const Tensor& tensor = context_->input(index); const XlaExpression* expression = CastExpressionFromTensor(tensor); XlaResource* variable = expression->resource(); @@ -296,7 +297,15 @@ Status XlaOpKernelContext::ReadVariableInput( return errors::InvalidArgument("Read of uninitialized variable ", variable->name()); } + if (variable->type() != type) { + return errors::InvalidArgument( + "Type mismatch for read of variable ", variable->name(), ". Expected ", + DataTypeString(type), "; got ", DataTypeString(variable->type())); + } *value = variable->value(); + if (shape) { + *shape = variable->shape(); + } return Status::OK(); } @@ -312,12 +321,7 @@ Status XlaOpKernelContext::GetVariableTypeAndShape(int index, DataType* type, variable->name()); } *type = variable->type(); - auto shape_or_status = builder()->GetShape(variable->value()); - if (!shape_or_status.ok()) { - return shape_or_status.status(); - } - TF_RETURN_IF_ERROR( - XLAShapeToTensorShape(*shape_or_status.ValueOrDie(), shape)); + *shape = variable->shape(); return Status::OK(); } @@ -405,7 +409,17 @@ Status XlaOpKernelContext::AssignVariable( XlaResource* variable = expression->resource(); TF_RET_CHECK(variable != nullptr); TF_RET_CHECK(variable->kind() == XlaResource::kVariable); - return variable->SetValue(type, handle); + + auto shape_or_status = builder()->GetShape(handle); + if (!shape_or_status.ok()) { + return shape_or_status.status(); + } + TensorShape shape; + TF_RETURN_IF_ERROR( + XLAShapeToTensorShape(*shape_or_status.ValueOrDie(), &shape)); + + TF_RETURN_IF_ERROR(variable->SetTypeAndShape(type, shape)); + return variable->SetValue(handle); } XlaCompiler* XlaOpKernelContext::compiler() const { diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h index 6d3b6db2289d6c0b8f266062f9f3baca1145154a..e1fd0f55c6d2501b4813c90171630a8df567f78a 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.h +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h @@ -164,11 +164,16 @@ class XlaOpKernelContext { TensorShape* shape) const; // Reads the current value of the resouce variable referred to by input - // 'index'. - Status ReadVariableInput(int index, xla::ComputationDataHandle* value); + // 'index'. If `shape` is not nullptr, sets `*shape` to the shape of the + // variable. Returns an error if the variable has not been initialized, or if + // its type does not match `type`. + Status ReadVariableInput(int index, DataType type, TensorShape* shape, + xla::ComputationDataHandle* value); // Assigns the value `handle` to the variable referenced by input - // `input_index`. Marks the operator as having side effects. + // `input_index`. The variable must be of `type`. Returns an error if the + // variable has been initialized with a different type or with a + // different shape. Status AssignVariable(int input_index, DataType type, const xla::ComputationDataHandle& handle); diff --git a/tensorflow/compiler/tf2xla/xla_resource.cc b/tensorflow/compiler/tf2xla/xla_resource.cc index 9abac8bdaa77c99a57b2f8ac66fe6ed06fbcd102..c2075b44b82ba279d1246ec6bfcf305d12c418a6 100644 --- a/tensorflow/compiler/tf2xla/xla_resource.cc +++ b/tensorflow/compiler/tf2xla/xla_resource.cc @@ -25,51 +25,99 @@ limitations under the License. namespace tensorflow { -XlaResource::XlaResource(Kind kind, int arg_num, string name, - DataType initial_type, - const xla::ComputationDataHandle& initial_value) +XlaResource::XlaResource(Kind kind, int arg_num, string name, DataType type, + TensorShape shape, + const xla::ComputationDataHandle& initial_value, + int64 tensor_array_size, + const std::set& tensor_array_gradients) : kind_(kind), arg_num_(arg_num), name_(std::move(name)), - type_(initial_type), + type_(type), + shape_(std::move(shape)), value_(initial_value), - initial_value_(initial_value) { + initial_value_(initial_value), + tensor_array_size_(tensor_array_size) { CHECK(kind_ != kInvalid); + + for (const string& gradient : tensor_array_gradients) { + tensor_array_gradients_[gradient].reset( + new XlaResource(/*kind=*/kTensorArray, /*arg_num=*/-1, + /*name=*/strings::StrCat("TensorArrayGrad: ", name_), + type_, shape_, xla::ComputationDataHandle(), + tensor_array_size_, /*tensor_array_gradients=*/{})); + } } -Status XlaResource::SetValue(DataType type, - const xla::ComputationDataHandle& value) { - if (type_ == DT_INVALID && type == DT_INVALID) { - return errors::InvalidArgument("Attempted to initialized resource ", name_, - " to an invalid type"); +Status XlaResource::SetTypeAndShape(DataType type, const TensorShape& shape) { + if (type == DT_INVALID) { + return errors::InvalidArgument("Attempted to set type of resource '", name_, + "'' to an invalid type"); } - if (type_ != DT_INVALID && type_ != type) { + if (initialized() && type_ != type) { return errors::InvalidArgument("Type of resource ", name_, " cannot be changed after initialization: " "old type was ", DataTypeString(type_), ", new type is ", DataTypeString(type)); } + if (initialized() && shape_ != shape) { + return errors::InvalidArgument("Shape of resource ", name_, + " cannot be changed after initialization: " + "old shape was ", + shape_.DebugString(), ", new shape is ", + shape.DebugString()); + } type_ = type; - value_ = value; + shape_ = shape; return Status::OK(); } -Status XlaResource::GetXlaShape(xla::ComputationBuilder* builder, - xla::Shape* shape) const { - auto shape_or_status = builder->GetShape(value_); - if (!shape_or_status.ok()) { - return shape_or_status.status(); +Status XlaResource::SetValue(const xla::ComputationDataHandle& value) { + if (type_ == DT_INVALID) { + return errors::InvalidArgument( + "Resource '", name_, + "' must be initialized with a valid type before use."); } - *shape = *shape_or_status.ValueOrDie(); + value_ = value; return Status::OK(); } -Status XlaResource::GetShape(xla::ComputationBuilder* builder, - TensorShape* shape) const { - xla::Shape xla_shape; - TF_RETURN_IF_ERROR(GetXlaShape(builder, &xla_shape)); - TF_RETURN_IF_ERROR(XLAShapeToTensorShape(xla_shape, shape)); +Status XlaResource::SetZeroValue(xla::ComputationBuilder* builder) { + if (type_ == DT_INVALID) { + return errors::InvalidArgument( + "Resource '", name_, + "' must be initialized with a valid type before use."); + } + switch (kind_) { + case kVariable: { + value_ = builder->Broadcast(XlaHelpers::Zero(builder, type_), + shape_.dim_sizes()); + break; + } + case kTensorArray: { + TensorShape ta_shape; + ta_shape.AddDim(tensor_array_size_); + ta_shape.AppendShape(shape_); + value_ = builder->Broadcast(XlaHelpers::Zero(builder, type_), + ta_shape.dim_sizes()); + break; + } + case kStack: { + TensorShape ta_shape; + ta_shape.AddDim(tensor_array_size_); + ta_shape.AppendShape(shape_); + value_ = + builder->Tuple({builder->Broadcast(XlaHelpers::Zero(builder, type_), + ta_shape.dim_sizes()), + builder->ConstantR0(0)}); + break; + } + + case kInvalid: + default: + LOG(FATAL) << "Invalid resource type"; + } return Status::OK(); } @@ -82,36 +130,20 @@ Status XlaResource::GetOrCreateTensorArrayGradient( std::unique_ptr& gradient = tensor_array_gradients_[source]; if (!gradient) { TensorShape ta_shape; - TF_RETURN_IF_ERROR(GetShape(builder, &ta_shape)); + ta_shape.AddDim(tensor_array_size_); + ta_shape.AppendShape(shape_); xla::ComputationDataHandle gradient_value = builder->Broadcast( XlaHelpers::Zero(builder, type_), ta_shape.dim_sizes()); gradient.reset( new XlaResource(/*kind=*/kTensorArray, /*arg_num=*/-1, /*name=*/strings::StrCat("TensorArrayGrad: ", name_), - type_, gradient_value)); - gradient->tensor_array_size_ = tensor_array_size_; + type_, shape_, gradient_value, tensor_array_size_, + /*tensor_array_gradients=*/{})); } *gradient_out = gradient.get(); return Status::OK(); } -Status XlaResource::PackedShape(xla::ComputationBuilder* builder, - xla::Shape* packed_shape) const { - if (tensor_array_gradients_.empty()) { - return GetXlaShape(builder, packed_shape); - } - TF_RET_CHECK(kind_ == kTensorArray); - std::vector elem_shapes(1 + tensor_array_gradients_.size()); - int pos = 0; - TF_RETURN_IF_ERROR(GetXlaShape(builder, &elem_shapes[pos++])); - for (const auto& gradient : tensor_array_gradients_) { - TF_RETURN_IF_ERROR( - gradient.second->GetXlaShape(builder, &elem_shapes[pos++])); - } - *packed_shape = xla::ShapeUtil::MakeTupleShape(elem_shapes); - return Status::OK(); -} - Status XlaResource::Pack(xla::ComputationDataHandle* pack, xla::ComputationBuilder* builder) const { if (tensor_array_gradients_.empty()) { @@ -130,27 +162,32 @@ Status XlaResource::Pack(xla::ComputationDataHandle* pack, Status XlaResource::SetFromPack(const std::set& gradient_sources, const xla::ComputationDataHandle& pack, - bool reset_initial_values, xla::ComputationBuilder* builder) { if (gradient_sources.empty()) { + if (!initialized()) { + initial_value_ = pack; + } value_ = pack; } else { TF_RET_CHECK(kind_ == kTensorArray); int pos = 0; - value_ = builder->GetTupleElement(pack, pos++); + auto v = builder->GetTupleElement(pack, pos++); + if (!initialized()) { + initial_value_ = v; + } + value_ = v; + for (const auto& source : gradient_sources) { XlaResource* gradient; TF_RETURN_IF_ERROR( GetOrCreateTensorArrayGradient(source, builder, &gradient)); - gradient->value_ = builder->GetTupleElement(pack, pos++); - if (reset_initial_values) { - gradient->initial_value_ = gradient->value_; + auto v = builder->GetTupleElement(pack, pos++); + if (!gradient->initialized()) { + gradient->initial_value_ = v; } + gradient->value_ = v; } } - if (reset_initial_values) { - initial_value_ = value_; - } return Status::OK(); } diff --git a/tensorflow/compiler/tf2xla/xla_resource.h b/tensorflow/compiler/tf2xla/xla_resource.h index 6b46089e4f5e10c195bb59f78c33305c2fa3f84d..1bb2c7274ecdf0954768fd96def51194e52deee8 100644 --- a/tensorflow/compiler/tf2xla/xla_resource.h +++ b/tensorflow/compiler/tf2xla/xla_resource.h @@ -36,8 +36,11 @@ class XlaResource { kStack, }; - XlaResource(Kind kind, int arg_num, string name, DataType initial_type, - const xla::ComputationDataHandle& initial_value); + XlaResource(Kind kind, int arg_num, string name, DataType type, + TensorShape shape, + const xla::ComputationDataHandle& initial_value, + int64 tensor_array_size, + const std::set& tensor_array_gradients); XlaResource(const XlaResource&) = delete; XlaResource(XlaResource&&) = delete; @@ -60,6 +63,12 @@ class XlaResource { // a resource is first initialized we do not yet know its type, so we keep // track of its type dynamically. DataType type() const { return type_; } + + // Shape of the resource. For an uninitialized resource, this is ignored. + // For a Variable, this is the shape of the value. For a TensorArray or Stack + // this is the shape of each entry in the TensorArray/Stack. + const TensorShape& shape() const { return shape_; } + const xla::ComputationDataHandle& value() const { return value_; } // Value of the resource at computation entry. Used to detect which @@ -68,17 +77,19 @@ class XlaResource { return initial_value_; } + // A variable is initialized if it has a value. bool initialized() const { return value_.handle() > 0; } - // Sets the current type/value of the resource. - Status SetValue(DataType type, const xla::ComputationDataHandle& value); + // Sets the type and shape of the resource. The type and shape of a resource + // must not change once the variable has been initialized. + Status SetTypeAndShape(DataType type, const TensorShape& shape); - // Returns the shape of the resource as an xla::Shape. - Status GetXlaShape(xla::ComputationBuilder* builder, xla::Shape* shape) const; + // Sets the current value of the resource. Returns an error if the type is not + // set to a valid value. + Status SetValue(const xla::ComputationDataHandle& value); - // Returns the shape of the resource as an TensorShape. Fails if the shape is - // not representable as a TensorShape. - Status GetShape(xla::ComputationBuilder* builder, TensorShape* shape) const; + // Sets the current value of the resource to an all-zero value. + Status SetZeroValue(xla::ComputationBuilder* builder); // Looks up the gradient for `source`, or creates it if it does not already // exist. The call target must be an initialized TensorArray resource. A @@ -96,10 +107,6 @@ class XlaResource { Status Pack(xla::ComputationDataHandle* pack, xla::ComputationBuilder* builder) const; - // Returns the shape of the `pack` value computed by `Pack()`. - Status PackedShape(xla::ComputationBuilder* builder, - xla::Shape* packed_shape) const; - // Updates the resource with values from `pack`. If `gradient_sources` is // non-empty, treats `pack` as a tuple that represents a TensorArray and // its gradients, and unpacks and updates the gradient resources. @@ -108,14 +115,14 @@ class XlaResource { // Opposite of Pack(). Status SetFromPack(const std::set& gradient_sources, const xla::ComputationDataHandle& pack, - bool reset_initial_values, xla::ComputationBuilder* builder); - // TensorArray-specific fields + // TensorArray and Stack specific fields // 'tensor_array_size' stores the expected size of the TensorArray or Stack. // We need to store this since sometimes TensorArrays must be initialized // lazily since we do not know the element shape at construction time. + // Used by both TensorArrays and Stacks. int64 tensor_array_size() const { return tensor_array_size_; } void set_tensor_array_size(int64 size) { tensor_array_size_ = size; } @@ -136,6 +143,7 @@ class XlaResource { const string name_; DataType type_; + TensorShape shape_; xla::ComputationDataHandle value_; xla::ComputationDataHandle initial_value_; diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index c22fd37129c5344825631ecf422bbcf3434e4534..34e733bc8d80b364cec1783006eba0a5468b55ea 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -88,7 +88,6 @@ cc_library( visibility = [":friends"], deps = [ "//tensorflow/core:framework_lite", - "//tensorflow/core:lib", "//third_party/eigen3", ], ) diff --git a/tensorflow/compiler/xla/client/BUILD b/tensorflow/compiler/xla/client/BUILD index 952109dde2d1d14845f2dd2fc34118bbce0c7d91..02356699a25e47be50eb15872df4c9c302fc289b 100644 --- a/tensorflow/compiler/xla/client/BUILD +++ b/tensorflow/compiler/xla/client/BUILD @@ -80,6 +80,18 @@ cc_library( ], ) +cc_library( + name = "executable_build_options", + srcs = ["executable_build_options.cc"], + hdrs = ["executable_build_options.h"], + deps = [ + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:device_memory_allocator", + "//tensorflow/core:lib", + ], +) + cc_library( name = "local_client", srcs = ["local_client.cc"], @@ -87,6 +99,7 @@ cc_library( deps = [ ":client", ":computation", + ":executable_build_options", "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", diff --git a/tensorflow/compiler/xla/client/executable_build_options.cc b/tensorflow/compiler/xla/client/executable_build_options.cc new file mode 100644 index 0000000000000000000000000000000000000000..804e34f5e75ce2d153ac7627b94a543fda88e810 --- /dev/null +++ b/tensorflow/compiler/xla/client/executable_build_options.cc @@ -0,0 +1,79 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/executable_build_options.h" + +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/core/lib/strings/stringprintf.h" + +namespace xla { + +ExecutableBuildOptions& ExecutableBuildOptions::set_device_allocator( + DeviceMemoryAllocator* allocator) { + device_allocator_ = allocator; + return *this; +} + +DeviceMemoryAllocator* ExecutableBuildOptions::device_allocator() const { + return device_allocator_; +} + +ExecutableBuildOptions& ExecutableBuildOptions::set_device_ordinal( + int device_ordinal) { + CHECK_GE(device_ordinal, 0); + device_ordinal_ = device_ordinal; + return *this; +} + +int ExecutableBuildOptions::device_ordinal() const { return device_ordinal_; } + +ExecutableBuildOptions& ExecutableBuildOptions::set_result_layout( + const Shape& shape_with_layout) { + result_layout_set_ = true; + result_layout_ = shape_with_layout; + return *this; +} + +const Shape* ExecutableBuildOptions::result_layout() const { + return result_layout_set_ ? &result_layout_ : nullptr; +} + +string ExecutableBuildOptions::ToString() const { + string result_layout = "nullopt"; + if (result_layout_set_) { + result_layout = ShapeUtil::HumanStringWithLayout(result_layout_); + } + string generate_hlo_graph = "nullopt"; + if (generate_hlo_graph_.has_value()) { + generate_hlo_graph = generate_hlo_graph_.value(); + } + return tensorflow::strings::Printf( + "ExecutableBuildOptions{device_ordinal=%d, result_layout=%s, " + "generate_hlo_graph=%s}", + device_ordinal_, result_layout.c_str(), generate_hlo_graph.c_str()); +} + +ExecutableBuildOptions& ExecutableBuildOptions::set_generate_hlo_graph( + string regex) { + generate_hlo_graph_ = std::move(regex); + return *this; +} + +const tensorflow::gtl::optional& +ExecutableBuildOptions::generate_hlo_graph() const { + return generate_hlo_graph_; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/client/executable_build_options.h b/tensorflow/compiler/xla/client/executable_build_options.h new file mode 100644 index 0000000000000000000000000000000000000000..3a52dbac9adb155ad9a7d91a8102707f70fe2fbf --- /dev/null +++ b/tensorflow/compiler/xla/client/executable_build_options.h @@ -0,0 +1,74 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_EXECUTABLE_BUILD_OPTIONS_H_ +#define TENSORFLOW_COMPILER_XLA_CLIENT_EXECUTABLE_BUILD_OPTIONS_H_ + +#include "tensorflow/compiler/xla/service/device_memory_allocator.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/gtl/optional.h" + +namespace xla { + +// Class containing options for building an LocalExecutable with +// LocalClient::Compile. +class ExecutableBuildOptions { + public: + // If set, this is the device to build the computation for. Valid + // device_ordinal values are: 0 to # of devices - 1. These values are + // identical to the device ordinal values used by StreamExecutor. The built + // executable will be executable on any device equivalent to the specified + // device as determined by Backend::devices_equivalent(). A value of -1 + // indicates this option has not been set. + ExecutableBuildOptions& set_device_ordinal(int device_ordinal); + int device_ordinal() const; + + // If set, this specifies the layout of the result of the computation. If not + // set, the service will chose the layout of the result. A Shape is used to + // store the layout to accommodate tuple result shapes. A value of nullptr + // indicates the option has not been set. + ExecutableBuildOptions& set_result_layout(const Shape& shape_with_layout); + const Shape* result_layout() const; + + // If set, this specifies an allocator that can be used to allocate temporary + // space on the device during compilation. For example, the compiler might + // want to run various algorithms on the device and pick the fastest one -- it + // might allocate buffers for use by these algorithms using this allocator. + // + // This does not need to be the same as the DeviceMemoryAllocator passed when + // running the executable. + ExecutableBuildOptions& set_device_allocator( + DeviceMemoryAllocator* allocator); + DeviceMemoryAllocator* device_allocator() const; + + // If set, specifies a regexp of HLO graphs to dump (as in DebugOptions). + ExecutableBuildOptions& set_generate_hlo_graph(string regex); + const tensorflow::gtl::optional& generate_hlo_graph() const; + + // Returns a string representation of the build options, suitable for + // debugging. + string ToString() const; + + private: + int device_ordinal_ = -1; + Shape result_layout_; + bool result_layout_set_ = false; + tensorflow::gtl::optional generate_hlo_graph_; + DeviceMemoryAllocator* device_allocator_ = nullptr; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_CLIENT_EXECUTABLE_BUILD_OPTIONS_H_ diff --git a/tensorflow/compiler/xla/client/local_client.cc b/tensorflow/compiler/xla/client/local_client.cc index e45787fca61dd5aa6b237e25a5d7ba12aeab5613..ef98dbb6403beedb0c08ab9a0fc9e7d4ee31ab3b 100644 --- a/tensorflow/compiler/xla/client/local_client.cc +++ b/tensorflow/compiler/xla/client/local_client.cc @@ -30,35 +30,6 @@ using xla::source_map_util::InvalidParameterArgument; namespace xla { -ExecutableBuildOptions& ExecutableBuildOptions::set_device_ordinal( - int device_ordinal) { - device_ordinal_ = device_ordinal; - return *this; -} - -int ExecutableBuildOptions::device_ordinal() const { return device_ordinal_; } - -ExecutableBuildOptions& ExecutableBuildOptions::set_result_layout( - const Shape& shape_with_layout) { - result_layout_set_ = true; - result_layout_ = shape_with_layout; - return *this; -} - -const Shape* ExecutableBuildOptions::result_layout() const { - return result_layout_set_ ? &result_layout_ : nullptr; -} - -ExecutableBuildOptions& ExecutableBuildOptions::set_device_allocator( - DeviceMemoryAllocator* allocator) { - device_allocator_ = allocator; - return *this; -} - -DeviceMemoryAllocator* ExecutableBuildOptions::device_allocator() const { - return device_allocator_; -} - namespace { StatusOr BorrowStreamForDevice(int device_ordinal, Backend* backend) { @@ -70,16 +41,18 @@ StatusOr BorrowStreamForDevice(int device_ordinal, } // namespace LocalExecutable::LocalExecutable(std::unique_ptr executable, - Backend* backend, int device_ordinal, - const ExecutableBuildOptions& build_options) + Backend* backend, + ExecutableBuildOptions build_options) : executable_(std::move(executable)), backend_(backend), - build_device_ordinal_(device_ordinal), - build_options_(build_options) {} + build_options_(std::move(build_options)) { + CHECK_GE(build_options_.device_ordinal(), 0) + << "Must have a valid device ordinal that the executable was built for."; +} tensorflow::Status LocalExecutable::ValidateExecutionOptions( const tensorflow::gtl::ArraySlice arguments, - const ExecutableRunOptions& options, const Backend& backend) { + const ExecutableRunOptions& run_options, const Backend& backend) { const ComputationLayout& computation_layout = executable_->module_config().entry_computation_layout(); @@ -103,14 +76,14 @@ tensorflow::Status LocalExecutable::ValidateExecutionOptions( } } - if (options.stream() != nullptr) { - if (!options.stream()->ok()) { + if (run_options.stream() != nullptr) { + if (!run_options.stream()->ok()) { return InvalidArgument("stream is uninitialized or in an error state"); } // Check stream matches service platform. const se::Platform* stream_platform = - options.stream()->parent()->platform(); + run_options.stream()->parent()->platform(); if (stream_platform != backend_->platform()) { return InvalidArgument( "stream is for platform %s, but service targets platform %s", @@ -120,7 +93,7 @@ tensorflow::Status LocalExecutable::ValidateExecutionOptions( // Cannot specify device_ordinal with a stream. The stream determines these // values. - if (options.device_ordinal() != -1) { + if (run_options.device_ordinal() != -1) { return InvalidArgument( "cannot set both device ordinal and stream options in " "ExecutableRunOptions; the stream determines the device ordinal"); @@ -129,34 +102,34 @@ tensorflow::Status LocalExecutable::ValidateExecutionOptions( // Verify that the device the executable was built for is equivalent to the // device it will run on. - int run_device_ordinal = options.device_ordinal() == -1 + int run_device_ordinal = run_options.device_ordinal() == -1 ? backend_->default_device_ordinal() - : options.device_ordinal(); - TF_ASSIGN_OR_RETURN( - bool devices_equivalent, - backend_->devices_equivalent(run_device_ordinal, build_device_ordinal_)); + : run_options.device_ordinal(); + TF_ASSIGN_OR_RETURN(bool devices_equivalent, + backend_->devices_equivalent( + run_device_ordinal, build_options_.device_ordinal())); if (!devices_equivalent) { TF_ASSIGN_OR_RETURN(se::StreamExecutor * run_executor, backend_->stream_executor(run_device_ordinal)); TF_ASSIGN_OR_RETURN(se::StreamExecutor * build_executor, - backend_->stream_executor(build_device_ordinal_)); + backend_->stream_executor(build_device_ordinal())); return InvalidArgument( "executable is built for device %s of type \"%s\"; cannot run it on " "device %s of type \"%s\"", - backend_->device_name(build_device_ordinal_).c_str(), + backend_->device_name(build_device_ordinal()).c_str(), build_executor->GetDeviceDescription().name().c_str(), backend_->device_name(run_device_ordinal).c_str(), run_executor->GetDeviceDescription().name().c_str()); } - if (!options.allocator()) { + if (!run_options.allocator()) { return InvalidArgument("an allocator must be provided to ExecuteLocally"); } - if (options.allocator()->platform() != backend.platform()) { + if (run_options.allocator()->platform() != backend.platform()) { return InvalidArgument( "allocator platform (%s) does not match service platform (%s)", - options.allocator()->platform()->Name().c_str(), + run_options.allocator()->platform()->Name().c_str(), backend.platform()->Name().c_str()); } @@ -165,23 +138,22 @@ tensorflow::Status LocalExecutable::ValidateExecutionOptions( StatusOr> LocalExecutable::Run( const tensorflow::gtl::ArraySlice arguments, - const ExecutableRunOptions& options) { - TF_RETURN_IF_ERROR(ValidateExecutionOptions(arguments, options, *backend_)); - - ExecutableRunOptions actual_options = options; + ExecutableRunOptions run_options) { + TF_RETURN_IF_ERROR( + ValidateExecutionOptions(arguments, run_options, *backend_)); Backend::StreamPtr stream; - if (options.stream() == nullptr) { + if (run_options.stream() == nullptr) { // NB! The lifetime of `stream` needs to match the lifetime of // `actual_options` (otherwise we will end up using a returned stream in // ExecuteOnStreamWrapper), which is why it isn't declared in the inner "if" // scope. TF_ASSIGN_OR_RETURN( - stream, BorrowStreamForDevice(options.device_ordinal(), backend_)); - actual_options.set_stream(stream.get()); + stream, BorrowStreamForDevice(run_options.device_ordinal(), backend_)); + run_options.set_stream(stream.get()); } - if (options.allocator() == nullptr) { - actual_options.set_allocator(backend_->memory_allocator()); + if (run_options.allocator() == nullptr) { + run_options.set_allocator(backend_->memory_allocator()); } // For local client execution on CPU backends: @@ -190,7 +162,7 @@ StatusOr> LocalExecutable::Run( // *) The thread pool used for XLA CPU ops is from // backend_->eigen_intra_op_thread_pool(). ServiceExecutableRunOptions service_options( - actual_options, backend_->StreamBorrower(), + run_options, backend_->StreamBorrower(), backend_->eigen_intra_op_thread_pool()); if (executable_->dumping()) { @@ -199,9 +171,8 @@ StatusOr> LocalExecutable::Run( TF_ASSIGN_OR_RETURN( std::unique_ptr result, executable_->ExecuteOnStreamWrapper( - &service_options, options.execution_profile(), arguments)); - return ScopedShapedBuffer::MakeScoped(result.get(), - actual_options.allocator()); + &service_options, run_options.execution_profile(), arguments)); + return ScopedShapedBuffer::MakeScoped(result.get(), run_options.allocator()); } StatusOr> LocalExecutable::ExecuteAndDump( @@ -277,17 +248,19 @@ StatusOr> LocalClient::Compile( const Computation& computation, const tensorflow::gtl::ArraySlice argument_layouts, const ExecutableBuildOptions& options) { - int device_ordinal = options.device_ordinal() == -1 - ? default_device_ordinal() - : options.device_ordinal(); + ExecutableBuildOptions updated_options = options; + if (options.device_ordinal() == -1) { + updated_options.set_device_ordinal(default_device_ordinal()); + VLOG(3) << "Set device ordinal to default value of: " + << updated_options.device_ordinal(); + } TF_ASSIGN_OR_RETURN( std::unique_ptr executable, local_service_->CompileExecutable(computation.handle(), argument_layouts, - options.result_layout(), device_ordinal, - options.device_allocator())); + updated_options)); return WrapUnique(new LocalExecutable(std::move(executable), local_service_->mutable_backend(), - device_ordinal, options)); + updated_options)); } StatusOr> diff --git a/tensorflow/compiler/xla/client/local_client.h b/tensorflow/compiler/xla/client/local_client.h index 843ad7aa85d07e6041b1adb4a1bbc7566838f0e8..b52a30f5a0b92e0094e6b0de3241c10a5a909cad 100644 --- a/tensorflow/compiler/xla/client/local_client.h +++ b/tensorflow/compiler/xla/client/local_client.h @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client.h" #include "tensorflow/compiler/xla/client/computation.h" +#include "tensorflow/compiler/xla/client/executable_build_options.h" #include "tensorflow/compiler/xla/executable_run_options.h" #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" @@ -33,51 +34,13 @@ limitations under the License. namespace xla { -// Class containing options for building an LocalExecutable with -// LocalClient::Compile. -class ExecutableBuildOptions { - public: - // If set, this is the device to build the computation for. Valid - // device_ordinal values are: 0 to # of devices - 1. These values are - // identical to the device ordinal values used by StreamExecutor. The built - // executable will be executable on any device equivalent to the specified - // device as determined by Backend::devices_equivalent(). A value of -1 - // indicates this option has not been set. - ExecutableBuildOptions& set_device_ordinal(int device_ordinal); - int device_ordinal() const; - - // If set, this specifies the layout of the result of the computation. If not - // set, the service will chose the layout of the result. A Shape is used to - // store the layout to accommodate tuple result shapes. A value of nullptr - // indicates the option has not been set. - ExecutableBuildOptions& set_result_layout(const Shape& shape_with_layout); - const Shape* result_layout() const; - - // If set, this specifies an allocator that can be used to allocate temporary - // space on the device during compilation. For example, the compiler might - // want to run various algorithms on the device and pick the fastest one -- it - // might allocate buffers for use by these algorithms using this allocator. - // - // This does not need to be the same as the DeviceMemoryAllocator passed when - // running the executable. - ExecutableBuildOptions& set_device_allocator( - DeviceMemoryAllocator* allocator); - DeviceMemoryAllocator* device_allocator() const; - - private: - int device_ordinal_ = -1; - Shape result_layout_; - bool result_layout_set_ = false; - DeviceMemoryAllocator* device_allocator_ = nullptr; -}; - class LocalExecutable { public: // Run the compiled computation with the given arguments and options and // return the result. StatusOr> Run( const tensorflow::gtl::ArraySlice arguments, - const ExecutableRunOptions& options); + ExecutableRunOptions run_options); // Return the layout (contained in a shape) of the result produced by the // computation. @@ -100,8 +63,7 @@ class LocalExecutable { // Constructor invoked by LocalClient. LocalExecutable(std::unique_ptr executable, Backend* backend, - int device_ordinal, - const ExecutableBuildOptions& build_options); + ExecutableBuildOptions build_options); // Validates that the given arguments and options satisfy various constraints // of the computation. @@ -129,19 +91,19 @@ class LocalExecutable { StatusOr> LiteralFromShapedBuffer( const ShapedBuffer& shaped_buffer); + // The ordinal of the device which this executable was compiled for. The + // executable can run on all equivalent devices (as determined by + // Backend::devices_equivalent). + int build_device_ordinal() const { return build_options_.device_ordinal(); } + // Compiled computation. std::unique_ptr executable_; // Execution backend. - Backend* backend_; - - // The ordinal of the device which this executable was compiled for. The - // executable can run on all equivalent devices (as determined by - // Backend::devices_equivalent). - int build_device_ordinal_; + Backend* backend_ = nullptr; // Options used to build the executable. - const ExecutableBuildOptions& build_options_; + const ExecutableBuildOptions build_options_; }; // An XLA Client specialization for use when the client and service run in diff --git a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc index fe3a4d2f6df47d9f156529e55198a5f339bc8e3c..c8ed3e3a2b009ddffdfb79a9a6ced8d5e736bee6 100644 --- a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc +++ b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc @@ -221,13 +221,19 @@ void AllocateFlags() { flag_values->xla_gpu_disable_multi_streaming(), "If true, multi-streaming in the GPU backend is disabled."), tensorflow::Flag( - "xla_dump_hlo_proto_to", flag_values->mutable_xla_dump_hlo_proto_to(), - "Dump compilation artifacts as proto binary into this directory."), + "xla_dump_optimized_hlo_proto_to", + flag_values->mutable_xla_dump_optimized_hlo_proto_to(), + "Dump Hlo after all hlo passes are executed as proto binary into " + "this directory."), tensorflow::Flag( - "xla_dump_prepass_hlo_proto_to", - flag_values->mutable_xla_dump_prepass_hlo_proto_to(), - "Dump compilation artifacts, before hlo passes are executed, as " - "proto binary into this directory."), + "xla_dump_unoptimized_hlo_proto_to", + flag_values->mutable_xla_dump_unoptimized_hlo_proto_to(), + "Dump HLO before any hlo passes are executed as proto binary into " + "this directory."), + tensorflow::Flag("xla_dump_per_pass_hlo_proto_to", + flag_values->mutable_xla_dump_per_pass_hlo_proto_to(), + "Dump HLO after each pass as an HloProto in binary file " + "format into this directory."), tensorflow::Flag( "xla_test_all_output_layouts", bool_setter_for(&DebugOptions::set_xla_test_all_output_layouts), diff --git a/tensorflow/compiler/xla/python/BUILD b/tensorflow/compiler/xla/python/BUILD index a8ca0e3ea0115d412e96ebacb320cc0dde061dff..e2972f06016ab3555c4fc0cc4616993fe6764b1e 100644 --- a/tensorflow/compiler/xla/python/BUILD +++ b/tensorflow/compiler/xla/python/BUILD @@ -49,6 +49,7 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client:executable_build_options", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/service:shaped_buffer", "//tensorflow/core:framework_lite", diff --git a/tensorflow/compiler/xla/python/local_computation_builder.cc b/tensorflow/compiler/xla/python/local_computation_builder.cc index 37f1eada2bc9f5ef72d99a835a17b4e78a354ae6..3b0d8377395ca2a91fb007b784773e6df9c8d6c0 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.cc +++ b/tensorflow/compiler/xla/python/local_computation_builder.cc @@ -98,15 +98,25 @@ const std::unique_ptr& LocalShapedBuffer::shaped_buffer() return shaped_buffer_; } +static StatusOr> ToBuffer( + LocalClient* client, int device_ordinal, const Literal& arg) { + return client->LiteralToShapedBuffer(arg, device_ordinal, + client->backend().memory_allocator()); +} + /* static */ -LocalShapedBuffer* LocalShapedBuffer::FromLiteral(const Literal& argument) { +LocalShapedBuffer* LocalShapedBuffer::FromLiteral( + const Literal& argument, + const tensorflow::gtl::optional& shape_with_layout) { LocalClient* client = GetOrCreateLocalClient(); - std::unique_ptr buf = - client - ->LiteralToShapedBuffer(argument, - /*device_ordinal=*/0, - client->backend().memory_allocator()) - .ConsumeValueOrDie(); + std::unique_ptr buf; + if (shape_with_layout) { + std::unique_ptr relaid = + argument.Relayout(shape_with_layout.value()); + buf = ToBuffer(client, /*device_ordinal=*/0, *relaid).ConsumeValueOrDie(); + } else { + buf = ToBuffer(client, /*device_ordinal=*/0, argument).ConsumeValueOrDie(); + } return new LocalShapedBuffer(std::move(buf)); } @@ -120,7 +130,8 @@ CompiledLocalComputation::CompiledLocalComputation( : executable_(std::move(executable)) {} StatusOr> CompiledLocalComputation::Execute( - const std::vector& arguments) { + const std::vector& arguments, + const std::vector>& shapes_with_layout) { LocalClient* client = GetOrCreateLocalClient(); VLOG(1) << "Execution requested with " << GetReplicaCount() << " replicas."; @@ -133,7 +144,8 @@ StatusOr> CompiledLocalComputation::Execute( GetReplicaCount()); for (int replica = 0; replica < GetReplicaCount(); ++replica) { - pool.Schedule([this, client, replica, &arguments, &results] { + pool.Schedule([this, client, replica, &arguments, &shapes_with_layout, + &results] { StatusOr device_ordinal_status = client->ReplicaNumberToDeviceOrdinal(replica); if (!device_ordinal_status.ok()) { @@ -144,18 +156,28 @@ StatusOr> CompiledLocalComputation::Execute( VLOG(3) << "Replica " << replica << " mapped to device ordinal for execution: " << device_ordinal; + // Transfer arguments in std::vector> scoped_buffers; scoped_buffers.reserve(arguments.size()); - for (const Literal& argument : arguments) { - StatusOr> pushed = - client->LiteralToShapedBuffer( - argument, device_ordinal, - client->backend().memory_allocator()); + for (int i = 0; i < arguments.size(); ++i) { + const Literal& argument = arguments[i]; + const tensorflow::gtl::optional& shape_with_layout = + shapes_with_layout[i]; + + StatusOr> pushed; + if (shape_with_layout) { + std::unique_ptr relaid = + argument.Relayout(shape_with_layout.value()); + pushed = ToBuffer(client, device_ordinal, *relaid); + } else { + pushed = ToBuffer(client, device_ordinal, argument); + } if (!pushed.ok()) { results[replica] = pushed.status(); return; } + scoped_buffers.push_back(std::move(pushed).ValueOrDie()); } @@ -233,7 +255,8 @@ LocalComputation::LocalComputation(Computation computation) : computation_(std::move(computation)) {} StatusOr LocalComputation::Compile( - const std::vector& argument_shapes) { + const std::vector& argument_shapes, + const ExecutableBuildOptions* build_options) { std::vector argument_shape_pointers; argument_shape_pointers.reserve(argument_shapes.size()); for (auto& argument_shape : argument_shapes) { @@ -242,6 +265,9 @@ StatusOr LocalComputation::Compile( LocalClient* client = GetOrCreateLocalClient(); ExecutableBuildOptions options; + if (build_options != nullptr) { + options = *build_options; + } TF_ASSIGN_OR_RETURN( auto local_executable, client->Compile(computation_, argument_shape_pointers, options)); @@ -363,12 +389,6 @@ LocalComputationBuilder::SelectAndScatterWithGeneralPadding( source, init_value, scatter.computation()); } -ComputationDataHandle LocalComputationBuilder::Select( - const ComputationDataHandle& pred, const ComputationDataHandle& on_true, - const ComputationDataHandle& on_false) { - return builder_.Select(pred, on_true, on_false); -} - ComputationDataHandle LocalComputationBuilder::Tuple( tensorflow::gtl::ArraySlice elements) { return builder_.Tuple(elements); @@ -384,6 +404,12 @@ ComputationDataHandle LocalComputationBuilder::Dot( return builder_.Dot(lhs, rhs); } +ComputationDataHandle LocalComputationBuilder::DotGeneral( + const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, + const DotDimensionNumbers& dimension_numbers) { + return builder_.DotGeneral(lhs, rhs, dimension_numbers); +} + ComputationDataHandle LocalComputationBuilder::ConvGeneralDilated( const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, tensorflow::gtl::ArraySlice window_strides, @@ -467,6 +493,17 @@ ComputationDataHandle LocalComputationBuilder::While( return builder_.While(condition.computation(), body.computation(), init); } +ComputationDataHandle LocalComputationBuilder::Conditional( + const ComputationDataHandle& predicate, + const ComputationDataHandle& true_operand, + const LocalComputation& true_computation, + const ComputationDataHandle& false_operand, + const LocalComputation& false_computation) { + return builder_.Conditional(predicate, true_operand, + true_computation.computation(), false_operand, + false_computation.computation()); +} + #define _FORWARD(method_name, return_sig, args_sig, args) \ return_sig LocalComputationBuilder::method_name args_sig { \ return builder_.method_name args; \ @@ -483,6 +520,15 @@ ComputationDataHandle LocalComputationBuilder::While( tensorflow::gtl::ArraySlice broadcast_dimensions), \ (lhs, rhs, broadcast_dimensions)) +#define _FORWARD_TRIOP(method_name) \ + _FORWARD( \ + method_name, ComputationDataHandle, \ + (const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, \ + const ComputationDataHandle& ehs), \ + (lhs, rhs, ehs)) + +_FORWARD_TRIOP(Select) +_FORWARD_TRIOP(Clamp) _FORWARD_BINOP(Eq) _FORWARD_BINOP(Ne) _FORWARD_BINOP(Ge) @@ -503,6 +549,7 @@ _FORWARD_UNOP(Abs) _FORWARD_UNOP(Exp) _FORWARD_UNOP(Floor) _FORWARD_UNOP(Ceil) +_FORWARD_UNOP(Round) _FORWARD_UNOP(Log) _FORWARD_UNOP(Sign) _FORWARD_UNOP(Cos) @@ -519,6 +566,7 @@ _FORWARD_UNOP(Sort) #undef _FORWARD #undef _FORWARD_UNOP #undef _FORWARD_BINOP +#undef _FORWARD_TRIOP void DeleteLocalShapedBuffer(LocalShapedBuffer* local_shaped_buffer) { delete local_shaped_buffer; diff --git a/tensorflow/compiler/xla/python/local_computation_builder.h b/tensorflow/compiler/xla/python/local_computation_builder.h index e5503cd52fa60eff30eea38c83aafe0f0ff1efc8..4c6a504f4cd83533185cdadf60ae2c53a0d5e911 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.h +++ b/tensorflow/compiler/xla/python/local_computation_builder.h @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/executable_build_options.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -58,7 +59,9 @@ StatusOr > TransferFromOutfeedLocalReplica( // client. class LocalShapedBuffer { public: - static LocalShapedBuffer* FromLiteral(const Literal& argument); + static LocalShapedBuffer* FromLiteral( + const Literal& argument, + const tensorflow::gtl::optional& shape_with_layout); LocalShapedBuffer(std::unique_ptr shaped_buffer); const std::unique_ptr& shaped_buffer() const; std::unique_ptr ToLiteral() const; @@ -76,8 +79,15 @@ class LocalShapedBuffer { class CompiledLocalComputation { public: CompiledLocalComputation(std::unique_ptr executable); + + // Execute the computation with the given argument literals, and + // with optionally-specified argument layouts. The literals will be + // re-laid out according to the corresponding elements of + // shapes_with_layout. StatusOr > Execute( - const std::vector& arguments); + const std::vector& arguments, + const std::vector >& shapes_with_layout); + LocalShapedBuffer* ExecuteWithShapedBuffers( tensorflow::gtl::ArraySlice argument_handles); @@ -93,7 +103,8 @@ class LocalComputation { public: LocalComputation(Computation computation); StatusOr Compile( - const std::vector& argument_shapes); + const std::vector& argument_shapes, + const ExecutableBuildOptions* build_options); const Computation& computation() const; private: @@ -172,10 +183,6 @@ class LocalComputationBuilder { const ComputationDataHandle& source, const ComputationDataHandle& init_value, const LocalComputation& scatter); - ComputationDataHandle Select(const ComputationDataHandle& pred, - const ComputationDataHandle& on_true, - const ComputationDataHandle& on_false); - ComputationDataHandle Tuple( tensorflow::gtl::ArraySlice elements); @@ -185,6 +192,10 @@ class LocalComputationBuilder { ComputationDataHandle Dot(const ComputationDataHandle& lhs, const ComputationDataHandle& rhs); + ComputationDataHandle DotGeneral( + const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, + const DotDimensionNumbers& dimension_numbers); + ComputationDataHandle ConvGeneralDilated( const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, tensorflow::gtl::ArraySlice window_strides, @@ -239,6 +250,12 @@ class LocalComputationBuilder { const LocalComputation& body, const ComputationDataHandle& init); + ComputationDataHandle Conditional(const ComputationDataHandle& predicate, + const ComputationDataHandle& true_operand, + const LocalComputation& true_computation, + const ComputationDataHandle& false_operand, + const LocalComputation& false_computation); + #define _FORWARD(method_name, return_sig, args_sig) \ return_sig method_name args_sig; @@ -252,6 +269,14 @@ class LocalComputationBuilder { (const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, \ tensorflow::gtl::ArraySlice broadcast_dimensions)) +#define _FORWARD_TRIOP(method_name) \ + _FORWARD( \ + method_name, ComputationDataHandle, \ + (const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, \ + const ComputationDataHandle& ehs)) + + _FORWARD_TRIOP(Select) + _FORWARD_TRIOP(Clamp) _FORWARD_BINOP(Eq) _FORWARD_BINOP(Ne) _FORWARD_BINOP(Ge) @@ -272,6 +297,7 @@ class LocalComputationBuilder { _FORWARD_UNOP(Exp) _FORWARD_UNOP(Floor) _FORWARD_UNOP(Ceil) + _FORWARD_UNOP(Round) _FORWARD_UNOP(Log) _FORWARD_UNOP(Sign) _FORWARD_UNOP(Cos) @@ -288,6 +314,7 @@ class LocalComputationBuilder { #undef _FORWARD #undef _FORWARD_UNOP #undef _FORWARD_BINOP +#undef _FORWARD_TRIOP private: ComputationBuilder builder_; diff --git a/tensorflow/compiler/xla/python/local_computation_builder.i b/tensorflow/compiler/xla/python/local_computation_builder.i index 31789259609714e7d20247eec072e05a181715e6..114754bde4033a13e217bd6552ebffbde7c3503b 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.i +++ b/tensorflow/compiler/xla/python/local_computation_builder.i @@ -27,12 +27,14 @@ limitations under the License. // ArraySlice <- sequence of int // Literal <-> (nested tuple of) numpy ndarray // std::vector <- sequence of (nested tuple of) ndarray -// Shape <-> pair holding (dtype, dimensions) -// std::vector <- sequence of shape information pairs +// Shape -> pair holding (dtype, dimensions) +// <- object duck-typed as xla_client.Shape +// std::vector <- sequence of xla_client.Shape objects // PrimitiveType <- int // ArraySlice> <- sequence of int pairs // PaddingConfig proto <- corresponding Python proto // ConvolutionDimensionNumbers proto <- corresponding Python proto +// DotDimensionNumbers proto <- corresponding Python proto // // Arrows indicate whether a conversion only ever occurs in one // direction, or whether it is maintained bidirectionally. @@ -55,7 +57,7 @@ limitations under the License. // translates to a tuple-shaped XLA Literal, whose component subshapes // are a 2x3 F32-shaped literal followed by two tuple-shaped literals. // -// The Python objects corresponding to C++ Shapes have the type: +// Shapes output by C++ become Python objects with the type: // // T = (dtype, S) // S = DIMENSIONS | TUPLE_SHAPES @@ -176,6 +178,16 @@ tensorflow::ImportNumpy(); } } +%typemap(out) StatusOr< std::unique_ptr > { + if ($1.ok()) { + std::unique_ptr value = $1.ConsumeValueOrDie(); + $result = numpy::PyObjectFromXlaLiteral(*value); + } else { + PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); + return NULL; + } +} + %typemap(out) StatusOr { if ($1.ok()) { auto* value = $1.ValueOrDie(); @@ -343,15 +355,31 @@ tensorflow::ImportNumpy(); // Shape %typemap(in) const Shape& (Shape temp) { - Status shape_status = numpy::CheckPyShapeInfo($input); - if (!shape_status.ok()) { - PyErr_SetString(PyExc_RuntimeError, shape_status.ToString().c_str()); + StatusOr statusor = numpy::XlaShapeFromPyShape($input); + if (!statusor.ok()) { + PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str()); return NULL; } - temp = numpy::XlaShapeFromPyShapeInfo($input); + temp = std::move(statusor).ValueOrDie(); $1 = &temp; } +%typemap(in) const tensorflow::gtl::optional& ( + tensorflow::gtl::optional temp) { + if ($input == Py_None) { + temp = tensorflow::gtl::nullopt; + $1 = &temp; + } else { + StatusOr statusor = numpy::XlaShapeFromPyShape($input); + if (!statusor.ok()) { + PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str()); + return NULL; + } + temp = std::move(statusor).ValueOrDie(); + $1 = &temp; + } +} + %typemap(out) std::unique_ptr { $result = numpy::PyShapeInfoFromXlaShape(*$1); } @@ -364,14 +392,37 @@ tensorflow::ImportNumpy(); const int size = PySequence_Size($input); for (int i = 0; i < size; ++i) { PyObject* o = PySequence_GetItem($input, i); - Status shape_status = numpy::CheckPyShapeInfo(o); - if (!shape_status.ok()) { - PyErr_SetString(PyExc_RuntimeError, shape_status.ToString().c_str()); - Py_DECREF(o); + StatusOr statusor = numpy::XlaShapeFromPyShape(o); + Py_DECREF(o); + if (!statusor.ok()) { + PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str()); return NULL; } - temps.push_back(numpy::XlaShapeFromPyShapeInfo(o)); - Py_DECREF(o); + temps.push_back(statusor.ConsumeValueOrDie()); + } + $1 = &temps; +} + +%typemap(in) const std::vector >& ( + std::vector > temps) { + if (!PySequence_Check($input)) { + PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); + return NULL; + } + const int size = PySequence_Size($input); + for (int i = 0; i < size; ++i) { + PyObject* o = PySequence_GetItem($input, i); + if (o == Py_None) { + temps.push_back(tensorflow::gtl::nullopt); + } else { + StatusOr statusor = numpy::XlaShapeFromPyShape(o); + Py_DECREF(o); + if (!statusor.ok()) { + PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str()); + return NULL; + } + temps.push_back(statusor.ConsumeValueOrDie()); + } } $1 = &temps; } @@ -461,6 +512,135 @@ tensorflow::ImportNumpy(); $1 = temps; } +// DotDimensionNumbers + +%typemap(in) const DotDimensionNumbers& + (DotDimensionNumbers dimension_numbers) { + int length; + + /* lhs_contracting_dimensions */ + PyObject* lhs_contracting_dimensions = PyObject_GetAttrString( + $input, "lhs_contracting_dimensions"); + if (!lhs_contracting_dimensions) { + return NULL; + } + + length = PySequence_Size(lhs_contracting_dimensions); + if (length == -1) { + Py_DECREF(lhs_contracting_dimensions); + return NULL; + } + + for (int i = 0; i < length; ++i) { + PyObject* item = PySequence_GetItem(lhs_contracting_dimensions, i); + if (!item) { + Py_DECREF(lhs_contracting_dimensions); + return NULL; + } + const int64 dimension = numpy::PyIntOrPyLongToLong(item); + if (dimension == -1 && PyErr_Occurred()) { + Py_DECREF(item); + Py_DECREF(lhs_contracting_dimensions); + return NULL; + } + dimension_numbers.add_lhs_contracting_dimensions(dimension); + Py_DECREF(item); + } + Py_DECREF(lhs_contracting_dimensions); + + /* rhs_contracting_dimensions */ + PyObject* rhs_contracting_dimensions = PyObject_GetAttrString( + $input, "rhs_contracting_dimensions"); + if (!lhs_contracting_dimensions) { + return NULL; + } + + length = PySequence_Size(rhs_contracting_dimensions); + if (length == -1) { + Py_DECREF(rhs_contracting_dimensions); + return NULL; + } + + for (int i = 0; i < length; ++i) { + PyObject* item = PySequence_GetItem(rhs_contracting_dimensions, i); + if (!item) { + Py_DECREF(rhs_contracting_dimensions); + return NULL; + } + const int64 dimension = numpy::PyIntOrPyLongToLong(item); + if (dimension == -1 && PyErr_Occurred()) { + Py_DECREF(item); + Py_DECREF(rhs_contracting_dimensions); + return NULL; + } + dimension_numbers.add_rhs_contracting_dimensions(dimension); + Py_DECREF(item); + } + Py_DECREF(rhs_contracting_dimensions); + + /* lhs_batch_dimensions */ + PyObject* lhs_batch_dimensions = PyObject_GetAttrString( + $input, "lhs_batch_dimensions"); + if (!lhs_batch_dimensions) { + return NULL; + } + + length = PySequence_Size(lhs_batch_dimensions); + if (length == -1) { + Py_DECREF(lhs_batch_dimensions); + return NULL; + } + + for (int i = 0; i < length; ++i) { + PyObject* item = PySequence_GetItem(lhs_batch_dimensions, i); + if (!item) { + Py_DECREF(lhs_batch_dimensions); + return NULL; + } + const int64 dimension = numpy::PyIntOrPyLongToLong(item); + if (dimension == -1 && PyErr_Occurred()) { + Py_DECREF(item); + Py_DECREF(lhs_batch_dimensions); + return NULL; + } + dimension_numbers.add_lhs_batch_dimensions(dimension); + Py_DECREF(item); + } + Py_DECREF(lhs_batch_dimensions); + + /* rhs_batch_dimensions */ + PyObject* rhs_batch_dimensions = PyObject_GetAttrString( + $input, "rhs_batch_dimensions"); + if (!rhs_batch_dimensions) { + return NULL; + } + + length = PySequence_Size(rhs_batch_dimensions); + if (length == -1) { + Py_DECREF(rhs_batch_dimensions); + return NULL; + } + + for (int i = 0; i < length; ++i) { + PyObject* item = PySequence_GetItem(rhs_batch_dimensions, i); + if (!item) { + Py_DECREF(rhs_batch_dimensions); + return NULL; + } + const int64 dimension = numpy::PyIntOrPyLongToLong(item); + if (dimension == -1 && PyErr_Occurred()) { + Py_DECREF(item); + Py_DECREF(rhs_batch_dimensions); + return NULL; + } + dimension_numbers.add_rhs_batch_dimensions(dimension); + Py_DECREF(item); + } + Py_DECREF(rhs_batch_dimensions); + + $1 = &dimension_numbers; +} + // PaddingConfig %typemap(in) const PaddingConfig& @@ -623,6 +803,30 @@ tensorflow::ImportNumpy(); $1 = &dimension_numbers; } +// ExecutableBuildOptions + +%typemap(in) const ExecutableBuildOptions* + (ExecutableBuildOptions build_options) { + if ($input == Py_None) { + $1 = NULL; + } else { + PyObject* o = PyObject_GetAttrString($input, "generate_hlo_graph"); + if (!o) { + return NULL; + } + if (o != Py_None) { + if (!PyString_Check(o)) { + PyErr_SetString(PyExc_TypeError, "ExecutableBuildOptions.generate_hlo_graph must be a string or None."); + return NULL; + } + build_options.set_generate_hlo_graph(PyString_AsString(o)); + } + Py_DECREF(o); + + $1 = &build_options; + } +} + %ignoreall %unignore xla; %unignore xla::swig; @@ -667,6 +871,7 @@ tensorflow::ImportNumpy(); %unignore xla::swig::LocalComputationBuilder::Call; %unignore xla::swig::LocalComputationBuilder::Transpose; %unignore xla::swig::LocalComputationBuilder::Rev; +%unignore xla::swig::LocalComputationBuilder::Clamp; %unignore xla::swig::LocalComputationBuilder::Map; %unignore xla::swig::LocalComputationBuilder::Reduce; %unignore xla::swig::LocalComputationBuilder::ReduceWindowWithGeneralPadding; @@ -674,6 +879,7 @@ tensorflow::ImportNumpy(); %unignore xla::swig::LocalComputationBuilder::RngUniform; %unignore xla::swig::LocalComputationBuilder::RngBernoulli; %unignore xla::swig::LocalComputationBuilder::While; +%unignore xla::swig::LocalComputationBuilder::Conditional; %unignore xla::swig::LocalComputationBuilder::Eq; %unignore xla::swig::LocalComputationBuilder::Ne; %unignore xla::swig::LocalComputationBuilder::Ge; @@ -681,6 +887,7 @@ tensorflow::ImportNumpy(); %unignore xla::swig::LocalComputationBuilder::Lt; %unignore xla::swig::LocalComputationBuilder::Le; %unignore xla::swig::LocalComputationBuilder::Dot; +%unignore xla::swig::LocalComputationBuilder::DotGeneral; %unignore xla::swig::LocalComputationBuilder::ConvGeneralDilated; %unignore xla::swig::LocalComputationBuilder::Add; %unignore xla::swig::LocalComputationBuilder::Sub; @@ -696,6 +903,7 @@ tensorflow::ImportNumpy(); %unignore xla::swig::LocalComputationBuilder::Exp; %unignore xla::swig::LocalComputationBuilder::Floor; %unignore xla::swig::LocalComputationBuilder::Ceil; +%unignore xla::swig::LocalComputationBuilder::Round; %unignore xla::swig::LocalComputationBuilder::Log; %unignore xla::swig::LocalComputationBuilder::Sign; %unignore xla::swig::LocalComputationBuilder::Cos; diff --git a/tensorflow/compiler/xla/python/numpy_bridge.cc b/tensorflow/compiler/xla/python/numpy_bridge.cc index 5c722623e318ece9eca6bdc8750195ce5fd5defb..3d87480728aab1d4ebbc71c6c7504d37cae5edaf 100644 --- a/tensorflow/compiler/xla/python/numpy_bridge.cc +++ b/tensorflow/compiler/xla/python/numpy_bridge.cc @@ -176,85 +176,107 @@ static string PyObjectCppRepr(PyObject* o) { return ExtractStringAndDecref(r); } -Status CheckPyShapeInfo(PyObject* o) { +StatusOr XlaShapeFromPyShape(PyObject* o) { auto error = [o](const string& prefix) { return InvalidArgument("%s; got %s", prefix.c_str(), PyObjectCppRepr(o).c_str()); }; - // The object is a tuple (a pair) - if (!PyTuple_Check(o)) { - return error("Shape record must be a tuple"); - } - if (PyTuple_Size(o) != 2) { - return error("Shape record tuple must be of length 2"); - } - // It has a first element, which is a numpy dtype object - PyObject* first = PyTuple_GetItem(o, 0); - if (first == nullptr) { - return error("Tuple has no item 0 (shape dtype)"); - } - if (first->ob_type != &PyArrayDescr_Type) { - return error( - "Shape record does not have a numpy dtype as its first element"); - } - const int np_type = NumpyTypenum(first); - if (!NumpyTypeIsValid(np_type)) { - return error("Shape record has an invalid integer dtype"); - } + auto get_attr = [o, &error](const string& field) -> StatusOr { + PyObject* result = + PyObject_GetAttrString(o, const_cast(field.c_str())); + if (result == nullptr) { + return error(tensorflow::strings::StrCat( + "Failed to get attribute of Shape object:", field)); + } + return result; + }; - // It has a second element, which is a tuple, either of shape - // records or of Python ints - PyObject* second = PyTuple_GetItem(o, 1); - if (!second) { - return error("Tuple has no item 0 (shape dimensions)"); - } - if (!PyTuple_Check(second)) { - return error("Shape record does not have a tuple as its second element"); - } - const int length = PyTuple_Size(second); - const PrimitiveType element_type = NumpyTypeToPrimitiveType(np_type); - for (int i = 0; i < length; i++) { - PyObject* dimension = PyTuple_GetItem(second, i); - if (element_type == TUPLE) { - VLOG(3) << "element_type is tuple, checking member: " << i; - Status result = CheckPyShapeInfo(dimension); - if (!result.ok()) { - return AddStatus( - result, tensorflow::strings::StrCat("Validating tuple member ", i, - " of ", PyObjectCppRepr(o))); - } - } else if (!CheckPyIntOrLong(dimension)) { - return error("Non-tuple shape record has a non-integer dimension"); + auto call_method = [o, &error](const string& method) -> StatusOr { + PyObject* result = + PyObject_CallMethod(o, const_cast(method.c_str()), nullptr); + if (result == nullptr) { + return error(tensorflow::strings::StrCat( + "Failed to call method of shape object:", method)); } - } + return result; + }; - return Status::OK(); -} + PyObject* np_type; + TF_ASSIGN_OR_RETURN(np_type, get_attr("np_dtype")); + if (np_type->ob_type != &PyArrayDescr_Type) { + return error("Shape attribute np_dtype is not an integer numpy dtype"); + } + if (!NumpyTypeIsValid(NumpyTypenum(np_type))) { + return error("Shape attribute np_dtype is not a valid integer numpy dtype"); + } + const PrimitiveType element_type = + NumpyTypeToPrimitiveType(NumpyTypenum(np_type)); + Py_DECREF(np_type); -// Precondition: CheckPyShapeInfo(o) -Shape XlaShapeFromPyShapeInfo(PyObject* o) { - const int np_type = NumpyTypenum(PyTuple_GetItem(o, 0)); - const PrimitiveType element_type = NumpyTypeToPrimitiveType(np_type); - PyObject* py_dimensions = PyTuple_GetItem(o, 1); - const int length = PyTuple_Size(py_dimensions); if (element_type == TUPLE) { + PyObject* py_subshapes; + TF_ASSIGN_OR_RETURN(py_subshapes, call_method("tuple_shapes")); + if (!PyTuple_Check(py_subshapes)) { + return error( + "Return value of Shape method tuple_shapes() is not a tuple"); + } + const int length = PyTuple_Size(py_subshapes); std::vector subshapes; subshapes.reserve(length); for (int i = 0; i < length; i++) { - subshapes.push_back( - XlaShapeFromPyShapeInfo(PyTuple_GetItem(py_dimensions, i))); + TF_ASSIGN_OR_RETURN( + const Shape& subshape, + XlaShapeFromPyShape(PyTuple_GetItem(py_subshapes, i))); + subshapes.push_back(subshape); } + Py_DECREF(py_subshapes); return ShapeUtil::MakeTupleShape(subshapes); } else { + PyObject* py_dimensions; + PyObject* py_minor_to_major; + TF_ASSIGN_OR_RETURN(py_dimensions, call_method("dimensions")); + TF_ASSIGN_OR_RETURN(py_minor_to_major, call_method("minor_to_major")); + if (!PyTuple_Check(py_dimensions)) { + return error("Return value of Shape method dimensions() is not a tuple"); + } + if (py_minor_to_major != Py_None && !PyTuple_Check(py_minor_to_major)) { + return error( + "Return value of Shape method minor_to_major() is neither a tuple " + "nor None"); + } + const int length = PyTuple_Size(py_dimensions); + if (py_minor_to_major != Py_None && + length != PyTuple_Size(py_minor_to_major)) { + return error( + "Shape methods dimensions() and minor_to_major() return " + "different-length tuples"); + } std::vector dimensions(length); + std::vector minor_to_major(length); for (int i = 0; i < length; i++) { dimensions[i] = PyIntOrPyLongToLong(PyTuple_GetItem(py_dimensions, i)); - if (dimensions[i] == -1) { - CHECK(!PyErr_Occurred()); + if (dimensions[i] == -1 && PyErr_Occurred()) { + return error("Dimension is not an int"); } + + if (py_minor_to_major != Py_None) { + minor_to_major[i] = + PyIntOrPyLongToLong(PyTuple_GetItem(py_minor_to_major, i)); + if (minor_to_major[i] == -1 && PyErr_Occurred()) { + return error("Minor-to-major value is not an int"); + } + } + } + bool with_layout = py_minor_to_major != Py_None; + Py_DECREF(py_dimensions); + Py_DECREF(py_minor_to_major); + if (with_layout) { + return ShapeUtil::MakeShapeWithLayout(element_type, dimensions, + minor_to_major); + } else { + return ShapeUtil::MakeShape(element_type, dimensions); } - return ShapeUtil::MakeShape(element_type, dimensions); } } diff --git a/tensorflow/compiler/xla/python/numpy_bridge.h b/tensorflow/compiler/xla/python/numpy_bridge.h index 6ff1c34cfc5e0323a6729bdfd5572239f4966211..adfcc3b8588dce01718bb19dea936bace483be4d 100644 --- a/tensorflow/compiler/xla/python/numpy_bridge.h +++ b/tensorflow/compiler/xla/python/numpy_bridge.h @@ -56,15 +56,11 @@ bool NumpyTypeIsValid(int np_type); // The return value is a new reference. PyObject* PyShapeInfoFromXlaShape(const Shape& shape); -// Returns the outcome of a best-effort check that the Python object -// is a pair of the form (numpy dtype, dimensions), as produced by -// PyShapeInfoFromXlaShape. -Status CheckPyShapeInfo(PyObject* o); - -// Performs the inverse conversion to that of PyShapeInfoFromXlaShape. +// Converts a Python object with a method interface mathing that of +// xla_client.Shape into an XLA Shape object. // // The return value is a new reference. -Shape XlaShapeFromPyShapeInfo(PyObject* o); +StatusOr XlaShapeFromPyShape(PyObject* o); // Converts a PyObject that represents operation metadata into protocol buffer // form. diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py index 66ace613a0c66c9577deeb9daa6f674ede5a8865..f8cee5d5665cf95b19e037658b88ffddd5efa511 100644 --- a/tensorflow/compiler/xla/python/xla_client.py +++ b/tensorflow/compiler/xla/python/xla_client.py @@ -89,6 +89,7 @@ _UNARY_OPS = [ 'Abs', 'Exp', 'Floor', + 'Round', 'Ceil', 'Log', 'Sign', @@ -155,9 +156,14 @@ class LocalBuffer(object): self._delete = c_api.DeleteLocalShapedBuffer @staticmethod - def from_py(npval): + def from_py(npval, layout_fn=None): npval = require_numpy_array_layout(npval) - return LocalBuffer(c_api.LocalShapedBuffer.FromLiteral(npval)) + if layout_fn: + shape = Shape.from_numpy(npval) + shape = shape.map_leaves(layout_fn) + else: + shape = None + return LocalBuffer(c_api.LocalShapedBuffer.FromLiteral(npval, shape)) def to_py(self): return self.c_local_shaped_buffer.ToLiteral() @@ -182,13 +188,17 @@ class Shape(object): represents an XLA tuple. """ - def __init__(self, np_dtype, dimensions): + def __init__(self, np_dtype, dimensions, minor_to_major=None): + assert isinstance(dimensions, tuple) self.np_dtype = np_dtype self._dimensions = dimensions + self._minor_to_major = minor_to_major + self._check_minor_to_major() def __repr__(self): - return 'xla_client.Shape(np_dtype={!r}, dimensions={!r})'.format( - self.np_dtype, self._dimensions) + return ('xla_client.Shape(np_dtype={!r}, dimensions={!r}, ' + 'minor_to_major={!r})').format(self.np_dtype, self._dimensions, + self._minor_to_major) def element_type(self): return DTYPE_TO_XLA_ELEMENT_TYPE[str(self.np_dtype)] @@ -201,11 +211,49 @@ class Shape(object): raise ValueError('Tuple shape has no dimensions') return self._dimensions + def minor_to_major(self): + return self._minor_to_major + def tuple_shapes(self): if not self.is_tuple(): raise ValueError('Shape is not a tuple shape') return self._dimensions + def rank(self): + return len(self.dimensions()) + + def map_leaves(self, f): + """Map f over each leaf-level array subshape. + + Args: + f: The function to apply. Whenever f returns None, the identity is + applied instead. + + Returns: + A new Shape with the mapped leaves. + """ + if self.is_tuple(): + children = tuple(child.map_leaves(f) for child in self.tuple_shapes()) + return Shape(np.dtype('O'), children) + else: + mapped = f(self) + return self if mapped is None else mapped + + def _check_minor_to_major(self): + mtm = self._minor_to_major + if self.is_tuple(): + assert mtm is None, self + if mtm is not None: + assert self.rank() == len(mtm), self + assert sorted(mtm) == range(len(mtm)), self + + def update_minor_to_major(self, minor_to_major): + if not isinstance(minor_to_major, tuple): + raise TypeError('minor_to_major must be a tuple') + updated = Shape(self.np_dtype, tuple(self.dimensions()), minor_to_major) + updated._check_minor_to_major() # pylint: disable=protected-access + return updated + @staticmethod def from_numpy(npval): @@ -222,23 +270,10 @@ def _wrap_shape(shape_info): dtype, dims = shape_info element_type = DTYPE_TO_XLA_ELEMENT_TYPE[str(dtype)] if element_type == xla_data_pb2.TUPLE: - dims = [_wrap_shape(subshape_info) for subshape_info in dims] + dims = tuple(_wrap_shape(subshape_info) for subshape_info in dims) return Shape(dtype, dims) -def _unwrap_shape(shape): - if shape.is_tuple(): - components = tuple( - _unwrap_shape(subshape) for subshape in shape.tuple_shapes()) - else: - components = shape.dimensions() - return (shape.np_dtype, components) - - -def _unwrap_shapes(shapes): - return [_unwrap_shape(shape) for shape in shapes] - - def _wrap_data_handle(handle): cdh = xla_data_pb2.ComputationDataHandle() cdh.handle = handle @@ -260,6 +295,17 @@ def require_numpy_array_layout(value): return np.require(value, requirements=['C', 'A']) +class CompileOptions(object): + """Python object for XLA compile options. + + These options can be passed to the 'compile' step when using a local XLA + client. + """ + + def __init__(self): + self.generate_hlo_graph = None + + def transfer_to_infeed(value, replica_number=None): """Transfers the given value into the XLA infeed queue. @@ -291,8 +337,7 @@ def transfer_from_outfeed(shape, replica_number=None): Returns: The literal value that is produced from the outfeed queue. """ - return c_api.TransferFromOutfeedLocalReplica( - _unwrap_shape(shape), replica_number or 0) + return c_api.TransferFromOutfeedLocalReplica(shape, replica_number or 0) class LocalComputation(object): @@ -313,22 +358,39 @@ class LocalComputation(object): else: self._delete = c_api.DeleteLocalComputation - def Compile(self, argument_shapes=()): + def Compile(self, argument_shapes=(), compile_options=None, layout_fn=None): if self.is_compiled: raise ValueError('Attempt to compile a compiled local XLA computation.') + if layout_fn: + argument_shapes = [ + shape.map_leaves(layout_fn) for shape in argument_shapes + ] return LocalComputation( - self.c_local_computation.Compile(_unwrap_shapes(argument_shapes)), + self.c_local_computation.Compile(argument_shapes, compile_options), is_compiled=True) - def CompileWithExampleArguments(self, arguments=()): + def CompileWithExampleArguments(self, + arguments=(), + compile_options=None, + layout_fn=None): return self.Compile( - argument_shapes=[Shape.from_numpy(arg) for arg in arguments]) + argument_shapes=[Shape.from_numpy(arg) for arg in arguments], + compile_options=compile_options, + layout_fn=layout_fn) - def Execute(self, arguments=()): + def Execute(self, arguments=(), layout_fn=None): + """Execute with Python values as arguments and return value.""" if not self.is_compiled: raise ValueError('Cannot execute an uncompiled local XLA computation.') + argument_shapes = [Shape.from_numpy(arg) for arg in arguments] + if layout_fn: + argument_shapes = [ + shape.map_leaves(layout_fn) for shape in argument_shapes + ] + else: + argument_shapes = [None for shape in argument_shapes] arguments = tuple(map(require_numpy_array_layout, arguments)) - return self.c_local_computation.Execute(arguments) + return self.c_local_computation.Execute(arguments, argument_shapes) def ExecuteWithLocalBuffers(self, arguments=()): """Execute with LocalBuffer arguments and return value.""" @@ -384,7 +446,7 @@ class ComputationBuilder(object): Returns: A ComputationDataHandle message. """ - return _wrap_data_handle(self._client.Infeed(_unwrap_shape(shape))) + return _wrap_data_handle(self._client.Infeed(shape)) def Outfeed(self, operand): """Enqueues an outfeed op onto the computation. @@ -393,7 +455,7 @@ class ComputationBuilder(object): outfeed queue for subsequent dequeue via the client API. """ self._client.Outfeed( - _unwrap_data_handle(operand), _unwrap_shape(self.GetShape(operand)), + _unwrap_data_handle(operand), self.GetShape(operand), ''.encode('utf-8')) def Constant(self, value): @@ -484,8 +546,7 @@ class ComputationBuilder(object): parameter_num = next(self._parameter_numbering) return _wrap_data_handle( - self._client.Parameter( - parameter_num, _unwrap_shape(shape), name.encode('utf8'))) + self._client.Parameter(parameter_num, shape, name.encode('utf8'))) def ParameterFromNumpy(self, value, name=None, parameter_num=None): """Enqueues a Parameter op onto the computation. @@ -606,6 +667,13 @@ class ComputationBuilder(object): return _wrap_data_handle( self._client.Rev(_unwrap_data_handle(operand), dimensions)) + def Clamp(self, min, operand, max): # pylint: disable=redefined-builtin + """Clamp op.""" + return _wrap_data_handle( + self._client.Clamp(_unwrap_data_handle(min), + _unwrap_data_handle(operand), + _unwrap_data_handle(max))) + def SelectAndScatter(self, operand, select, window_dimensions, window_strides, padding, source, init_value, scatter): """Select and scatter op, used by the gradient of ReduceWindow. @@ -825,8 +893,7 @@ class ComputationBuilder(object): shape = Shape(self.GetShape(mu).np_dtype, dims) return _wrap_data_handle( self._client.RngNormal( - _unwrap_data_handle(mu), _unwrap_data_handle(sigma), - _unwrap_shape(shape))) + _unwrap_data_handle(mu), _unwrap_data_handle(sigma), shape)) def RngUniform(self, a, b, dims): """Enqueues an RngUniform operation onto the computation. @@ -846,8 +913,7 @@ class ComputationBuilder(object): shape = Shape(self.GetShape(a).np_dtype, dims) return _wrap_data_handle( self._client.RngUniform( - _unwrap_data_handle(a), _unwrap_data_handle(b), - _unwrap_shape(shape))) + _unwrap_data_handle(a), _unwrap_data_handle(b), shape)) def While(self, cond, body, init): """Enqueues a While operation onto the computation. @@ -855,7 +921,7 @@ class ComputationBuilder(object): Args: cond: a Computation for the loop condition, which has type T -> PRED body: a Computation for the loop body, which has type T -> T - init: an ComputationDataHandle for the initial parameter, which has type T + init: a ComputationDataHandle for the initial parameter, which has type T Returns: a ComputationDataHandle representing the While operation. """ @@ -864,11 +930,58 @@ class ComputationBuilder(object): body.c_local_computation, _unwrap_data_handle(init))) + def Conditional(self, pred, true_operand, true_computation, false_operand, + false_computation): + """Enqueues a Conditional operation onto the computation. + + Args: + predicate: a ComputationDataHandle to test, which has scalar type PRED + true_operand: a ComputationDataHandle of type T_0 + true_computation: a Computation to apply to true_operand, type T_0 -> S + false_operand: a ComputationDatahandle of type T_1 + false_computation: a Computation to apply to false_operand, type T_1 -> S + + Returns: a ComputationDataHandle representing the Conditional operation. + """ + return _wrap_data_handle( + self._client.Conditional( + _unwrap_data_handle(pred), _unwrap_data_handle(true_operand), + true_computation.c_local_computation, + _unwrap_data_handle(false_operand), + false_computation.c_local_computation)) + def Dot(self, lhs, rhs): - """Matrix multiplication between lhs and rhs.""" + """Enqueues a dot operation onto the computation. + + Args: + lhs: ComputationDataHandle for the rank 1 or rank 2 left-hand-side array. + rhs: ComputationDataHandle for the rank 1 or rank 2 right-hand-side array. + + Returns: a ComputationDataHandle representing the Dot operation. + """ return _wrap_data_handle( self._client.Dot(_unwrap_data_handle(lhs), _unwrap_data_handle(rhs))) + def DotGeneral(self, lhs, rhs, dimension_numbers): + """Enqueues a general dot operation onto the computation. + + Args: + lhs: ComputationDataHandle for the left-hand-side array. + rhs: ComputationDataHandle for the right-hand-side array. + dimension_numbers: either an xla_data_pb2.DotDimensionNumbers or a nested + tuple ((lhs_contract, rhs_contract), (lhs_batch, rhs_batch)) of lists of + integers representing the dimensions to treat as contracting dimensions + and batch dimensions on each input operand. + + Returns: a ComputationDataHandle representing the DotGeneral operation. + """ + if not isinstance(dimension_numbers, xla_data_pb2.DotDimensionNumbers): + dimension_numbers = GetDotDimensionsFromLists(dimension_numbers) + return _wrap_data_handle( + self._client.DotGeneral( + _unwrap_data_handle(lhs), _unwrap_data_handle(rhs), + dimension_numbers)) + def Conv(self, lhs, rhs, window_strides, padding): """Enqueues a Conv operation onto the computation. @@ -979,7 +1092,7 @@ def initialize_replica_count(replica_count): Args: replica_count: number of replicas that are desired for set up during XLA - initalization. + initialization. Raises: A runtime exception if the XLA service has already been initialized. @@ -1005,3 +1118,13 @@ def GetPaddingConfigFromTriples(triples): dimension.edge_padding_high = hi dimension.interior_padding = interior return padding_config + + +def GetDotDimensionsFromLists(dimension_numbers): + (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers + dot_dims_proto = xla_data_pb2.DotDimensionNumbers() + dot_dims_proto.lhs_contracting_dimensions.extend(lhs_contract) + dot_dims_proto.rhs_contracting_dimensions.extend(rhs_contract) + dot_dims_proto.lhs_batch_dimensions.extend(lhs_batch) + dot_dims_proto.rhs_batch_dimensions.extend(rhs_batch) + return dot_dims_proto diff --git a/tensorflow/compiler/xla/python/xla_client_test.py b/tensorflow/compiler/xla/python/xla_client_test.py index c0413b9bbc3b7f8b63e4cf7a8f24980322cffc47..65720c6ef9ec1cd7a816bcf719960fa803dd45a1 100644 --- a/tensorflow/compiler/xla/python/xla_client_test.py +++ b/tensorflow/compiler/xla/python/xla_client_test.py @@ -444,6 +444,30 @@ class SingleOpTest(LocalComputationTest): c.Dot(c.Constant(lhs), c.Constant(rhs)) self._ExecuteAndCompareClose(c, expected=np.dot(lhs, rhs)) + def testDotGeneral(self): + c = self._NewComputation() + rng = np.random.RandomState(0) + lhs = NumpyArrayF32(rng.randn(10, 3, 4)) + rhs = NumpyArrayF32(rng.randn(10, 4, 5)) + dimension_numbers = (([2], [1]), ([0], [0])) + c.DotGeneral(c.Constant(lhs), c.Constant(rhs), dimension_numbers) + self._ExecuteAndCompareClose(c, expected=np.matmul(lhs, rhs)) + + def testDotGeneralWithDotDimensionNumbersProto(self): + c = self._NewComputation() + rng = np.random.RandomState(0) + lhs = NumpyArrayF32(rng.randn(10, 3, 4)) + rhs = NumpyArrayF32(rng.randn(10, 4, 5)) + + dimension_numbers = xla_client.xla_data_pb2.DotDimensionNumbers() + dimension_numbers.lhs_contracting_dimensions.append(2) + dimension_numbers.rhs_contracting_dimensions.append(1) + dimension_numbers.lhs_batch_dimensions.append(0) + dimension_numbers.rhs_batch_dimensions.append(0) + + c.DotGeneral(c.Constant(lhs), c.Constant(rhs), dimension_numbers) + self._ExecuteAndCompareClose(c, expected=np.matmul(lhs, rhs)) + def testConvF32Same(self): c = self._NewComputation() a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32") @@ -496,6 +520,12 @@ class SingleOpTest(LocalComputationTest): c.Exp(c.Constant(arr)) self._ExecuteAndCompareClose(c, expected=np.exp(arr)) + def testRound(self): + c = self._NewComputation() + arr = NumpyArrayF32([3.3, 12.1]) + c.Round(c.Constant(arr)) + self._ExecuteAndCompareClose(c, expected=np.round(arr)) + def testLog(self): c = self._NewComputation() arr = NumpyArrayF32([3.3, 12.1]) @@ -699,6 +729,23 @@ class SingleOpTest(LocalComputationTest): self._ExecuteAndCompareExact( c, expected=[[[6, 5], [8, 7]], [[2, 1], [4, 3]]]) + def testClampF32(self): + c = self._NewComputation() + c.Clamp( + c.Constant(NumpyArrayF32(-1)), + c.Constant(NumpyArrayF32([-2, -1, 0, 1, 2, 3])), + c.Constant(NumpyArrayF32(2))) + self._ExecuteAndCompareExact(c, expected=[-1, -1, 0, 1, 2, 2]) + + # TODO(b/72689392): re-enable when bug S32 resolved + def DISABLED_testClampS32(self): + c = self._NewComputation() + c.Clamp( + c.Constant(NumpyArrayS32(-1)), + c.Constant(NumpyArrayS32([-2, -1, 0, 1, 2, 3])), + c.Constant(NumpyArrayS32(2))) + self._ExecuteAndCompareExact(c, expected=[-1, 0, 1, 2, 2]) + def testSelect(self): c = self._NewComputation() c.Select( @@ -1172,6 +1219,28 @@ class EmbeddedComputationsTest(LocalComputationTest): c.While(cond, body, init) self._ExecuteAndCompareClose(c, expected=16.) + def testConditionalTrue(self): + c = self._NewComputation() + pred = c.ConstantPredScalar(True) + true_operand = c.ConstantF32Scalar(3.) + true_computation = self._CreateMulF32By2Computation() + false_operand = c.ConstantF32Scalar(2.) + false_computation = self._CreateConstantF32Computation() + c.Conditional(pred, true_operand, true_computation, false_operand, + false_computation) + self._ExecuteAndCompareClose(c, expected=6.) + + def testConditionalFalse(self): + c = self._NewComputation() + pred = c.ConstantPredScalar(False) + true_operand = c.ConstantF32Scalar(3.) + true_computation = self._CreateMulF32By2Computation() + false_operand = c.ConstantF32Scalar(2.) + false_computation = self._CreateConstantF32Computation() + c.Conditional(pred, true_operand, true_computation, false_operand, + false_computation) + self._ExecuteAndCompareClose(c, expected=1.) + def testInfeedS32Values(self): to_infeed = NumpyArrayS32([1, 2, 3, 4]) c = self._NewComputation() diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index a9eb6b53e026865560e103a1027f9ea1eb9a3965..0f2d0a9e96e20007aa24a22832bdca4f0add372d 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -509,6 +509,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:executable_build_options", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", ], @@ -1884,6 +1885,7 @@ cc_library( ":hlo", ":hlo_graph_dumper", ":hlo_pass", + ":hlo_proto_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index ba82e822b216528c28536181059bc2417048de01..fb857559f972a220a19b108baa4c441e09b90e1f 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -1618,9 +1618,12 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) { reduce, HloInstruction::CreateBroadcast(reduce->shape(), init_value, {})); } + // A Transpose feeding a reduce can simply permute the reduction dimensions - // field. - if (arg->opcode() == HloOpcode::kTranspose) { + // field if the output of the reduce is a vector or scalar. Higher ranked + // result may require a transpose of the output. + if (ShapeUtil::Rank(reduce->shape()) <= 1 && + arg->opcode() == HloOpcode::kTranspose) { auto transpose_dimensions = arg->dimensions(); std::vector new_reduce_dimensions; for (auto dim : dimensions) { diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc index d5594dc07c8f525a431e6a0f0f6865db6d094774..774b11478c6d2faf0eb5db29df3cfd3cc1e98d5b 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment.cc @@ -997,14 +997,15 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering( auto color = single_colored_set.first; VLOG(2) << "Simulating heap for color " << color; int64 alignment = assignment->color_alignment_(color); + HeapSimulator::Options options; + options.buffers_to_assign = &single_colored_set.second; TF_ASSIGN_OR_RETURN( const HeapSimulator::Result result, HeapSimulator::Run(MakeUnique( MakeUnique(alignment)), assignment->module(), module_sequence, assignment->points_to_analysis(), - assignment->buffer_size_, - &single_colored_set.second)); + assignment->buffer_size_, options)); AssignBuffersFromHeapSimulator(result, assignment, single_colored_set.first); } @@ -1024,14 +1025,15 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering( auto color = single_colored_set.first; VLOG(2) << "Simulating heap for color " << color; int64 alignment = assignment->color_alignment_(color); + HeapSimulator::Options options; + options.buffers_to_assign = &single_colored_set.second; TF_ASSIGN_OR_RETURN( const HeapSimulator::Result result, HeapSimulator::Run(MakeUnique( MakeUnique(alignment)), *computation, *instruction_sequence, assignment->points_to_analysis(), - assignment->buffer_size_, - &single_colored_set.second)); + assignment->buffer_size_, options)); AssignBuffersFromHeapSimulator(result, assignment, single_colored_set.first); } diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index 3fdb3d5ca64e2dde9bb6f4ea5c4c50187be7a787..d13a97bcc9a84afb22556389b4cdcd985f58d445 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -519,8 +519,8 @@ StatusOr> CpuCompiler::RunBackend( // ownership is std::moved. const bool embed_ir_in_executable = module->config().debug_options().xla_embed_ir_in_executable(); - const string xla_dump_hlo_proto_to = - module->config().debug_options().xla_dump_hlo_proto_to(); + const string xla_dump_optimized_hlo_proto_to = + module->config().debug_options().xla_dump_optimized_hlo_proto_to(); if (options::CpuParallelBackendRequested(module->config())) { VLOG(1) << "Using parallel cpu backend"; @@ -540,10 +540,10 @@ StatusOr> CpuCompiler::RunBackend( // print one ourselves. XLA_VLOG_LINES(2, assignment->ToString()); - if (!xla_dump_hlo_proto_to.empty()) { + if (!xla_dump_optimized_hlo_proto_to.empty()) { HloProto proto = MakeHloProto(*module, *assignment); TF_RETURN_IF_ERROR(protobuf_util::DumpProtoToDirectory( - proto, xla_dump_hlo_proto_to, module->name())); + proto, xla_dump_optimized_hlo_proto_to, module->name())); } // If we are using the parallel CPU backend, we need to create map from @@ -649,10 +649,10 @@ StatusOr> CpuCompiler::RunBackend( // print one ourselves. XLA_VLOG_LINES(2, assignment->ToString()); - if (!xla_dump_hlo_proto_to.empty()) { + if (!xla_dump_optimized_hlo_proto_to.empty()) { HloProto proto = MakeHloProto(*module, *assignment); TF_RETURN_IF_ERROR(protobuf_util::DumpProtoToDirectory( - proto, xla_dump_hlo_proto_to, module->name())); + proto, xla_dump_optimized_hlo_proto_to, module->name())); } // Each computation is a single function. Emit all embedded computations @@ -828,12 +828,12 @@ CpuCompiler::CompileAheadOfTime(std::vector> modules, // print one ourselves. XLA_VLOG_LINES(2, assignment->ToString()); - const string xla_dump_hlo_proto_to = - module->config().debug_options().xla_dump_hlo_proto_to(); - if (!xla_dump_hlo_proto_to.empty()) { + const string xla_dump_optimized_hlo_proto_to = + module->config().debug_options().xla_dump_optimized_hlo_proto_to(); + if (!xla_dump_optimized_hlo_proto_to.empty()) { HloProto proto = MakeHloProto(*module, *assignment); TF_RETURN_IF_ERROR(protobuf_util::DumpProtoToDirectory( - proto, xla_dump_hlo_proto_to, module->name())); + proto, xla_dump_optimized_hlo_proto_to, module->name())); } IrEmitter ir_emitter(*module, *assignment, &llvm_module, diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index 71e81331897a8bb82438dd5160d2964cb88fd31f..0b2d3d47463b745049807e9afa55360434ad522b 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -479,7 +479,7 @@ Status IrEmitter::HandleOutfeed(HloInstruction* outfeed) { Status IrEmitter::HandleSort(HloInstruction* sort) { // TODO(b/26783907): Implement sort on CPU. - return Unimplemented("Sort is not supported on CPU (b/26783907)."); + return Unimplemented("Sort is not implemented on CPU."); } Status IrEmitter::HandleTuple(HloInstruction* tuple) { @@ -522,7 +522,7 @@ Status IrEmitter::HandleReduceWindow(HloInstruction* reduce_window) { // TODO(b/31410564): Implement dilation for reduce-window. if (window_util::HasDilation(window)) { return Unimplemented( - "Dilation for reduce-window not implemented on CPU. See b/31410564."); + "Dilation for ReduceWindow is not implemented on CPU."); } // The called computation should have been emitted previously. @@ -625,8 +625,7 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) { // TODO(b/31410564): Implement dilation for select-and-scatter. if (window_util::HasDilation(window)) { return Unimplemented( - "Dilation for select-and-scatter not implemented on CPU. " - "See b/31410564."); + "Dilation for SelectAndScatter is not implemented on CPU. "); } // The select and scatter computations should have been emitted previously. @@ -1196,8 +1195,7 @@ Status IrEmitter::HandleCrossReplicaSum(HloInstruction* crs) { } // TODO(b/33011107): Support cross replica sum on CPU. - return Unimplemented( - "Cross replica sum is not implemented on CPU. See b/33011107."); + return Unimplemented("CrossReplicaSum is not implemented on CPU."); } // Fills up the free variables in 'index_with_free_var' with values from @@ -1811,12 +1809,12 @@ Status IrEmitter::HandleReduce(HloInstruction* reduce) { Status IrEmitter::HandleSend(HloInstruction* send) { // TODO(b/33942983): Support Send/Recv on CPU. - return Unimplemented("Send is not implemented on CPU. See b/33942983."); + return Unimplemented("Send is not implemented on CPU."); } Status IrEmitter::HandleSendDone(HloInstruction* send_done) { // TODO(b/33942983): Support Send/Recv on CPU. - return Unimplemented("Send-done is not implemented on CPU. See b/33942983."); + return Unimplemented("Send-done is not implemented on CPU."); } Status IrEmitter::HandleSlice(HloInstruction* slice) { @@ -1981,12 +1979,12 @@ Status IrEmitter::HandleDynamicUpdateSlice( Status IrEmitter::HandleRecv(HloInstruction* recv) { // TODO(b/33942983): Support Send/Recv on CPU. - return Unimplemented("Recv is not implemented on CPU. See b/33942983."); + return Unimplemented("Recv is not implemented on CPU."); } Status IrEmitter::HandleRecvDone(HloInstruction* recv_done) { // TODO(b/33942983): Support Send/Recv on CPU. - return Unimplemented("Recv-done is not implemented on CPU. See b/33942983."); + return Unimplemented("Recv-done is not implemented on CPU."); } Status IrEmitter::HandlePad(HloInstruction* pad) { @@ -1995,10 +1993,10 @@ Status IrEmitter::HandlePad(HloInstruction* pad) { for (auto& padding_dimension : pad->padding_config().dimensions()) { if (padding_dimension.edge_padding_low() < 0 || padding_dimension.edge_padding_high() < 0) { - return Unimplemented( - "Negative padding not supported in the CPU backend (b/34628603); " - "this should have been eliminated at the HLO level: %s", - pad->padding_config().ShortDebugString().c_str()); + return InternalErrorStrCat( + "Encountered negative padding in IrEmitter on CPU. " + "This should have been eliminated at the HLO level. ", + pad->ToString()); } } diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc index 9780bac16ec17eed2c1df64f01bcb753e26b46f0..4468adbadbf823f1420a8b665a26f66cb7d36b43 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc @@ -428,7 +428,7 @@ StatusOr ElementalIrEmitter::EmitFloatUnaryOp( llvm::Intrinsic::round, {operand_value}, {operand_value->getType()}, ir_builder_); case HloOpcode::kSign: { - // TODO(b/32151903): Ensure consistent sign behavior for -0.0 + // TODO(b/32151903): Ensure consistent sign behavior for -0.0. auto type = operand_value->getType(); auto zero = llvm::ConstantFP::get(type, 0.0); auto oeq = ir_builder_->CreateFCmpOEQ(operand_value, zero); @@ -870,7 +870,10 @@ llvm::Value* ElementalIrEmitter::EmitFloatMin(llvm::Value* lhs_value, StatusOr ElementalIrEmitter::EmitErfInv(PrimitiveType prim_type, llvm::Value* x) const { if (prim_type != F32) { - return Unimplemented("inverse erf only implemented for F32 (b/34339814)"); + // TODO(b/34339814): Implement inverse erf for F64. + return Unimplemented( + "Inverse erf is only implemented for element " + "type F32."); } auto getFloat = [&](const float f) { return llvm::ConstantFP::get(ir_builder_->getFloatTy(), f); @@ -1040,17 +1043,9 @@ StatusOr ElementalIrEmitter::EmitIntegerBinaryOp( is_signed ? llvm::CmpInst::ICMP_SGE : llvm::CmpInst::ICMP_UGE, lhs_value, rhs_value, ir_builder_); case HloOpcode::kMinimum: - return ir_builder_->CreateSelect( - ir_builder_->CreateICmp( - is_signed ? llvm::ICmpInst::ICMP_SLE : llvm::ICmpInst::ICMP_ULE, - lhs_value, rhs_value), - lhs_value, rhs_value); + return EmitIntegralMin(lhs_value, rhs_value, is_signed); case HloOpcode::kMaximum: - return ir_builder_->CreateSelect( - ir_builder_->CreateICmp( - is_signed ? llvm::ICmpInst::ICMP_SGE : llvm::ICmpInst::ICMP_UGE, - lhs_value, rhs_value), - lhs_value, rhs_value); + return EmitIntegralMax(lhs_value, rhs_value, is_signed); case HloOpcode::kAnd: return ir_builder_->CreateAnd(lhs_value, rhs_value); case HloOpcode::kOr: @@ -1067,6 +1062,26 @@ StatusOr ElementalIrEmitter::EmitIntegerBinaryOp( } } +llvm::Value* ElementalIrEmitter::EmitIntegralMax(llvm::Value* lhs_value, + llvm::Value* rhs_value, + bool is_signed) const { + return ir_builder_->CreateSelect( + ir_builder_->CreateICmp( + is_signed ? llvm::ICmpInst::ICMP_SGE : llvm::ICmpInst::ICMP_UGE, + lhs_value, rhs_value), + lhs_value, rhs_value); +} + +llvm::Value* ElementalIrEmitter::EmitIntegralMin(llvm::Value* lhs_value, + llvm::Value* rhs_value, + bool is_signed) const { + return ir_builder_->CreateSelect( + ir_builder_->CreateICmp( + is_signed ? llvm::ICmpInst::ICMP_SLE : llvm::ICmpInst::ICMP_ULE, + lhs_value, rhs_value), + lhs_value, rhs_value); +} + llvm_ir::IrArray::Index ElementalIrEmitter::ElementwiseSourceIndex( const llvm_ir::IrArray::Index& target_index, const HloInstruction& hlo, int64 operand_no) const { @@ -1363,7 +1378,18 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( TF_ASSIGN_OR_RETURN(llvm::Value * max_value, operand_to_generator.at(hlo->operand(2))( ElementwiseSourceIndex(index, *hlo, 2))); - return EmitFloatMin(max_value, EmitFloatMax(min_value, arg_value)); + PrimitiveType prim_type = hlo->shape().element_type(); + if (primitive_util::IsFloatingPointType(prim_type)) { + return EmitFloatMin(max_value, EmitFloatMax(min_value, arg_value)); + } else if (primitive_util::IsIntegralType(prim_type)) { + bool is_signed = primitive_util::IsSignedIntegralType(prim_type); + return EmitIntegralMin( + max_value, EmitIntegralMax(min_value, arg_value, is_signed), + is_signed); + } else { + return Unimplemented("Clamp unimplemented for %s", + PrimitiveType_Name(prim_type).c_str()); + } }; case HloOpcode::kReducePrecision: return [this, hlo, &operand_to_generator]( diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/elemental_ir_emitter.h index 1a48eb5fcb960b60d524ea56a43e15269576db76..c516a826d9e382bc738e54635426db639d17108c 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.h +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.h @@ -86,6 +86,12 @@ class ElementalIrEmitter { virtual llvm::Value* EmitFloatMin(llvm::Value* lhs_value, llvm::Value* rhs_value) const; + llvm::Value* EmitIntegralMax(llvm::Value* lhs_value, llvm::Value* rhs_value, + bool is_signed) const; + + llvm::Value* EmitIntegralMin(llvm::Value* lhs_value, llvm::Value* rhs_value, + bool is_signed) const; + virtual StatusOr EmitErfInv(PrimitiveType prim_type, llvm::Value* value) const; diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 80c2eed1097e4a7dcbf29b9b1c02fb9964983368..9da4fb97fa27a238fead74985cb481a9be1f4a65 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -129,8 +129,11 @@ cc_library( hdrs = [ "ir_emitter.h", "ir_emitter_context.h", + "ir_emitter_nested.h", + "ir_emitter_unnested.h", ], deps = [ + ":cudnn_convolution_runner", ":elemental_ir_emitter", ":gpu_constants", ":gpu_executable", @@ -262,6 +265,7 @@ cc_library( ], deps = [ ":buffer_allocations", + ":cudnn_convolution_runner", ":infeed_manager", ":ir_emission_utils", ":partition_assignment", @@ -309,9 +313,41 @@ cc_library( ) cc_library( - name = "convolution_folding", - srcs = ["convolution_folding.cc"], - hdrs = ["convolution_folding.h"], + name = "cudnn_convolution_algorithm_picker", + srcs = ["cudnn_convolution_algorithm_picker.cc"], + hdrs = ["cudnn_convolution_algorithm_picker.h"], + deps = [ + ":cudnn_convolution_runner", + ":gpu_executable", + ":ir_emission_utils", + "//tensorflow/compiler/xla/service:device_memory_allocator", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_pass", + "//tensorflow/core:lib", + "//tensorflow/core:stream_executor_no_cuda", + ], +) + +cc_library( + name = "cudnn_convolution_runner", + srcs = ["cudnn_convolution_runner.cc"], + hdrs = ["cudnn_convolution_runner.h"], + deps = [ + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:stream_executor_no_cuda", + ], +) + +cc_library( + name = "cudnn_convolution_rewriter", + srcs = ["cudnn_convolution_rewriter.cc"], + hdrs = ["cudnn_convolution_rewriter.h"], deps = [ ":ir_emission_utils", "//tensorflow/compiler/xla:literal_util", @@ -325,15 +361,18 @@ cc_library( ) tf_cc_test( - name = "convolution_folding_test", - srcs = ["convolution_folding_test.cc"], + name = "cudnn_convolution_rewriter_test", + srcs = ["cudnn_convolution_rewriter_test.cc"], deps = [ - ":convolution_folding", + ":cudnn_convolution_rewriter", + ":ir_emission_utils", + "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_matchers", "//tensorflow/compiler/xla/service:shape_inference", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep "//tensorflow/core:test", ], ) @@ -446,7 +485,8 @@ cc_library( srcs = ["gpu_compiler.cc"], hdrs = ["gpu_compiler.h"], deps = [ - ":convolution_folding", + ":cudnn_convolution_algorithm_picker", + ":cudnn_convolution_rewriter", ":fusion_merger", ":gpu_constants", ":gpu_copy_insertion", diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc index 899cc5c83b99f1bb6154f883ca17871863e1f457..f76f15929d12eed63d8964acd61fb3fea3945006 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/strings/strcat.h" @@ -36,366 +37,69 @@ using se::dnn::DataLayout; using se::dnn::FilterDescriptor; using se::dnn::FilterLayout; -ConvolveScratchAllocator::ConvolveScratchAllocator( - int device_ordinal, DeviceMemoryAllocator* memory_allocator) - : device_ordinal_(device_ordinal), memory_allocator_(memory_allocator) {} - -ConvolveScratchAllocator::~ConvolveScratchAllocator() { - for (auto& allocated_buffer : allocated_buffers_) { - if (!memory_allocator_->Deallocate(device_ordinal_, &allocated_buffer) - .ok()) { - // The program can still continue with failed deallocation. - LOG(ERROR) << "Failed to deallocate the allocated buffer: " - << allocated_buffer.opaque(); - } - } -} - -int64 ConvolveScratchAllocator::GetMemoryLimitInBytes(se::Stream* stream) { - constexpr int64 kConvolveScratchSize = 1LL << 32; // 4GB by default. - return kConvolveScratchSize; -} - -se::port::StatusOr> -ConvolveScratchAllocator::AllocateBytes(se::Stream* stream, int64 byte_size) { - CHECK_GE(byte_size, 0) << "byte_size must be positive."; - if (byte_size > GetMemoryLimitInBytes(stream)) { - return se::port::Status( - se::port::error::RESOURCE_EXHAUSTED, - tensorflow::strings::Printf( - "Allocating %lld bytes exceeds the memory limit of %lld bytes.", - byte_size, GetMemoryLimitInBytes(stream))); - } - - auto status_or_memory = - memory_allocator_->Allocate(device_ordinal_, byte_size, - /*retry_on_failure=*/false); - if (!status_or_memory.ok()) { - return se::port::Status(se::port::error::RESOURCE_EXHAUSTED, - tensorflow::strings::Printf( - "Failed to allocate %lld bytes on device %d.", - byte_size, device_ordinal_)); - } - se::DeviceMemoryBase allocated_buffer = status_or_memory.ValueOrDie(); - allocated_buffers_.push_back(allocated_buffer); - total_allocated_bytes_ += byte_size; - return se::DeviceMemory(allocated_buffer); -} - -string ConvolutionKindToString( - ConvolutionThunk::ConvolutionKind convolution_kind) { - switch (convolution_kind) { - case ConvolutionThunk::ConvolutionKind::kForward: - return "forward"; - case ConvolutionThunk::ConvolutionKind::kBackwardFilter: - return "backward_filter"; - case ConvolutionThunk::ConvolutionKind::kBackwardInput: - return "backward_input"; - } - return "unknown convolution kind"; -} - ConvolutionThunk::ConvolutionThunk( - ConvolutionKind convolution_kind, - const BufferAllocation::Slice& input_buffer, + CudnnConvKind convolution_kind, const BufferAllocation::Slice& input_buffer, const BufferAllocation::Slice& filter_buffer, - const BufferAllocation::Slice& output_buffer, const Shape& input_shape, + const BufferAllocation::Slice& output_buffer, + const BufferAllocation::Slice& tuple_result_buffer, + const BufferAllocation::Slice& scratch_buffer, const Shape& input_shape, const Shape& filter_shape, const Shape& output_shape, const Window& window, - const ConvolutionDimensionNumbers& dim_nums, const HloInstruction* hlo) + const ConvolutionDimensionNumbers& dim_nums, int64 algorithm, + const HloInstruction* hlo) : Thunk(Kind::kConvolution, hlo), convolution_kind_(convolution_kind), input_buffer_(input_buffer), filter_buffer_(filter_buffer), output_buffer_(output_buffer), + tuple_result_buffer_(tuple_result_buffer), + scratch_buffer_(scratch_buffer), input_shape_(input_shape), filter_shape_(filter_shape), output_shape_(output_shape), window_(window), - dim_nums_(dim_nums) {} + dim_nums_(dim_nums), + algorithm_(algorithm) {} -tensorflow::Status ConvolutionThunk::ExecuteOnStream( +Status ConvolutionThunk::ExecuteOnStream( const BufferAllocations& buffer_allocations, se::Stream* stream) { - VLOG(3) << "Convolution kind: " << ConvolutionKindToString(convolution_kind_); - VLOG(3) << "input shape: { " << input_shape_.ShortDebugString() << " }"; - VLOG(3) << "filter shape: { " << filter_shape_.ShortDebugString() << " }"; - VLOG(3) << "Output shape: { " << output_shape_.ShortDebugString() << " }"; - VLOG(3) << "Dim nums: { " << dim_nums_.ShortDebugString() << " }"; - VLOG(3) << "Window: { " << window_.ShortDebugString() << " }"; - - const int num_dimensions = window_.dimensions_size(); - CHECK_LE(num_dimensions, 3); - // cuDNN does not support 1D convolutions. We therefore express 1D - // convolutions as 2D convolutions where the first spatial dimension is 1. - // This matches the behavior of TF (see definition of conv1d in - // tensorflow/python/ops/nn_ops.py). - const int effective_num_dimensions = std::max(2, num_dimensions); - - CHECK_EQ(F32, output_shape_.element_type()); - CHECK_EQ(num_dimensions, dim_nums_.input_spatial_dimensions_size()); - CHECK_EQ(num_dimensions, dim_nums_.kernel_spatial_dimensions_size()); - CHECK_EQ(num_dimensions, dim_nums_.output_spatial_dimensions_size()); - for (const WindowDimension& dim : window_.dimensions()) { - CHECK_EQ(dim.padding_low(), dim.padding_high()); - } - - // cuDNN's convolution APIs support the BDYX layout for activations/output and - // the OIYX layout for weights. - BatchDescriptor input_descriptor(effective_num_dimensions); - input_descriptor.set_layout(DataLayout::kBatchDepthYX) - .set_feature_map_count( - input_shape_.dimensions(dim_nums_.input_feature_dimension())) - .set_count(input_shape_.dimensions(dim_nums_.input_batch_dimension())); - for (int dim = 0; dim < num_dimensions; ++dim) { - // Note that the dimensions are reversed. The same holds below. - input_descriptor.set_spatial_dim( - static_cast(effective_num_dimensions - dim - 1), - input_shape_.dimensions(dim_nums_.input_spatial_dimensions(dim))); - } - - FilterDescriptor filter_descriptor(effective_num_dimensions); - filter_descriptor.set_layout(FilterLayout::kOutputInputYX) - .set_input_feature_map_count( - filter_shape_.dimensions(dim_nums_.kernel_input_feature_dimension())) - .set_output_feature_map_count(filter_shape_.dimensions( - dim_nums_.kernel_output_feature_dimension())); - for (int dim = 0; dim < num_dimensions; ++dim) { - filter_descriptor.set_spatial_dim( - static_cast(effective_num_dimensions - dim - 1), - filter_shape_.dimensions(dim_nums_.kernel_spatial_dimensions(dim))); - } - - ConvolutionDescriptor convolution_descriptor(effective_num_dimensions); - for (int dim = 0; dim < num_dimensions; ++dim) { - convolution_descriptor - .set_zero_padding( - static_cast(effective_num_dimensions - dim - 1), - window_.dimensions(dim).padding_low()) - .set_filter_stride( - static_cast(effective_num_dimensions - dim - 1), - window_.dimensions(dim).stride()); - } - - BatchDescriptor output_descriptor(effective_num_dimensions); - output_descriptor.set_layout(DataLayout::kBatchDepthYX) - .set_feature_map_count( - output_shape_.dimensions(dim_nums_.output_feature_dimension())) - .set_count(output_shape_.dimensions(dim_nums_.output_batch_dimension())); - for (int dim = 0; dim < num_dimensions; ++dim) { - output_descriptor.set_spatial_dim( - static_cast(effective_num_dimensions - dim - 1), - output_shape_.dimensions(dim_nums_.output_spatial_dimensions(dim))); - } - - // Add a singleton dimension in the 1D convolution case. - if (num_dimensions == 1) { - input_descriptor.set_spatial_dim(static_cast(0), 1); - output_descriptor.set_spatial_dim(static_cast(0), 1); - filter_descriptor.set_spatial_dim(static_cast(0), 1); - convolution_descriptor - .set_zero_padding(static_cast(0), 0) - .set_filter_stride(static_cast(0), 1); - } - se::DeviceMemory input_data( buffer_allocations.GetDeviceAddress(input_buffer_)); se::DeviceMemory filter_data( buffer_allocations.GetDeviceAddress(filter_buffer_)); se::DeviceMemory output_data( buffer_allocations.GetDeviceAddress(output_buffer_)); - return ConvolveWithTune(input_descriptor, input_data, filter_descriptor, - filter_data, output_descriptor, output_data, - convolution_descriptor, buffer_allocations, stream); -} - -tensorflow::Status ConvolutionThunk::Convolve( - const BatchDescriptor& input_descriptor, se::DeviceMemory input_data, - const FilterDescriptor& filter_descriptor, - se::DeviceMemory filter_data, - const BatchDescriptor& output_descriptor, - se::DeviceMemory output_data, - const ConvolutionDescriptor& convolution_descriptor, - const se::dnn::AlgorithmConfig& algorithm_config, se::Stream* stream, - ConvolveScratchAllocator* scratch_allocator, - se::dnn::ProfileResult* profile_result) { - bool launch_ok; - switch (convolution_kind_) { - case ConvolutionKind::kBackwardFilter: - launch_ok = - stream - ->ThenConvolveBackwardFilterWithAlgorithm( - input_descriptor, input_data, output_descriptor, output_data, - convolution_descriptor, filter_descriptor, &filter_data, - scratch_allocator, algorithm_config, profile_result) - .ok(); - break; - case ConvolutionKind::kBackwardInput: - launch_ok = stream - ->ThenConvolveBackwardDataWithAlgorithm( - filter_descriptor, filter_data, output_descriptor, - output_data, convolution_descriptor, input_descriptor, - &input_data, scratch_allocator, algorithm_config, - profile_result) - .ok(); - break; - case ConvolutionKind::kForward: - launch_ok = - stream - ->ThenConvolveWithAlgorithm( - input_descriptor, input_data, filter_descriptor, filter_data, - convolution_descriptor, output_descriptor, &output_data, - scratch_allocator, algorithm_config, profile_result) - .ok(); - break; - } - if (launch_ok) { - return tensorflow::Status::OK(); - } - return InternalError( - "Unable to launch convolution for thunk %p with type %s and algorithm " - "(%lld, %lld)", - this, ConvolutionKindToString(convolution_kind_).c_str(), - algorithm_config.algorithm().algo_id(), - algorithm_config.algorithm_no_scratch().algo_id()); -} - -std::vector ConvolutionThunk::GetAlgorithms( - bool with_winograd_nonfused, se::StreamExecutor* stream_exec) const { - std::vector algorithms; - switch (convolution_kind_) { - case ConvolutionKind::kBackwardFilter: - CHECK(stream_exec->GetConvolveBackwardFilterAlgorithms( - with_winograd_nonfused, &algorithms)); - break; - case ConvolutionKind::kBackwardInput: - CHECK(stream_exec->GetConvolveBackwardDataAlgorithms( - with_winograd_nonfused, &algorithms)); - break; - case ConvolutionKind::kForward: - CHECK(stream_exec->GetConvolveAlgorithms(with_winograd_nonfused, - &algorithms)); - break; - } - return algorithms; -} - -static string AlgorithmToString(const se::dnn::AlgorithmDesc& algo) { - if (algo.tensor_ops_enabled()) { - return tensorflow::strings::StrCat(algo.algo_id(), "+TC"); - } - return tensorflow::strings::StrCat(algo.algo_id()); -} - -// Determines whether we can safely perform a winograd non-fused convolution for -// the given input and output descriptors. This works around b/68264959, an -// integer overflow in cuDNNv5 and cuDNNv6. -static bool ShouldIncludeWinogradNonfusedAlgo( - const BatchDescriptor& input_descriptor, - const BatchDescriptor& output_descriptor) { - int64 batch = input_descriptor.count(); - int64 in_depths = input_descriptor.feature_map_count(); - int64 in_rows = input_descriptor.height(); - int64 in_cols = input_descriptor.width(); - int64 out_depths = output_descriptor.feature_map_count(); - - int64 total_size = 16 * std::ceil(batch / 16.0) * - std::max(in_depths, out_depths) * in_cols * in_rows * - sizeof(float); - int64 threshold = 1L << 31; - - return total_size < threshold; -} - -tensorflow::Status ConvolutionThunk::ConvolveWithTune( - const BatchDescriptor& input_descriptor, se::DeviceMemory input_data, - const FilterDescriptor& filter_descriptor, - se::DeviceMemory filter_data, - const BatchDescriptor& output_descriptor, - se::DeviceMemory output_data, - const ConvolutionDescriptor& convolution_descriptor, - const BufferAllocations& buffer_allocations, se::Stream* stream) { - // TODO(b/29126320): Try cudnn v5's new auto-tuner when it's rolled out. - if (!best_algorithm_.has_value()) { - best_algorithm_.emplace(); - - // Auto-tuning either is disabled or only happens in the first run of this - // function. - VLOG(2) << "Profiling for best convolution algorithm used for " - "ConvolutionThunk: " - << this; - - bool with_winograd_nonfused = - ShouldIncludeWinogradNonfusedAlgo(input_descriptor, output_descriptor); - - se::dnn::ProfileResult best_result; - se::dnn::ProfileResult best_result_without_scratch; - std::vector algorithms = - GetAlgorithms(with_winograd_nonfused, stream->parent()); - for (auto algorithm : algorithms) { - ConvolveScratchAllocator scratch_allocator( - buffer_allocations.device_ordinal(), - buffer_allocations.memory_allocator()); - se::dnn::ProfileResult profile_result; - VLOG(3) << "Trying algorithm " << AlgorithmToString(algorithm) - << " for ConvolutionThunk: " << this; - bool launch_ok = - Convolve(input_descriptor, input_data, filter_descriptor, filter_data, - output_descriptor, output_data, convolution_descriptor, - se::dnn::AlgorithmConfig(algorithm, algorithm), stream, - &scratch_allocator, &profile_result) - .ok(); - if (launch_ok && profile_result.is_valid()) { - VLOG(3) << "Run of algorithm " << AlgorithmToString(algorithm) - << " for ConvolutionThunk " << this << " succeeded, taking " - << profile_result.elapsed_time_in_ms() - << "ms. (Best result: " << best_result.elapsed_time_in_ms() - << "ms)"; - if (profile_result.elapsed_time_in_ms() < - best_result.elapsed_time_in_ms()) { - best_result = profile_result; - } - if (scratch_allocator.TotalAllocatedBytes() == 0 && - profile_result.elapsed_time_in_ms() < - best_result_without_scratch.elapsed_time_in_ms()) { - best_result_without_scratch = profile_result; - } - } else { - VLOG(3) << "Run of algorithm " << AlgorithmToString(algorithm) - << " for ConvolutionThunk " << this << " failed."; - } - } - - if (best_result.is_valid()) { - best_algorithm_->set_algorithm(best_result.algorithm()); - } else { - LOG(ERROR) << "No convolution algorithm works with profiling. Fall back " - "to the default algorithm."; - best_algorithm_->set_algorithm(AlgorithmDesc()); + se::DeviceMemoryBase scratch = + buffer_allocations.GetDeviceAddress(scratch_buffer_); + + se::dnn::AlgorithmConfig algorithm_config( + se::dnn::AlgorithmDesc(algorithm_, /*use_tensor_ops=*/false)); + + TF_RETURN_IF_ERROR(RunCudnnConvolution( + convolution_kind_, input_shape_, filter_shape_, output_shape_, input_data, + filter_data, output_data, scratch, window_, dim_nums_, algorithm_config, + stream)); + + // Figure out which of output/input/filter is the result produced by this op, + // and write the result tuple. + void* result_ptr = [&] { + switch (convolution_kind_) { + case CudnnConvKind::kForward: + return output_data.opaque(); + case CudnnConvKind::kBackwardInput: + return input_data.opaque(); + case CudnnConvKind::kBackwardFilter: + return filter_data.opaque(); } + }(); + void* ptrs[] = {result_ptr, scratch.opaque()}; + se::DeviceMemory tuple_addr( + buffer_allocations.GetDeviceAddress(tuple_result_buffer_)); + stream->ThenMemcpyH2D(ptrs, &tuple_addr); - if (best_result_without_scratch.is_valid()) { - best_algorithm_->set_algorithm_no_scratch( - best_result_without_scratch.algorithm()); - } else { - LOG(ERROR) << "No convolution algorithm without scratch works with " - "profiling. Fall back " - "to the default algorithm."; - best_algorithm_->set_algorithm_no_scratch(AlgorithmDesc()); - } - } - - { - VLOG(2) << "Using convolution algorithm (" - << AlgorithmToString(best_algorithm_->algorithm()) << ", " - << AlgorithmToString(best_algorithm_->algorithm_no_scratch()) - << ") for ConvolutionThunk: " << this; - ConvolveScratchAllocator scratch_allocator( - buffer_allocations.device_ordinal(), - buffer_allocations.memory_allocator()); - return Convolve(input_descriptor, input_data, filter_descriptor, - filter_data, output_descriptor, output_data, - convolution_descriptor, *best_algorithm_, stream, - &scratch_allocator, nullptr); + if (!stream->ok()) { + return InternalError("ConvolutionThunk::ExecuteOnStream failed."); } + return Status::OK(); } } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h index 46c94d0bf1e486fb91e63109efb8e4ba778c4120..ca9ef5277b3369dea3f698d1bcf0ad190d2c5217 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" +#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h" #include "tensorflow/compiler/xla/service/gpu/gpu_executable.h" #include "tensorflow/compiler/xla/service/gpu/thunk.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -30,106 +31,47 @@ limitations under the License. namespace xla { namespace gpu { -// A one-time scratch allocator for forward and backward convolution. The -// scratch buffers allocated are released on destruction. -// -// Not thread-safe. -class ConvolveScratchAllocator : public perftools::gputools::ScratchAllocator { - public: - ConvolveScratchAllocator(int device_ordinal, - DeviceMemoryAllocator* memory_allocator); - - ~ConvolveScratchAllocator() override; - - int64 GetMemoryLimitInBytes(perftools::gputools::Stream* stream) override; - - int64 TotalAllocatedBytes() { return total_allocated_bytes_; } - - perftools::gputools::port::StatusOr> - AllocateBytes(perftools::gputools::Stream* stream, int64 byte_size) override; - - private: - const int device_ordinal_; - DeviceMemoryAllocator* memory_allocator_; - std::vector allocated_buffers_; - int64 total_allocated_bytes_ = 0; -}; - // This class stores everything that StreamExecutor needs to launch a BNN // convolution. It is generated by IrEmitter. // // This is thread-compatible. class ConvolutionThunk : public Thunk { public: - // ConvolutionThunk performs one of the following types of convolution. - enum class ConvolutionKind { - kBackwardFilter, // Backward convolution for filter. - kBackwardInput, // Backward convolution for input. - kForward, // Forward convolution. - }; - - // Constructs a thunk for launching a DNN convolution. + // Constructs a thunk for launching a DNN convolution. When run, it will + // write a tuple (result, scratch_memory) into `tuple_result_buffer`. + // + // `algorithm` is a cudnn algorithm number. `algorithm == -1` indicates that + // we should use the default (i.e. baseline) cudnn algorithm. + // + // Note that "output" here doesn't refer to the output from running this + // thunk, but rather to the "output" of a hypothetical forward convolution + // that corresponds to this input+filter+output triple. That is, the result + // generated by this thunk is "output" for forward convs, "input" for + // backward-input convs, and "filter" for backward-filter convs. + // // Semantics of null hlo_instruction argument are as in Thunk. - ConvolutionThunk(ConvolutionKind convolution_kind, + ConvolutionThunk(CudnnConvKind convolution_kind, const BufferAllocation::Slice& input_buffer, const BufferAllocation::Slice& filter_buffer, const BufferAllocation::Slice& output_buffer, + const BufferAllocation::Slice& tuple_result_buffer, + const BufferAllocation::Slice& scratch_buffer, const Shape& input_shape, const Shape& filter_shape, const Shape& output_shape, const Window& window, - const ConvolutionDimensionNumbers& dnums, + const ConvolutionDimensionNumbers& dim_nums, int64 algorithm, const HloInstruction* hlo); ConvolutionThunk(const ConvolutionThunk&) = delete; ConvolutionThunk& operator=(const ConvolutionThunk&) = delete; - // Does the convolution for the thunk on "stream". Auto-tuning happens on the - // first run of this function. - tensorflow::Status ExecuteOnStream( - const BufferAllocations& buffer_allocations, - perftools::gputools::Stream* stream) override; - - // Returns true if the next run of ExecuteOnStream will do autotuning. If so, - // we want the GPU to be quiescent during autotuning, so as not to introduce - // noise in our results. - bool ShouldHaltAllActivityBeforeRunning( - perftools::gputools::Stream*) override { - return !best_algorithm_.has_value(); - } - - // Return true if scratch memory is needed to execute the thunk, that is - // either the best algorithm hasn't been chosen or the best algorithm is not - // the same as the no-scratch algorithm. This is because that the execution - // of the thunk is asynchronous, and the scratch allocator goes out of - // scope before the thunk finishes execution. Returning true tells the stream - // executor to make future thunks wait for this thunk to avoid reusing the - // deallocated scratch memory until this thunk is done with it. - bool ShouldBlockFutureThunks() { - if (!best_algorithm_.has_value()) { - return true; - } - - const perftools::gputools::dnn::AlgorithmDesc& best_alg = - best_algorithm_->algorithm(); - const perftools::gputools::dnn::AlgorithmDesc& no_scratch_best_alg = - best_algorithm_->algorithm_no_scratch(); - return (!best_alg.is_default() || !no_scratch_best_alg.is_default() || - !(best_alg == no_scratch_best_alg)); - } + // Does the convolution for the thunk on "stream". + Status ExecuteOnStream(const BufferAllocations& buffer_allocations, + perftools::gputools::Stream* stream) override; private: - tensorflow::Status ConvolveWithTune( - const perftools::gputools::dnn::BatchDescriptor& input_descriptor, - perftools::gputools::DeviceMemory input_data, - const perftools::gputools::dnn::FilterDescriptor& filter_descriptor, - perftools::gputools::DeviceMemory filter_data, - const perftools::gputools::dnn::BatchDescriptor& output_descriptor, - perftools::gputools::DeviceMemory output_data, - const perftools::gputools::dnn::ConvolutionDescriptor& - convolution_descriptor, - const BufferAllocations& buffer_allocations, - perftools::gputools::Stream* stream); + class ScratchAllocator; - tensorflow::Status Convolve( + Status Convolve( const perftools::gputools::dnn::BatchDescriptor& input_descriptor, perftools::gputools::DeviceMemory input_data, const perftools::gputools::dnn::FilterDescriptor& filter_descriptor, @@ -139,40 +81,26 @@ class ConvolutionThunk : public Thunk { const perftools::gputools::dnn::ConvolutionDescriptor& convolution_descriptor, const perftools::gputools::dnn::AlgorithmConfig& algorithm_config, - perftools::gputools::Stream* stream, - ConvolveScratchAllocator* scratch_allocator, + perftools::gputools::Stream* stream, ScratchAllocator* scratch_allocator, perftools::gputools::dnn::ProfileResult* profile_result); - // Returns the convolve algorithms that can be used for this ConvolutionThunk. - std::vector GetAlgorithms( - bool with_winograd_nonfused, - perftools::gputools::StreamExecutor* stream_exec) const; - - // Fastest cuDNN convolution algorithm for this thunk learned from - // auto-tuning. If auto-tuning is disabled or failed, best_algorithm_ is set - // to the default value, indicating cuDNN's convolution will choose the best - // algorithm from some heuristics based on its parameters. - tensorflow::gtl::optional - best_algorithm_; - - const ConvolutionKind convolution_kind_; + const CudnnConvKind convolution_kind_; const BufferAllocation::Slice input_buffer_; const BufferAllocation::Slice filter_buffer_; const BufferAllocation::Slice output_buffer_; + const BufferAllocation::Slice tuple_result_buffer_; + const BufferAllocation::Slice scratch_buffer_; const Shape input_shape_; const Shape filter_shape_; const Shape output_shape_; const Window window_; - const ConvolutionDimensionNumbers dim_nums_; + int64 algorithm_; }; -string ConvolutionKindToString( - ConvolutionThunk::ConvolutionKind convolution_kind); - } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc new file mode 100644 index 0000000000000000000000000000000000000000..621b2d510fa98af40b89badebef5e45902f23d4c --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc @@ -0,0 +1,370 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h" +#include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h" +#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" +#include "tensorflow/core/lib/gtl/optional.h" +#include "tensorflow/core/lib/strings/numbers.h" +#include "tensorflow/core/lib/strings/strcat.h" + +namespace xla { +namespace gpu { +namespace { + +namespace se = perftools::gputools; + +using se::DeviceMemoryBase; +using se::dnn::AlgorithmConfig; +using se::dnn::AlgorithmDesc; +using tensorflow::gtl::nullopt; +using tensorflow::gtl::optional; + +class ScratchAllocator : public se::ScratchAllocator { + public: + ScratchAllocator(int device_ordinal, DeviceMemoryAllocator* memory_allocator) + : device_ordinal_(device_ordinal), memory_allocator_(memory_allocator) {} + + ~ScratchAllocator() override; + + int64 GetMemoryLimitInBytes(se::Stream* stream) override { + return 1LL << 32; // 4GB. TODO(jlebar): Tune this? + } + int64 TotalAllocatedBytes() { return total_allocated_bytes_; } + + se::port::StatusOr> AllocateBytes( + se::Stream* stream, int64 byte_size) override; + + private: + const int device_ordinal_; + DeviceMemoryAllocator* memory_allocator_; + std::vector allocated_buffers_; + int64 total_allocated_bytes_ = 0; +}; + +ScratchAllocator::~ScratchAllocator() { + for (auto& allocated_buffer : allocated_buffers_) { + if (!memory_allocator_->Deallocate(device_ordinal_, &allocated_buffer) + .ok()) { + // The program can still continue with failed deallocation. + LOG(ERROR) << "Failed to deallocate the allocated buffer: " + << allocated_buffer.opaque(); + } + } +} + +se::port::StatusOr> ScratchAllocator::AllocateBytes( + se::Stream* stream, int64 byte_size) { + CHECK_GE(byte_size, 0) << "byte_size must be positive."; + if (byte_size > GetMemoryLimitInBytes(stream)) { + return se::port::Status( + se::port::error::RESOURCE_EXHAUSTED, + tensorflow::strings::Printf( + "Allocating %lld bytes exceeds the memory limit of %lld bytes.", + byte_size, GetMemoryLimitInBytes(stream))); + } + + auto status_or_memory = + memory_allocator_->Allocate(device_ordinal_, byte_size, + /*retry_on_failure=*/false); + if (!status_or_memory.ok()) { + return se::port::Status(se::port::error::RESOURCE_EXHAUSTED, + tensorflow::strings::Printf( + "Failed to allocate %lld bytes on device %d.", + byte_size, device_ordinal_)); + } + se::DeviceMemoryBase allocated_buffer = status_or_memory.ValueOrDie(); + allocated_buffers_.push_back(allocated_buffer); + total_allocated_bytes_ += byte_size; + return se::DeviceMemory(allocated_buffer); +} + +// Determines whether we can safely perform a winograd non-fused convolution for +// the given input and output shapes. This works around b/68264959, an integer +// overflow in cuDNNv5 and cuDNNv6. +// +// TODO(jlebar): We shouldn't need this check for cuDNNv7. +bool ShouldIncludeWinogradNonfusedAlgo( + const Shape& input_shape, const Shape& output_shape, + const ConvolutionDimensionNumbers& dnums) { + int64 batch = input_shape.dimensions(dnums.input_batch_dimension()); + int64 in_depths = input_shape.dimensions(dnums.input_feature_dimension()); + int64 in_rows = input_shape.dimensions(dnums.input_spatial_dimensions(0)); + int64 in_cols = + dnums.input_spatial_dimensions_size() == 1 + ? 1 + : input_shape.dimensions(dnums.input_spatial_dimensions(1)); + int64 out_depths = output_shape.dimensions(dnums.output_feature_dimension()); + + int64 total_size = CeilOfRatio(batch, int64{16}) * + std::max(in_depths, out_depths) * in_cols * in_rows * + sizeof(float); + + const int64 threshold = 1L << 31; + return total_size < threshold; +} + +std::vector GetAlgorithms(CudnnConvKind kind, + bool with_winograd_nonfused, + se::StreamExecutor* stream_exec_) { + std::vector algorithms; + switch (kind) { + case CudnnConvKind::kBackwardFilter: + CHECK(stream_exec_->GetConvolveBackwardFilterAlgorithms( + with_winograd_nonfused, &algorithms)); + break; + case CudnnConvKind::kBackwardInput: + CHECK(stream_exec_->GetConvolveBackwardDataAlgorithms( + with_winograd_nonfused, &algorithms)); + break; + case CudnnConvKind::kForward: + CHECK(stream_exec_->GetConvolveAlgorithms(with_winograd_nonfused, + &algorithms)); + break; + } + + // Remove any algorithms with tensor math enabled. These have lower precision + // than regular algorithms, and we don't yet have a way to turn this on/off in + // XLA. + algorithms.erase(std::remove_if(algorithms.begin(), algorithms.end(), + [&](const AlgorithmDesc& a) { + return a.tensor_ops_enabled(); + }), + algorithms.end()); + + return algorithms; +} + +string AlgorithmToString(const AlgorithmDesc& algo) { + if (algo.tensor_ops_enabled()) { + return tensorflow::strings::StrCat(algo.algo_id(), "+TC"); + } + return tensorflow::strings::StrCat(algo.algo_id()); +} + +string NumBytesToString(int64 bytes) { + return tensorflow::strings::StrCat( + tensorflow::strings::HumanReadableNumBytes(bytes), " (", bytes, "B)"); +} + +} // anonymous namespace + +// We could have caching here so that we don't redo this work for two identical +// convolutions. Unfortunately our cache key would have to be a tuple +// containing the protos passed to this function, and we have no utility for +// hashing protos. We could write our own hash functions, but they'd silently +// break if we ever added a field to one of the protos. Perhaps we could hack +// using the binary-encoded proto as the hash key, on the assumption that two +// protos being binary-equal is a sufficient, if not necessary, condition for +// proper equality. But that would still leave us open to having unnecessary +// cache misses and doing extra work. Overall, caching doesn't seem worth the +// trouble, but we may want to revisit this if we ever find a model where +// caching would speed up compilation a lot. +optional> +CudnnConvolutionAlgorithmPicker::PickBestAlgorithm( + CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape, + const Shape& output_shape, const Window& window, + const ConvolutionDimensionNumbers& dnums, HloInstruction* instr) { + // Create a stream for us to do our work on. + se::Stream stream{stream_exec_}; + stream.Init(); + const auto device_ordinal = stream_exec_->device_ordinal(); + + // allocator either points to this->allocator_ or, if that's null, to a + // StreamExecutorMemoryAllocator for stream_exec_. + DeviceMemoryAllocator* allocator; + optional se_allocator; + if (allocator_ != nullptr) { + allocator = allocator_; + } else { + se_allocator.emplace( + stream_exec_->platform(), + tensorflow::gtl::ArraySlice({stream_exec_})); + allocator = &*se_allocator; + } + + // Allocate space for the input, filter, and output of the convolution. We + // use a ScratchAllocator for this instead of calling allocator_ directly so + // that our allocations don't leak. + // + // We don't put any data in these buffers, because (in theory, anyway) the + // speed of a conv isn't affected by the data being convolved. + ScratchAllocator input_output_allocator(device_ordinal, allocator); + se::port::StatusOr input_buf = + input_output_allocator.AllocateBytes(&stream, + ShapeUtil::ByteSizeOf(input_shape)); + se::port::StatusOr filter_buf = + input_output_allocator.AllocateBytes(&stream, + ShapeUtil::ByteSizeOf(filter_shape)); + se::port::StatusOr output_buf = + input_output_allocator.AllocateBytes(&stream, + ShapeUtil::ByteSizeOf(output_shape)); + if (!input_buf.ok() || !filter_buf.ok() || !output_buf.ok()) { + LOG(WARNING) + << "Couldn't allocate space for input/filter/output of convolution " + << instr->ToString() << ". Falling back to default algorithm."; + return nullopt; + } + + const bool use_winograd_nonfused = + ShouldIncludeWinogradNonfusedAlgo(input_shape, output_shape, dnums); + se::dnn::ProfileResult best_result; + int64 best_result_bytes_used = 0; + for (const AlgorithmDesc& alg : + GetAlgorithms(kind, use_winograd_nonfused, stream_exec_)) { + ScratchAllocator scratch_allocator(device_ordinal, allocator); + se::dnn::ProfileResult profile_result; + VLOG(3) << "Trying algorithm " << AlgorithmToString(alg) << " for " + << instr->ToString(); + + bool launch_ok = + RunCudnnConvolution(kind, input_shape, filter_shape, output_shape, + se::DeviceMemory(input_buf.ValueOrDie()), + se::DeviceMemory(filter_buf.ValueOrDie()), + se::DeviceMemory(output_buf.ValueOrDie()), + &scratch_allocator, window, dnums, + AlgorithmConfig(alg), &stream, &profile_result) + .ok(); + + if (launch_ok && profile_result.is_valid()) { + int64 scratch_bytes_used = scratch_allocator.TotalAllocatedBytes(); + VLOG(3) << "Run of algorithm " << AlgorithmToString(alg) + << " succeeded, taking " << profile_result.elapsed_time_in_ms() + << "ms and using " << NumBytesToString(scratch_bytes_used) + << " of scratch (Best result: " + << best_result.elapsed_time_in_ms() << "ms, " + << NumBytesToString(best_result_bytes_used) << " of scratch)"; + if (profile_result.elapsed_time_in_ms() < + best_result.elapsed_time_in_ms()) { + best_result = profile_result; + best_result_bytes_used = scratch_bytes_used; + } + } else { + VLOG(3) << "Run of algorithm " << AlgorithmToString(alg) << " failed."; + } + } + if (best_result.is_valid()) { + VLOG(2) << "Best algorithm for " << instr->ToString() << ": " + << AlgorithmToString(best_result.algorithm()) << ", takes " + << best_result.elapsed_time_in_ms() << "ms, and uses " + << best_result_bytes_used << "B of scratch memory."; + return std::make_pair(best_result.algorithm().algo_id(), + best_result_bytes_used); + } + + LOG(WARNING) << "All algorithms tried for convolution " << instr->ToString() + << " failed. Falling back to default algorithm."; + return nullopt; +} + +StatusOr CudnnConvolutionAlgorithmPicker::RunOnInstruction( + HloInstruction* instr) { + CHECK(IsCustomCallToDnnConvolution(*instr)); + + const auto& call_target = instr->custom_call_target(); + const auto& lhs_shape = instr->operand(0)->shape(); + const auto& rhs_shape = instr->operand(1)->shape(); + const auto& conv_result_shape = instr->shape().tuple_shapes(0); + optional> alg_and_scratch_bytes; + if (call_target == kCudnnConvForwardCallTarget) { + alg_and_scratch_bytes = PickBestAlgorithm( + CudnnConvKind::kForward, /*input_shape=*/lhs_shape, + /*filter_shape=*/rhs_shape, /*output_shape=*/conv_result_shape, + instr->window(), instr->convolution_dimension_numbers(), instr); + } else if (call_target == kCudnnConvBackwardInputCallTarget) { + alg_and_scratch_bytes = PickBestAlgorithm( + CudnnConvKind::kBackwardInput, /*input_shape=*/conv_result_shape, + /*filter_shape=*/rhs_shape, /*output_shape=*/lhs_shape, instr->window(), + instr->convolution_dimension_numbers(), instr); + } else if (call_target == kCudnnConvBackwardFilterCallTarget) { + alg_and_scratch_bytes = PickBestAlgorithm( + CudnnConvKind::kBackwardFilter, /*input_shape=*/lhs_shape, + /*filter_shape=*/conv_result_shape, /*output_shape=*/rhs_shape, + instr->window(), instr->convolution_dimension_numbers(), instr); + } else { + LOG(FATAL) << "Unknown custom call target for cudnn conv: " + << instr->ToString(); + } + + if (!alg_and_scratch_bytes.has_value()) { + return false; + } + + int64 algorithm; + int64 scratch_bytes; + std::tie(algorithm, scratch_bytes) = *alg_and_scratch_bytes; + + VLOG(1) << "Setting cudnn conv to use algorithm " << algorithm << " and " + << NumBytesToString(scratch_bytes) + << " of scratch memory: " << instr->ToString(); + + // Replace instr with a new CustomCall which has the correct algorithm, and + // whose output shape has the appropriate amount of scratch memory. + HloComputation* computation = instr->parent(); + Shape new_call_shape = + ShapeUtil::MakeTupleShape({instr->shape().tuple_shapes(0), + ShapeUtil::MakeShape(U8, {scratch_bytes})}); + HloInstruction* algorithm_hlo = computation->AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(algorithm))); + HloInstruction* new_call = + computation->AddInstruction(HloInstruction::CreateCustomCall( + new_call_shape, + {instr->mutable_operand(0), instr->mutable_operand(1), algorithm_hlo}, + instr->custom_call_target())); + new_call->set_window(instr->window()); + new_call->set_convolution_dimension_numbers( + instr->convolution_dimension_numbers()); + + // Repackage new_call so it has the same shape as the original call, namely + // (conv_result, u8[0]). + HloInstruction* new_tuple = + computation->AddInstruction(HloInstruction::CreateTuple( + {computation->AddInstruction(HloInstruction::CreateGetTupleElement( + new_call_shape.tuple_shapes(0), new_call, 0)), + computation->AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR1({})))})); + + TF_RETURN_IF_ERROR(instr->parent()->ReplaceInstruction(instr, new_tuple)); + return true; +} + +StatusOr CudnnConvolutionAlgorithmPicker::RunOnComputation( + HloComputation* computation) { + std::vector convs; + for (auto* instr : computation->instructions()) { + if (IsCustomCallToDnnConvolution(*instr)) { + convs.push_back(instr); + } + } + + bool changed = false; + for (auto* instr : convs) { + TF_ASSIGN_OR_RETURN(bool result, RunOnInstruction(instr)); + changed |= result; + } + return changed; +} + +StatusOr CudnnConvolutionAlgorithmPicker::Run(HloModule* module) { + bool changed = false; + for (HloComputation* computation : module->MakeNonfusionComputations()) { + TF_ASSIGN_OR_RETURN(bool result, RunOnComputation(computation)); + changed |= result; + } + return changed; +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h new file mode 100644 index 0000000000000000000000000000000000000000..10e49daee5df187e5ad90b7adf8c92aa9a63ba21 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h @@ -0,0 +1,62 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_ALGORITHM_PICKER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_ALGORITHM_PICKER_H_ + +#include "tensorflow/compiler/xla/service/device_memory_allocator.h" +#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/core/lib/gtl/optional.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +namespace xla { +namespace gpu { + +// Modifies CustomCalls to cudnn convolutions, choosing the best algorithm for +// each and adding explicit scratch space to the CustomCalls. +class CudnnConvolutionAlgorithmPicker : public HloPassInterface { + public: + // If the `allocator` parameter is not null, we will use it to allocate temp + // memory while timing the various convolution algorithms. If it's null, + // we'll use the default allocator on the StreamExecutor. + CudnnConvolutionAlgorithmPicker( + perftools::gputools::StreamExecutor* stream_exec, + DeviceMemoryAllocator* allocator) + : stream_exec_(stream_exec), allocator_(allocator) {} + + tensorflow::StringPiece name() const override { + return "cudnn-convolution-algorithm-picker"; + } + + StatusOr Run(HloModule* module) override; + + private: + StatusOr RunOnComputation(HloComputation* computation); + StatusOr RunOnInstruction(HloInstruction* instr); + tensorflow::gtl::optional> PickBestAlgorithm( + CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape, + const Shape& output_shape, const Window& window, + const ConvolutionDimensionNumbers& dnums, HloInstruction* instr); + + perftools::gputools::StreamExecutor* stream_exec_; // never null + DeviceMemoryAllocator* allocator_; // may be null +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_ALGORITHM_PICKER_H_ diff --git a/tensorflow/compiler/xla/service/gpu/convolution_folding.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc similarity index 83% rename from tensorflow/compiler/xla/service/gpu/convolution_folding.cc rename to tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc index b0626ca3bc9f843e513d4727932f0e2d5fa37748..e0c73aa73acb7f3313eb54fb07390cb76590433e 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_folding.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/gpu/convolution_folding.h" +#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h" #include #include @@ -33,14 +33,32 @@ namespace xla { namespace gpu { namespace { + +bool CanImplementAsCudnnForwardConv(HloInstruction* conv) { + const ConvolutionDimensionNumbers& dnums = + conv->convolution_dimension_numbers(); + if (dnums.input_spatial_dimensions_size() > 3) { + return false; + } + + // CuDNN does not accept zero-element arguments + if (ShapeUtil::HasZeroElements(conv->operand(0)->shape()) || + ShapeUtil::HasZeroElements(conv->operand(1)->shape())) { + return false; + } + + if (window_util::HasWindowReversal(conv->window())) { + return false; + } + return true; +} + // Try to match a backward filter pattern that contains "conv". // Precondition: "conv" is a kConvolution. -std::tuple, Window, - ConvolutionDimensionNumbers> -MatchBackwardFilter(HloInstruction* conv) { +std::tuple MatchBackwardFilter( + HloInstruction* conv) { const auto no_match_result = - std::make_tuple(false, std::vector(), Window(), - ConvolutionDimensionNumbers()); + std::make_tuple(false, Window(), ConvolutionDimensionNumbers()); // Step 1: match the instruction pattern without considering the paddings and // dimension numbers just yet. We may need some generic pattern matcher // similar to third_party/llvm/llvm/include/llvm/IR/PatternMatch.h @@ -190,18 +208,15 @@ MatchBackwardFilter(HloInstruction* conv) { backward_conv_dnums.add_kernel_spatial_dimensions(output_spatial_dims[i]); } - return std::make_tuple(true, std::vector({conv}), - backward_conv_window, backward_conv_dnums); + return std::make_tuple(true, backward_conv_window, backward_conv_dnums); } // Try to match a backward input pattern that contains "conv". // Precondition: "conv" is a kConvolution. -std::tuple, Window, - ConvolutionDimensionNumbers> -MatchBackwardInput(HloInstruction* conv) { +std::tuple MatchBackwardInput( + HloInstruction* conv) { const auto no_match_result = - std::make_tuple(false, std::vector(), Window(), - ConvolutionDimensionNumbers()); + std::make_tuple(false, Window(), ConvolutionDimensionNumbers()); // Match instruction pattern. CHECK_EQ(HloOpcode::kConvolution, conv->opcode()); @@ -374,58 +389,82 @@ MatchBackwardInput(HloInstruction* conv) { dnums.set_kernel_output_feature_dimension( conv->convolution_dimension_numbers().kernel_input_feature_dimension()); - return std::make_tuple(true, - std::vector({conv, reverse_filter}), - new_window, dnums); + return std::make_tuple(true, new_window, dnums); } -} // namespace -StatusOr ConvolutionFolding::Run(HloModule* module) { - HloComputation* entry_computation = module->entry_computation(); - std::vector convs; - for (auto* hlo : entry_computation->instructions()) { - if (hlo->opcode() == HloOpcode::kConvolution) { - convs.push_back(hlo); - } - } +// Tries to rewrite a single convolution into a call to cudnn. +StatusOr RunOnInstruction(HloInstruction* conv) { + CHECK_EQ(conv->opcode(), HloOpcode::kConvolution); - bool changed = false; - for (HloInstruction* conv : convs) { + HloInstruction* custom_call = [&]() -> HloInstruction* { bool match; - std::vector hlos_to_fuse; Window window; ConvolutionDimensionNumbers dnums; - std::tie(match, hlos_to_fuse, window, dnums) = MatchBackwardFilter(conv); + + std::tie(match, window, dnums) = MatchBackwardFilter(conv); if (match) { - VLOG(2) << "Fuse instructions"; - for (HloInstruction* hlo_to_fuse : hlos_to_fuse) { - VLOG(2) << " " << hlo_to_fuse->ToString(); - } - HloInstruction* backward_convolution = - entry_computation->CreateFusionInstructionForBackwardConvolution( - hlos_to_fuse, HloInstruction::FusionKind::kConvBackwardFilter, - window, dnums); - VLOG(2) << "to backward filter convolution"; - VLOG(2) << " " << backward_convolution->ToString(); - changed = true; - continue; + return CreateCudnnConvBackwardFilter( + conv->shape(), conv->mutable_operand(0), conv->mutable_operand(1), + window, dnums); } - std::tie(match, hlos_to_fuse, window, dnums) = MatchBackwardInput(conv); + std::tie(match, window, dnums) = MatchBackwardInput(conv); if (match) { - VLOG(2) << "Fuse instructions"; - for (HloInstruction* hlo_to_fuse : hlos_to_fuse) { - VLOG(2) << " " << hlo_to_fuse->ToString(); - } - HloInstruction* backward_convolution = - entry_computation->CreateFusionInstructionForBackwardConvolution( - hlos_to_fuse, HloInstruction::FusionKind::kConvBackwardInput, - window, dnums); - VLOG(2) << "to backward input convolution"; - VLOG(2) << " " << backward_convolution->ToString(); - changed = true; - continue; + // Backward input conv subsumes the conv plus the reverse in operand 1. + HloInstruction* reverse = conv->mutable_operand(1); + CHECK_EQ(reverse->opcode(), HloOpcode::kReverse); + HloInstruction* rhs = reverse->mutable_operand(0); + + return CreateCudnnConvBackwardInput( + conv->shape(), conv->mutable_operand(0), rhs, window, dnums); } + + // If all else fails, try a forward convolution. + if (CanImplementAsCudnnForwardConv(conv)) { + return CreateCudnnConvForward(conv->shape(), conv->mutable_operand(0), + conv->mutable_operand(1), conv->window(), + conv->convolution_dimension_numbers()); + } + + return nullptr; + }(); + + if (custom_call == nullptr) { + return false; + } + + // The CustomCall returns a tuple (conv_result, scratch_memory). Extract out + // the conv result and replace `conv` with it. + TF_RETURN_IF_ERROR(conv->parent()->ReplaceWithNewInstruction( + conv, + HloInstruction::CreateGetTupleElement(conv->shape(), custom_call, 0))); + return true; +} + +// Rewrites the convolutions in the given computation into calls to cudnn. +// Returns true if it made any changes. +StatusOr RunOnComputation(HloComputation* computation) { + std::vector convs; + for (auto* hlo : computation->instructions()) { + if (hlo->opcode() == HloOpcode::kConvolution) { + convs.push_back(hlo); + } + } + + bool changed = false; + for (HloInstruction* conv : convs) { + TF_ASSIGN_OR_RETURN(bool result, RunOnInstruction(conv)); + changed |= result; + } + return changed; +} +} // namespace + +StatusOr CudnnConvolutionRewriter::Run(HloModule* module) { + bool changed = false; + for (HloComputation* computation : module->MakeNonfusionComputations()) { + TF_ASSIGN_OR_RETURN(bool result, RunOnComputation(computation)); + changed |= result; } return changed; } diff --git a/tensorflow/compiler/xla/service/gpu/convolution_folding.h b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h similarity index 63% rename from tensorflow/compiler/xla/service/gpu/convolution_folding.h rename to tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h index f9c898721f8dd6b8b7e74c82bb2085cc437eaad5..0c0578d88840fed1d77f7456c9acef27dec380f5 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_folding.h +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CONVOLUTION_FOLDING_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CONVOLUTION_FOLDING_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_REWRITER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_REWRITER_H_ #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" @@ -22,10 +22,12 @@ limitations under the License. namespace xla { namespace gpu { -class ConvolutionFolding : public HloPassInterface { +// Rewrites plain convolutions, backwards-filter convolutions, and +// backwards-input convolutions into CustomCall HLOs that call into cuDNN. +class CudnnConvolutionRewriter : public HloPassInterface { public: tensorflow::StringPiece name() const override { - return "convolution-folding"; + return "cudnn-convolution-rewriter"; } StatusOr Run(HloModule* module) override; @@ -34,4 +36,4 @@ class ConvolutionFolding : public HloPassInterface { } // namespace gpu } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CONVOLUTION_FOLDING_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_REWRITER_H_ diff --git a/tensorflow/compiler/xla/service/gpu/convolution_folding_test.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter_test.cc similarity index 82% rename from tensorflow/compiler/xla/service/gpu/convolution_folding_test.cc rename to tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter_test.cc index 34e6bdb117d47a3d7e1eb3bae5806e130e94ea79..65588b6aaf24da628ea586eb52c462b78b8daaa7 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_folding_test.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,23 +13,29 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/gpu/convolution_folding.h" +#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h" +#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/shape_inference.h" +#include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/core/platform/test.h" namespace xla { namespace gpu { +namespace { -class ConvolutionFoldingTest : public HloTestBase { +namespace op = xla::testing::opcode_matchers; + +class CudnnConvolutionRewriterTest : public HloTestBase { public: - ConvolutionFoldingTest() { + CudnnConvolutionRewriterTest() { for (int i = 0; i < 2; ++i) { WindowDimension* window_dim = default_conv_window_.add_dimensions(); window_dim->set_size(1); @@ -44,7 +50,8 @@ class ConvolutionFoldingTest : public HloTestBase { // the batch and feature dimension in the activations, and treat the batch // dimension in gradients as the input feature dimension in the filter. // - // TODO(jingyue): Add more tests on NCHW input order which TF also supports. + // TODO(jingyue): Add more tests on NCHW input order, which TF also + // supports. tf_default_dnums_for_backward_filter_.set_input_batch_dimension(3); tf_default_dnums_for_backward_filter_.set_input_feature_dimension(0); tf_default_dnums_for_backward_filter_.add_input_spatial_dimensions(1); @@ -74,9 +81,8 @@ class ConvolutionFoldingTest : public HloTestBase { } protected: - bool FoldConvolution(HloModule* module) { - ConvolutionFolding convolution_folding; - return convolution_folding.Run(module).ValueOrDie(); + bool RunPass(HloModule* module) { + return CudnnConvolutionRewriter().Run(module).ValueOrDie(); } // A convolution window with stride 1 and zero padding. The size fields are @@ -86,7 +92,7 @@ class ConvolutionFoldingTest : public HloTestBase { ConvolutionDimensionNumbers tf_default_dnums_for_backward_input_; }; -TEST_F(ConvolutionFoldingTest, BackwardFilterConvolve) { +TEST_F(CudnnConvolutionRewriterTest, BackwardFilterConvolve) { HloComputation::Builder builder(TestName()); HloInstruction* activations = builder.AddInstruction(HloInstruction::CreateParameter( @@ -108,14 +114,13 @@ TEST_F(ConvolutionFoldingTest, BackwardFilterConvolve) { auto module = CreateNewModule(); HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(FoldConvolution(module.get())); - EXPECT_EQ(HloOpcode::kFusion, - entry_computation->root_instruction()->opcode()); - EXPECT_TRUE(HloInstruction::FusionKind::kConvBackwardFilter == - entry_computation->root_instruction()->fusion_kind()); + EXPECT_TRUE(RunPass(module.get())); + EXPECT_THAT(entry_computation->root_instruction(), + op::GetTupleElement( + op::CustomCall(kCudnnConvBackwardFilterCallTarget), 0)); } -TEST_F(ConvolutionFoldingTest, +TEST_F(CudnnConvolutionRewriterTest, BackwardFilterConvolveEquivalentToForwardConvolution) { HloComputation::Builder builder(TestName()); HloInstruction* activations = @@ -135,12 +140,17 @@ TEST_F(ConvolutionFoldingTest, tf_default_dnums_for_backward_filter_)); auto module = CreateNewModule(); - module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(FoldConvolution(module.get())); + HloComputation* entry_computation = + module->AddEntryComputation(builder.Build()); + EXPECT_TRUE(RunPass(module.get())); + EXPECT_THAT(entry_computation->root_instruction(), + op::GetTupleElement( + op::CustomCall(kCudnnConvBackwardFilterCallTarget), 0)); } // Extracted from block35 training. -TEST_F(ConvolutionFoldingTest, BackwardFilterConvolveWithPaddedActivations) { +TEST_F(CudnnConvolutionRewriterTest, + BackwardFilterConvolveWithPaddedActivations) { auto builder = HloComputation::Builder(TestName()); HloInstruction* activations = builder.AddInstruction(HloInstruction::CreateParameter( @@ -162,15 +172,15 @@ TEST_F(ConvolutionFoldingTest, BackwardFilterConvolveWithPaddedActivations) { auto module = CreateNewModule(); HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(FoldConvolution(module.get())); - EXPECT_EQ(HloOpcode::kFusion, - entry_computation->root_instruction()->opcode()); - EXPECT_TRUE(HloInstruction::FusionKind::kConvBackwardFilter == - entry_computation->root_instruction()->fusion_kind()); + EXPECT_TRUE(RunPass(module.get())); + EXPECT_THAT(entry_computation->root_instruction(), + op::GetTupleElement( + op::CustomCall(kCudnnConvBackwardFilterCallTarget), 0)); } // Extracted from inception v3 training. -TEST_F(ConvolutionFoldingTest, BackwardFilterConvolveWithPaddedGradients) { +TEST_F(CudnnConvolutionRewriterTest, + BackwardFilterConvolveWithPaddedGradients) { auto builder = HloComputation::Builder(TestName()); HloInstruction* activations = builder.AddInstruction(HloInstruction::CreateParameter( @@ -192,14 +202,13 @@ TEST_F(ConvolutionFoldingTest, BackwardFilterConvolveWithPaddedGradients) { auto module = CreateNewModule(); HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(FoldConvolution(module.get())); - EXPECT_EQ(HloOpcode::kFusion, - entry_computation->root_instruction()->opcode()); - EXPECT_TRUE(HloInstruction::FusionKind::kConvBackwardFilter == - entry_computation->root_instruction()->fusion_kind()); + EXPECT_TRUE(RunPass(module.get())); + EXPECT_THAT(entry_computation->root_instruction(), + op::GetTupleElement( + op::CustomCall(kCudnnConvBackwardFilterCallTarget), 0)); } -TEST_F(ConvolutionFoldingTest, BackwardFilterConvolveWithUnevenPadding) { +TEST_F(CudnnConvolutionRewriterTest, BackwardFilterConvolveWithUnevenPadding) { auto builder = HloComputation::Builder(TestName()); HloInstruction* activations = builder.AddInstruction(HloInstruction::CreateParameter( @@ -221,14 +230,13 @@ TEST_F(ConvolutionFoldingTest, BackwardFilterConvolveWithUnevenPadding) { auto module = CreateNewModule(); HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(FoldConvolution(module.get())); - EXPECT_EQ(HloOpcode::kFusion, - entry_computation->root_instruction()->opcode()); - EXPECT_TRUE(HloInstruction::FusionKind::kConvBackwardFilter == - entry_computation->root_instruction()->fusion_kind()); + EXPECT_TRUE(RunPass(module.get())); + EXPECT_THAT(entry_computation->root_instruction(), + op::GetTupleElement( + op::CustomCall(kCudnnConvBackwardFilterCallTarget), 0)); } -TEST_F(ConvolutionFoldingTest, BackwardInputConvolveEvenPadding) { +TEST_F(CudnnConvolutionRewriterTest, BackwardInputConvolveEvenPadding) { auto builder = HloComputation::Builder(TestName()); HloInstruction* output = builder.AddInstruction(HloInstruction::CreateParameter( @@ -272,14 +280,15 @@ TEST_F(ConvolutionFoldingTest, BackwardInputConvolveEvenPadding) { auto module = CreateNewModule(); HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(FoldConvolution(module.get())); - EXPECT_EQ(HloOpcode::kFusion, - entry_computation->root_instruction()->opcode()); - EXPECT_TRUE(HloInstruction::FusionKind::kConvBackwardInput == - entry_computation->root_instruction()->fusion_kind()); + EXPECT_TRUE(RunPass(module.get())); + + ASSERT_THAT(entry_computation->root_instruction(), + op::GetTupleElement( + op::CustomCall(kCudnnConvBackwardInputCallTarget), 0)); + const HloInstruction* custom_call = + entry_computation->root_instruction()->operand(0); for (int i = 0; i < 2; ++i) { - const WindowDimension& window_dim = - entry_computation->root_instruction()->window().dimensions(i); + const WindowDimension& window_dim = custom_call->window().dimensions(i); // Low padding of the backward input convolution // = kernel_size - 1 - low padding on gradients. EXPECT_EQ(3, window_dim.padding_low()); @@ -291,7 +300,7 @@ TEST_F(ConvolutionFoldingTest, BackwardInputConvolveEvenPadding) { // Convolve([abc], [x], base_dilation=2) // = Convolve([abc], Reverse([x]), base_dilation=2) // = BackwardInputConvolve([abc], [x], stride=2) -TEST_F(ConvolutionFoldingTest, BackwardInputConvolve1x1Filter) { +TEST_F(CudnnConvolutionRewriterTest, BackwardInputConvolve1x1Filter) { auto builder = HloComputation::Builder(TestName()); // NHWC dimension order. HloInstruction* output = @@ -316,17 +325,16 @@ TEST_F(ConvolutionFoldingTest, BackwardInputConvolve1x1Filter) { auto module = CreateNewModule(); HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(FoldConvolution(module.get())); - EXPECT_EQ(HloOpcode::kFusion, - entry_computation->root_instruction()->opcode()); - EXPECT_TRUE(HloInstruction::FusionKind::kConvBackwardInput == - entry_computation->root_instruction()->fusion_kind()); + EXPECT_TRUE(RunPass(module.get())); + EXPECT_THAT(entry_computation->root_instruction(), + op::GetTupleElement( + op::CustomCall(kCudnnConvBackwardInputCallTarget), 0)); } // BackwardInputConvolve([abc], [x], stride=1) is equivalent to // ForwardConvolve([abc], [x], stride=1). No need to fold it into backward input // convolution. -TEST_F(ConvolutionFoldingTest, +TEST_F(CudnnConvolutionRewriterTest, BackwardInputConvolve1x1FilterEquivalentToForwardConvolve) { auto builder = HloComputation::Builder(TestName()); // NHWC dimension order. @@ -347,8 +355,12 @@ TEST_F(ConvolutionFoldingTest, tf_default_dnums_for_backward_input_)); auto module = CreateNewModule(); - module->AddEntryComputation(builder.Build()); - EXPECT_FALSE(FoldConvolution(module.get())); + HloComputation* entry_computation = + module->AddEntryComputation(builder.Build()); + EXPECT_TRUE(RunPass(module.get())); + EXPECT_THAT( + entry_computation->root_instruction(), + op::GetTupleElement(op::CustomCall(kCudnnConvForwardCallTarget), 0)); } // Extracted from Inception V3 training. @@ -365,7 +377,8 @@ TEST_F(ConvolutionFoldingTest, // 20x10x10x192 // // Gradients are padded unevenly. -TEST_F(ConvolutionFoldingTest, BackwardInputConvolveUnevenPaddingOnGradients) { +TEST_F(CudnnConvolutionRewriterTest, + BackwardInputConvolveUnevenPaddingOnGradients) { auto builder = HloComputation::Builder(TestName()); HloInstruction* output = builder.AddInstruction(HloInstruction::CreateParameter( @@ -397,14 +410,14 @@ TEST_F(ConvolutionFoldingTest, BackwardInputConvolveUnevenPaddingOnGradients) { auto module = CreateNewModule(); HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(FoldConvolution(module.get())); - EXPECT_EQ(HloOpcode::kFusion, - entry_computation->root_instruction()->opcode()); - EXPECT_TRUE(HloInstruction::FusionKind::kConvBackwardInput == - entry_computation->root_instruction()->fusion_kind()); + EXPECT_TRUE(RunPass(module.get())); + ASSERT_THAT(entry_computation->root_instruction(), + op::GetTupleElement( + op::CustomCall(kCudnnConvBackwardInputCallTarget), 0)); + const HloInstruction* custom_call = + entry_computation->root_instruction()->operand(0); for (int i = 0; i < 2; ++i) { - const WindowDimension& window_dim = - entry_computation->root_instruction()->window().dimensions(i); + const WindowDimension& window_dim = custom_call->window().dimensions(i); EXPECT_EQ(0, window_dim.padding_low()); EXPECT_EQ(0, window_dim.padding_high()); EXPECT_EQ(2, window_dim.stride()); @@ -413,7 +426,7 @@ TEST_F(ConvolutionFoldingTest, BackwardInputConvolveUnevenPaddingOnGradients) { // Similar to BackwardInputConvolveUnevenPadding, but the low padding of the // gradients exceeds kernel_size - 1. Therefore, this pattern cannot be fused. -TEST_F(ConvolutionFoldingTest, BackwardInputConvolveLowPaddingTooLarge) { +TEST_F(CudnnConvolutionRewriterTest, BackwardInputConvolveLowPaddingTooLarge) { auto builder = HloComputation::Builder(TestName()); HloInstruction* output = builder.AddInstruction(HloInstruction::CreateParameter( @@ -442,8 +455,12 @@ TEST_F(ConvolutionFoldingTest, BackwardInputConvolveLowPaddingTooLarge) { .ValueOrDie())); auto module = CreateNewModule(); - module->AddEntryComputation(builder.Build()); - EXPECT_FALSE(FoldConvolution(module.get())); + HloComputation* entry_computation = + module->AddEntryComputation(builder.Build()); + EXPECT_TRUE(RunPass(module.get())); + EXPECT_THAT( + entry_computation->root_instruction(), + op::GetTupleElement(op::CustomCall(kCudnnConvForwardCallTarget), 0)); } // Extracted from //learning/brain/google/xla/benchmarks/resnet.py @@ -460,7 +477,7 @@ TEST_F(ConvolutionFoldingTest, BackwardInputConvolveLowPaddingTooLarge) { // // We should fuse BC even though padding on activations is uneven, because // PadInsertion will canonicalize the fusion HLO. -TEST_F(ConvolutionFoldingTest, +TEST_F(CudnnConvolutionRewriterTest, BackwardInputConvolveUnevenPaddingOnActivations) { auto builder = HloComputation::Builder(TestName()); // The gradients are in NCHW layout. @@ -493,13 +510,12 @@ TEST_F(ConvolutionFoldingTest, auto module = CreateNewModule(); const HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(FoldConvolution(module.get())); - const HloInstruction* backward_conv = entry_computation->root_instruction(); - EXPECT_EQ(HloOpcode::kFusion, backward_conv->opcode()); - EXPECT_TRUE(HloInstruction::FusionKind::kConvBackwardInput == - backward_conv->fusion_kind()); + EXPECT_TRUE(RunPass(module.get())); + ASSERT_THAT(entry_computation->root_instruction(), + op::GetTupleElement( + op::CustomCall(kCudnnConvBackwardInputCallTarget), 0)); const WindowDimension& backward_conv_col_dim = - backward_conv->window().dimensions(1); + entry_computation->root_instruction()->operand(0)->window().dimensions(1); EXPECT_EQ(0, backward_conv_col_dim.padding_low()); EXPECT_EQ(1, backward_conv_col_dim.padding_high()); } @@ -515,7 +531,7 @@ TEST_F(ConvolutionFoldingTest, // // We currently don't fuse BC because PadInsertion doesn't support negative // padding on the gradients of backward convolution (b/32744257). -TEST_F(ConvolutionFoldingTest, +TEST_F(CudnnConvolutionRewriterTest, BackwardInputConvolveNegativePaddingHighOnActivations) { auto builder = HloComputation::Builder(TestName()); // The gradients are in NCHW layout. @@ -544,9 +560,14 @@ TEST_F(ConvolutionFoldingTest, .ValueOrDie())); auto module = CreateNewModule(); - module->AddEntryComputation(builder.Build()); - EXPECT_FALSE(FoldConvolution(module.get())); + HloComputation* entry_computation = + module->AddEntryComputation(builder.Build()); + EXPECT_TRUE(RunPass(module.get())); + EXPECT_THAT( + entry_computation->root_instruction(), + op::GetTupleElement(op::CustomCall(kCudnnConvForwardCallTarget), 0)); } +} // anonymous namespace } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc new file mode 100644 index 0000000000000000000000000000000000000000..f5f52cf62bf6edb7925ec3b22fc1772ffbfbf089 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc @@ -0,0 +1,221 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/util.h" + +namespace xla { +namespace gpu { +namespace { + +namespace se = ::perftools::gputools; + +using se::DeviceMemory; +using se::DeviceMemoryBase; +using se::Stream; +using se::dnn::AlgorithmConfig; +using se::dnn::BatchDescriptor; +using se::dnn::ConvolutionDescriptor; +using se::dnn::DataLayout; +using se::dnn::DimIndex; +using se::dnn::FilterDescriptor; +using se::dnn::FilterLayout; +using se::dnn::ProfileResult; + +// A StreamExecutor ScratchAllocator that wraps a single XLA allocation, +// returning it (in its entirety) the first time Allocate() is called. +class ScratchBufAllocator : public se::ScratchAllocator { + public: + explicit ScratchBufAllocator(se::DeviceMemoryBase scratch) + : scratch_(scratch) {} + + ~ScratchBufAllocator() override = default; + + int64 GetMemoryLimitInBytes(se::Stream* /*stream*/) override { + return scratch_.size(); + } + + se::port::StatusOr> AllocateBytes( + se::Stream* stream, int64 byte_size) override { + if (allocated_) { + return se::port::InternalError( + "Can't allocate twice from a ScratchBufAllocator."); + } + if (byte_size > scratch_.size()) { + return se::port::InternalError(tensorflow::strings::StrCat( + "Can't allocate ", byte_size, + " bytes from a ScratchBufAllocator of size ", scratch_.size())); + } + + allocated_ = true; + return se::DeviceMemory(scratch_); + } + + private: + se::DeviceMemoryBase scratch_; + bool allocated_ = false; +}; + +} // anonymous namespace + +string CudnnConvKindToString(CudnnConvKind kind) { + switch (kind) { + case CudnnConvKind::kForward: + return "forward"; + case CudnnConvKind::kBackwardFilter: + return "backward_filter"; + case CudnnConvKind::kBackwardInput: + return "backward_input"; + } +} + +Status RunCudnnConvolution(CudnnConvKind kind, const Shape& input_shape, + const Shape& filter_shape, const Shape& output_shape, + DeviceMemory input_buf, + DeviceMemory filter_buf, + DeviceMemory output_buf, + DeviceMemoryBase scratch_buf, const Window& window, + const ConvolutionDimensionNumbers& dnums, + AlgorithmConfig algorithm, Stream* stream, + ProfileResult* profile_result /*= nullptr*/) { + ScratchBufAllocator scratch_allocator(scratch_buf); + return RunCudnnConvolution(kind, input_shape, filter_shape, output_shape, + input_buf, filter_buf, output_buf, + &scratch_allocator, window, dnums, algorithm, + stream, profile_result); +} + +Status RunCudnnConvolution( + CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape, + const Shape& output_shape, DeviceMemory input_buf, + DeviceMemory filter_buf, DeviceMemory output_buf, + se::ScratchAllocator* scratch_allocator, const Window& window, + const ConvolutionDimensionNumbers& dnums, AlgorithmConfig algorithm, + Stream* stream, ProfileResult* profile_result /*= nullptr*/) { + VLOG(3) << "Convolution kind: " << CudnnConvKindToString(kind); + VLOG(3) << "input shape: { " << ShapeUtil::HumanString(input_shape) << " }"; + VLOG(3) << "filter shape: { " << ShapeUtil::HumanString(filter_shape) << " }"; + VLOG(3) << "Output shape: { " << ShapeUtil::HumanString(output_shape) << " }"; + VLOG(3) << "Window: { " << window.ShortDebugString() << " }"; + VLOG(3) << "Dim nums: { " << dnums.ShortDebugString() << " }"; + + const int num_dimensions = window.dimensions_size(); + CHECK_LE(num_dimensions, 3); + // cuDNN does not support 1D convolutions. We therefore express 1D + // convolutions as 2D convolutions where the first spatial dimension is 1. + // This matches the behavior of TF (see definition of conv1d in + // tensorflow/python/ops/nn_ops.py). + const int effective_num_dimensions = std::max(2, num_dimensions); + + CHECK_EQ(F32, output_shape.element_type()) + << ShapeUtil::HumanString(output_shape); + CHECK_EQ(num_dimensions, dnums.input_spatial_dimensions_size()); + CHECK_EQ(num_dimensions, dnums.kernel_spatial_dimensions_size()); + CHECK_EQ(num_dimensions, dnums.output_spatial_dimensions_size()); + for (const WindowDimension& dim : window.dimensions()) { + CHECK_EQ(dim.padding_low(), dim.padding_high()); + } + + // cuDNN's convolution APIs support the BDYX layout for activations/output and + // the OIYX layout for weights. + BatchDescriptor input_descriptor(effective_num_dimensions); + input_descriptor.set_layout(DataLayout::kBatchDepthYX) + .set_feature_map_count( + input_shape.dimensions(dnums.input_feature_dimension())) + .set_count(input_shape.dimensions(dnums.input_batch_dimension())); + for (int dim = 0; dim < num_dimensions; ++dim) { + // Note that the dimensions are reversed. The same holds below. + input_descriptor.set_spatial_dim( + static_cast(effective_num_dimensions - dim - 1), + input_shape.dimensions(dnums.input_spatial_dimensions(dim))); + } + + FilterDescriptor filter_descriptor(effective_num_dimensions); + filter_descriptor.set_layout(FilterLayout::kOutputInputYX) + .set_input_feature_map_count( + filter_shape.dimensions(dnums.kernel_input_feature_dimension())) + .set_output_feature_map_count( + filter_shape.dimensions(dnums.kernel_output_feature_dimension())); + for (int dim = 0; dim < num_dimensions; ++dim) { + filter_descriptor.set_spatial_dim( + static_cast(effective_num_dimensions - dim - 1), + filter_shape.dimensions(dnums.kernel_spatial_dimensions(dim))); + } + + ConvolutionDescriptor convolution_descriptor(effective_num_dimensions); + for (int dim = 0; dim < num_dimensions; ++dim) { + convolution_descriptor + .set_zero_padding( + static_cast(effective_num_dimensions - dim - 1), + window.dimensions(dim).padding_low()) + .set_filter_stride( + static_cast(effective_num_dimensions - dim - 1), + window.dimensions(dim).stride()); + } + + BatchDescriptor output_descriptor(effective_num_dimensions); + output_descriptor.set_layout(DataLayout::kBatchDepthYX) + .set_feature_map_count( + output_shape.dimensions(dnums.output_feature_dimension())) + .set_count(output_shape.dimensions(dnums.output_batch_dimension())); + for (int dim = 0; dim < num_dimensions; ++dim) { + output_descriptor.set_spatial_dim( + static_cast(effective_num_dimensions - dim - 1), + output_shape.dimensions(dnums.output_spatial_dimensions(dim))); + } + + // Add a singleton dimension in the 1D convolution case. + if (num_dimensions == 1) { + input_descriptor.set_spatial_dim(static_cast(0), 1); + output_descriptor.set_spatial_dim(static_cast(0), 1); + filter_descriptor.set_spatial_dim(static_cast(0), 1); + convolution_descriptor.set_zero_padding(static_cast(0), 0) + .set_filter_stride(static_cast(0), 1); + } + + switch (kind) { + case CudnnConvKind::kForward: + stream->ThenConvolveWithAlgorithm( + input_descriptor, input_buf, filter_descriptor, filter_buf, + convolution_descriptor, output_descriptor, &output_buf, + scratch_allocator, algorithm, profile_result); + break; + case CudnnConvKind::kBackwardInput: + stream->ThenConvolveBackwardDataWithAlgorithm( + filter_descriptor, filter_buf, output_descriptor, output_buf, + convolution_descriptor, input_descriptor, &input_buf, + scratch_allocator, algorithm, profile_result); + break; + case CudnnConvKind::kBackwardFilter: + stream->ThenConvolveBackwardFilterWithAlgorithm( + input_descriptor, input_buf, output_descriptor, output_buf, + convolution_descriptor, filter_descriptor, &filter_buf, + scratch_allocator, algorithm, profile_result); + break; + } + + if (!stream->ok()) { + return InternalError( + "Unable to launch convolution with type %s and algorithm (%lld, %lld)", + CudnnConvKindToString(kind).c_str(), algorithm.algorithm().algo_id(), + algorithm.algorithm_no_scratch().algo_id()); + } + return Status::OK(); +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h new file mode 100644 index 0000000000000000000000000000000000000000..b101f76510c129fd22b246e5f0348848192ecbba --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h @@ -0,0 +1,97 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_RUNNER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_RUNNER_H_ + +#include "tensorflow/compiler/xla/status.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +namespace xla { +namespace gpu { + +// This file contains low-level routines for running cudnn convolutions. + +// Different types of convolutions supported by cudnn. +// +// A way to think about these is that a convolution is defined by three arrays +// -- the "input", the "filter", and the "output" -- and given any two of these, +// we can compute the third. For example, a backward-input convolution takes as +// input a filter and an "output" and produces an "input" such that if one were +// to do a forward convolution of "input" using filter, the result would be +// something with the same shape as "output". +// +// This way of thinking is not correct if you look at the values produced. For +// example, a backward-input convolution is not actually the mathematical +// inverse of a forward convolution. But it's right as far as the shapes and +// "connectivity" (i.e. which elements of the input affect which elements of +// the output) are concerned. +enum class CudnnConvKind { + kForward, // input + filter => output + kBackwardInput, // filter + output => input + kBackwardFilter, // input + output => filter +}; + +// Converts a CudnnConvKind value to a string. +string CudnnConvKindToString(CudnnConvKind kind); + +// Calls into cudnn to run the specified convolution. +// +// Note that depending on the value of CudnnConvKind, the result of this call +// may be written into input_buf, filter_buf, or output_buf! +// +// At the moment we only support cudnn convolutions over floats. +// +// We provide one overload which takes a scratch buffer, and another which takes +// an allocator which is responsible for allocating the scratch space. In +// theory the second one shouldn't be necessary -- users of this function could +// just ask cudnn how much scratch space it needs for a particular convolution. +// But in practice, StreamExecutor does not expose such an API, and in the name +// of parsimony, perhaps it's better not to add it. Instead, the first time you +// call a convolution, you should call the version that takes a scratch +// allocator and take note of how much memory is used. The next time you call +// the same conv, you can provide an explicitly preallocated scratch buffer of +// that size, if you like. +Status RunCudnnConvolution( + CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape, + const Shape& output_shape, + perftools::gputools::DeviceMemory input_buf, + perftools::gputools::DeviceMemory filter_buf, + perftools::gputools::DeviceMemory output_buf, + perftools::gputools::DeviceMemoryBase scratch_buf, const Window& window, + const ConvolutionDimensionNumbers& dnums, + perftools::gputools::dnn::AlgorithmConfig algorithm, + perftools::gputools::Stream* stream, + perftools::gputools::dnn::ProfileResult* profile_result = nullptr); + +Status RunCudnnConvolution( + CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape, + const Shape& output_shape, + perftools::gputools::DeviceMemory input_buf, + perftools::gputools::DeviceMemory filter_buf, + perftools::gputools::DeviceMemory output_buf, + perftools::gputools::ScratchAllocator* scratch_allocator, + const Window& window, const ConvolutionDimensionNumbers& dnums, + perftools::gputools::dnn::AlgorithmConfig algorithm, + perftools::gputools::Stream* stream, + perftools::gputools::dnn::ProfileResult* profile_result = nullptr); + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_RUNNER_H_ diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc index 495ae1710fa3511a6fcdaeda94362cbaebcf174b..28ebd034ee0c89137f4e6eb417d8a37f4a00af7a 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc @@ -35,8 +35,9 @@ limitations under the License. #include "tensorflow/compiler/xla/service/call_inliner.h" #include "tensorflow/compiler/xla/service/dot_decomposer.h" #include "tensorflow/compiler/xla/service/flatten_call_graph.h" -#include "tensorflow/compiler/xla/service/gpu/convolution_folding.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h" +#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h" +#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h" #include "tensorflow/compiler/xla/service/gpu/fusion_merger.h" #include "tensorflow/compiler/xla/service/gpu/gpu_constants.h" #include "tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h" @@ -46,8 +47,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/hlo_schedule.h" #include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" -#include "tensorflow/compiler/xla/service/gpu/ir_emitter.h" #include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h" +#include "tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h" #include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h" #include "tensorflow/compiler/xla/service/gpu/pad_insertion.h" #include "tensorflow/compiler/xla/service/gpu/partition_assignment.h" @@ -127,7 +128,9 @@ string GetLibdeviceDir(const string& config_cuda_data_dir) { } // Runs optimization passes on the given HLO module. -tensorflow::Status OptimizeHloModule(HloModule* hlo_module) { +tensorflow::Status OptimizeHloModule(HloModule* hlo_module, + se::StreamExecutor* stream_exec, + DeviceMemoryAllocator* device_allocator) { { HloPassPipeline pipeline("optimization"); pipeline.AddInvariantChecker(); @@ -143,6 +146,7 @@ tensorflow::Status OptimizeHloModule(HloModule* hlo_module) { // most ops. pipeline.AddPass(BF16, F32); pipeline.AddPass(); + { auto& pass = pipeline.AddPass>("simplification"); @@ -173,7 +177,7 @@ tensorflow::Status OptimizeHloModule(HloModule* hlo_module) { pass.AddPass(); pass.AddPass(); } - pipeline.AddPass(); + pipeline.AddPass( [](const HloInstruction& dot, const TransposeFolding::OperandIndices& candidate_operands) { @@ -185,6 +189,58 @@ tensorflow::Status OptimizeHloModule(HloModule* hlo_module) { pipeline.AddPass(); TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status()); } + + { + // Convert convolutions into CustomCalls to cudnn, then canonicalize them + // (PadInsertion). + HloPassPipeline pipeline("conv_canonicalization"); + pipeline.AddInvariantChecker(); + pipeline.AddPass(); + pipeline.AddPass(); + + // Choose the fastest algorithm for each conv. + // + // In theory doing this here is way too early: It needs to happen after + // layout assignment, because the layout of the inputs/outputs affects the + // speed of the conv. But currently we only allow only one input/output + // layout when calling cudnn, so there's no ambiguity. + // + // We pick the algorithm at this early stage so we can generate better HLO. + // After CudnnConvolutionRewriter, our convolutions are CustomCalls which + // return a tuple (conv_result, scratch_memory), and the each conv uses 0 + // bytes of scratch: + // + // customcall = (f32[...], f32[0]) + // return gte(customcall, 0) + // + // The algorithm picker then chooses the best algorithm, and potentially + // increases the scratch space. It replaces customcall with new_tuple, + // giving us the following: + // + // new_customcall = (f32[...], f32[N]) + // new_tuple = tuple(gte(new_customcall, 0), constant f32[0]) + // return gte(new_tuple, 0) + // + // The new tuple and gte instructions then be simplified away, because + // nobody is expected to use the scratch value. + // + // However, if we were to run CudnnConvolutionAlgorithmPicker after layout + // assignment, fusion would already have run, and the gte(customcall, 0) + // would probably already be into a fusion node. We can't simplify across + // HloComputation boundaries, so in this case we wouldn't be able to + // simplify away the new_tuple bits. + // + // We'll need to revisit this if we ever allow multiple layouts for the + // inputs/outputs of a cudnn convolution. + pipeline.AddPass(stream_exec, + device_allocator); + // Clean up new_tuple described above. + pipeline.AddPass(); + pipeline.AddPass(); + + TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status()); + } + { HloPassFix fusion("fusion"); fusion.AddInvariantChecker(); @@ -212,9 +268,7 @@ tensorflow::Status OptimizeHloModule(HloModule* hlo_module) { // Modifies the given HLO module so that it will be accepted by IrEmitter. // Unlike optimization passes, the passes are necessary for correctness. -tensorflow::Status PrepareHloModuleForIrEmitting( - HloModule* hlo_module, se::StreamExecutor* stream_exec, - DeviceMemoryAllocator* /*device_allocator*/) { +tensorflow::Status PrepareHloModuleForIrEmitting(HloModule* hlo_module) { // In some cases, we have to place the result of an instruction in a temporary // buffer. For instance, the buffer that holds an external parameter is // assumed immutable at this point, and should not be reused for output @@ -222,9 +276,10 @@ tensorflow::Status PrepareHloModuleForIrEmitting( // the parameter. HloPassPipeline pipeline("GPU-ir-emit-prepare"); pipeline.AddInvariantChecker(); - pipeline.AddPass(); + pipeline.AddPass( hlo_module->mutable_entry_computation_layout()); + // The LayoutAssignment pass may leave behind kCopy instructions which are // duplicate or NOPs, so remove them with algebraic simplification and CSE. pipeline.AddPass>( @@ -417,7 +472,8 @@ StatusOr> GpuCompiler::RunHloPasses( XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunHloPasses"); Tracing::TraceMe annotation("HLO Transforms", module->name(), /*is_expensive=*/true); - TF_RETURN_IF_ERROR(OptimizeHloModule(module.get())); + TF_RETURN_IF_ERROR( + OptimizeHloModule(module.get(), stream_exec, device_allocator)); return std::move(module); } @@ -428,8 +484,7 @@ StatusOr> GpuCompiler::RunBackend( TF_RET_CHECK(stream_exec != nullptr); - TF_RETURN_IF_ERROR(PrepareHloModuleForIrEmitting(module.get(), stream_exec, - device_allocator)); + TF_RETURN_IF_ERROR(PrepareHloModuleForIrEmitting(module.get())); llvm::LLVMContext llvm_context; std::string buffer; @@ -464,16 +519,17 @@ StatusOr> GpuCompiler::RunBackend( /*color_alignment=*/[](LogicalBuffer::Color) { return kCudaMallocAlignBytes; })); - // BufferAssignment::ToString() includes a header, so no need for us to - // print one ourselves. + // BufferAssignment::Stats::ToString() and BufferAssignment::ToString() + // include headers, so no need for us to print them ourselves. + XLA_VLOG_LINES(1, buffer_assignment->GetStats().ToString()); XLA_VLOG_LINES(2, buffer_assignment->ToString()); XLA_VLOG_LINES(2, module->ToString()); - const string xla_dump_hlo_proto_to = - module->config().debug_options().xla_dump_hlo_proto_to(); - if (!xla_dump_hlo_proto_to.empty()) { + const string xla_dump_optimized_hlo_proto_to = + module->config().debug_options().xla_dump_optimized_hlo_proto_to(); + if (!xla_dump_optimized_hlo_proto_to.empty()) { HloProto proto = MakeHloProto(*module, *buffer_assignment); TF_RETURN_IF_ERROR(protobuf_util::DumpProtoToDirectory( - proto, xla_dump_hlo_proto_to, module->name())); + proto, xla_dump_optimized_hlo_proto_to, module->name())); } IrEmitterContext ir_emitter_context(module.get(), buffer_assignment.get(), diff --git a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc index e3b493c6630d061c00dc6c67bdaecdb2e5d68533..88bf5a74fa03618d6f61365450f05e6f5d1a0c86 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc @@ -78,6 +78,12 @@ StatusOr GpuCopyInsertion::Run(HloModule* module) { for (int64 i = 0; i < hlo->operand_count() - 2; ++i) { TF_RETURN_IF_ERROR(copy_operand_if_constant(i)); } + } else if (IsCustomCallToDnnConvolution(*hlo)) { + // The last argument to a CUDNN convolution is its algorithm, which must + // be an HLO constant -- it shouldn't be copied. + for (int64 i = 0; i < hlo->operand_count() - 1; ++i) { + TF_RETURN_IF_ERROR(copy_operand_if_constant(i)); + } } else if (ImplementedAsLibraryCall(*hlo)) { // For all other library calls, materialize all the operands into memory. for (int64 i = 0; i < hlo->operand_count(); ++i) { diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc index 58915f1f62f0c0f320443058a798333c498ffe47..89f1e625884568bf7370b3801d851ef4846c2a98 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc @@ -28,122 +28,114 @@ limitations under the License. namespace xla { namespace gpu { +// cuDNN convolutions are called with specific layouts on the input, output, +// and filter: +// +// input: DataLayout::kBatchDepthYX +// output: DataLayout::kBatchDepthYX +// filter: FilterLayout::kOutputInputYX +// +// The order dimensions in the constant name is major-to-minor (eg, the +// most-major dimension of the input is batch, most-minor is X). The +// specific dimension numbers these named dimensions correspond to is +// determined by the ConvolutionDimensionNumbers argument. Y is spatial +// dimension 0, and X is spatial dimension 1. +// +// TODO(b/29399649): Be more flexible about handling layouts of cuDNN calls. +static Status AddBackendConstraintsToDnnConvCustomCall( + HloInstruction* instr, LayoutConstraints* constraints) { + CHECK(IsCustomCallToDnnConvolution(*instr)) << instr->ToString(); + Shape input_shape; + Shape filter_shape; + Shape output_shape; + const auto& target = instr->custom_call_target(); + if (target == kCudnnConvForwardCallTarget) { + input_shape = instr->operand(0)->shape(); + filter_shape = instr->operand(1)->shape(); + output_shape = instr->shape().tuple_shapes(0); + } else if (target == kCudnnConvBackwardInputCallTarget) { + input_shape = instr->shape().tuple_shapes(0); + filter_shape = instr->operand(1)->shape(); + output_shape = instr->operand(0)->shape(); + } else if (target == kCudnnConvBackwardFilterCallTarget) { + input_shape = instr->operand(0)->shape(); + filter_shape = instr->shape().tuple_shapes(0); + output_shape = instr->operand(1)->shape(); + } else { + LOG(FATAL) << "Unexpected custom call target: " + << instr->custom_call_target(); + } + + // Construct minor-to-major dimension orders for operands and result. + // cuDNN's convolution APIs support the BDYX layout for activations/output + // and the OIYX layout for weights. + // TODO(b/29399649): Be more flexible about handling layouts of cuDNN + // calls after we switch to cuDNN v5. + const ConvolutionDimensionNumbers& dimension_numbers = + instr->convolution_dimension_numbers(); + std::vector input_layout; + for (int i = dimension_numbers.input_spatial_dimensions_size() - 1; i >= 0; + --i) { + input_layout.push_back(dimension_numbers.input_spatial_dimensions(i)); + } + input_layout.push_back(dimension_numbers.input_feature_dimension()); + input_layout.push_back(dimension_numbers.input_batch_dimension()); + *input_shape.mutable_layout() = LayoutUtil::MakeLayout(input_layout); + + std::vector filter_layout; + for (int i = dimension_numbers.kernel_spatial_dimensions_size() - 1; i >= 0; + --i) { + filter_layout.push_back(dimension_numbers.kernel_spatial_dimensions(i)); + } + filter_layout.push_back(dimension_numbers.kernel_input_feature_dimension()); + filter_layout.push_back(dimension_numbers.kernel_output_feature_dimension()); + *filter_shape.mutable_layout() = LayoutUtil::MakeLayout(filter_layout); + + std::vector output_layout; + for (int i = dimension_numbers.output_spatial_dimensions_size() - 1; i >= 0; + --i) { + output_layout.push_back(dimension_numbers.output_spatial_dimensions(i)); + } + output_layout.push_back(dimension_numbers.output_feature_dimension()); + output_layout.push_back(dimension_numbers.output_batch_dimension()); + *output_shape.mutable_layout() = LayoutUtil::MakeLayout(output_layout); + + // The custom call returns a tuple of (actual_result, scratch_buffer); + // call_result_buf is the logical buffer for actual_result, the thing that + // contains the result of the conv call. + TF_ASSIGN_OR_RETURN(const LogicalBuffer* call_result_buf, + constraints->points_to_analysis().GetBufferDefinedAt( + instr, /*index=*/{0})); + + // Set layouts of the instructions' shapes. + if (target == kCudnnConvForwardCallTarget) { + TF_RETURN_IF_ERROR(constraints->SetOperandLayout(input_shape, instr, 0)); + TF_RETURN_IF_ERROR(constraints->SetOperandLayout(filter_shape, instr, 1)); + TF_RETURN_IF_ERROR( + constraints->SetBufferLayout(output_shape.layout(), *call_result_buf)); + } else if (target == kCudnnConvBackwardInputCallTarget) { + TF_RETURN_IF_ERROR(constraints->SetOperandLayout(output_shape, instr, 0)); + TF_RETURN_IF_ERROR(constraints->SetOperandLayout(filter_shape, instr, 1)); + TF_RETURN_IF_ERROR( + constraints->SetBufferLayout(input_shape.layout(), *call_result_buf)); + } else if (target == kCudnnConvBackwardFilterCallTarget) { + TF_RETURN_IF_ERROR(constraints->SetOperandLayout(input_shape, instr, 0)); + TF_RETURN_IF_ERROR(constraints->SetOperandLayout(output_shape, instr, 1)); + TF_RETURN_IF_ERROR( + constraints->SetBufferLayout(filter_shape.layout(), *call_result_buf)); + } else { + LOG(FATAL) << "Unexpected custom call target: " + << instr->custom_call_target(); + } + return Status::OK(); +} + Status GpuLayoutAssignment::AddBackendConstraints( LayoutConstraints* constraints) { for (auto* instruction : constraints->computation()->instructions()) { - // cuDNN is called with specific layouts on the input, output, and filter: - // - // input: DataLayout::kBatchDepthYX - // output: DataLayout::kBatchDepthYX - // filter: FilterLayout::kOutputInputYX - // - // The order dimensions in the constant name is major-to-minor (eg, the - // most-major dimension of the input is batch, most-minor is X). The - // specific dimension numbers these named dimensions correspond to is - // determined by the ConvolutionDimensionNumbers argument. Y is spatial - // dimension 0, and X is spatial dimension 1. - // - // TODO(b/29399649): Be more flexible about handling layouts of cuDNN calls. - if (ImplementedAsDnnConvolution(*instruction)) { - HloInstruction* input = nullptr; - HloInstruction* filter = nullptr; - HloInstruction* output = nullptr; - if (instruction->opcode() == HloOpcode::kConvolution) { - input = instruction->mutable_operand(0); - filter = instruction->mutable_operand(1); - output = instruction; - } else { - CHECK_EQ(HloOpcode::kFusion, instruction->opcode()); - switch (instruction->fusion_kind()) { - case HloInstruction::FusionKind::kConvBackwardFilter: - // filter = BackwardFilterConvolve(input, output) - input = instruction->mutable_operand(0); - filter = instruction; - output = instruction->mutable_operand(1); - break; - case HloInstruction::FusionKind::kConvBackwardInput: - // input = BackwardInputConvolve(output, filter) - input = instruction; - filter = instruction->mutable_operand(1); - output = instruction->mutable_operand(0); - break; - default: - LOG(FATAL) << "Not a convolution-fusion"; - } - } - - // Construct minor-to-major dimension orders for operands and result. - // cuDNN's convolution APIs support the BDYX layout for activations/output - // and the OIYX layout for weights. - // TODO(b/29399649): Be more flexible about handling layouts of cuDNN - // calls after we switch to cuDNN v5. - const ConvolutionDimensionNumbers& dimension_numbers = - instruction->convolution_dimension_numbers(); - std::vector input_layout; - for (int i = dimension_numbers.input_spatial_dimensions_size() - 1; - i >= 0; --i) { - input_layout.push_back(dimension_numbers.input_spatial_dimensions(i)); - } - input_layout.push_back(dimension_numbers.input_feature_dimension()); - input_layout.push_back(dimension_numbers.input_batch_dimension()); - Shape input_shape(input->shape()); - *input_shape.mutable_layout() = LayoutUtil::MakeLayout(input_layout); - - std::vector filter_layout; - for (int i = dimension_numbers.kernel_spatial_dimensions_size() - 1; - i >= 0; --i) { - filter_layout.push_back(dimension_numbers.kernel_spatial_dimensions(i)); - } - filter_layout.push_back( - dimension_numbers.kernel_input_feature_dimension()); - filter_layout.push_back( - dimension_numbers.kernel_output_feature_dimension()); - Shape filter_shape(filter->shape()); - *filter_shape.mutable_layout() = LayoutUtil::MakeLayout(filter_layout); - - std::vector output_layout; - for (int i = dimension_numbers.output_spatial_dimensions_size() - 1; - i >= 0; --i) { - output_layout.push_back(dimension_numbers.output_spatial_dimensions(i)); - } - output_layout.push_back(dimension_numbers.output_feature_dimension()); - output_layout.push_back(dimension_numbers.output_batch_dimension()); - Shape output_shape(output->shape()); - *output_shape.mutable_layout() = LayoutUtil::MakeLayout(output_layout); - - // Set layouts of the instructions' shapes. - if (instruction->opcode() == HloOpcode::kConvolution) { - TF_RETURN_IF_ERROR( - constraints->SetOperandLayout(input_shape, output, 0)); - TF_RETURN_IF_ERROR( - constraints->SetOperandLayout(filter_shape, output, 1)); - TF_RETURN_IF_ERROR( - constraints->SetInstructionLayout(output_shape, output)); - } else { - CHECK_EQ(HloOpcode::kFusion, instruction->opcode()); - switch (instruction->fusion_kind()) { - case HloInstruction::FusionKind::kConvBackwardFilter: - // filter = BackwardFilterConvolve(input, output) - TF_RETURN_IF_ERROR( - constraints->SetOperandLayout(input_shape, filter, 0)); - TF_RETURN_IF_ERROR( - constraints->SetInstructionLayout(filter_shape, filter)); - TF_RETURN_IF_ERROR( - constraints->SetOperandLayout(output_shape, filter, 1)); - break; - case HloInstruction::FusionKind::kConvBackwardInput: - // input = BackwardInputConvolve(output, filter) - TF_RETURN_IF_ERROR( - constraints->SetInstructionLayout(input_shape, input)); - TF_RETURN_IF_ERROR( - constraints->SetOperandLayout(output_shape, input, 0)); - TF_RETURN_IF_ERROR( - constraints->SetOperandLayout(filter_shape, input, 1)); - break; - default: - LOG(FATAL) << "Not a convolution-fusion"; - } - } + if (IsCustomCallToDnnConvolution(*instruction)) { + TF_RETURN_IF_ERROR( + AddBackendConstraintsToDnnConvCustomCall(instruction, constraints)); } } return Status::OK(); @@ -151,9 +143,12 @@ Status GpuLayoutAssignment::AddBackendConstraints( bool GpuLayoutAssignment::CustomCallRequiresMajorFirstLayout( const HloInstruction* instruction) { - // Inputs to cudnn batchnorm custom calls don't need the major-first layout - // (i.e. {n, n-1, ...0}) -- we can handle any layout. - return !IsCustomCallToDnnBatchNorm(*instruction); + // - Inputs to cudnn batchnorm custom calls don't need the major-first layout + // (i.e. {n, n-1, ...0}) -- we can handle any layout. + // - Inputs to cudnn convolution require custom layouts handled in + // AddBackendConstraints. + return !IsCustomCallToDnnBatchNorm(*instruction) && + !IsCustomCallToDnnConvolution(*instruction); } Status GpuLayoutAssignment::PropagateOperandConstraint( diff --git a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc index c2115c49993ef71c4b6dd584e7e0498807666613..061210352cf12e6802d066d311fd2cb481673f15 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc +++ b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc @@ -22,12 +22,17 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" namespace xla { namespace gpu { +using tensorflow::strings::StrAppend; +using tensorflow::strings::StrCat; + void HloToIrBindings::EmitBasePointersForHlos( tensorflow::gtl::ArraySlice io_hlos, tensorflow::gtl::ArraySlice non_io_hlos) { @@ -191,7 +196,11 @@ static bool BuffersInvariantWithinConsumer( llvm_ir::IrArray HloToIrBindings::GetIrArray(const HloInstruction& hlo, const HloInstruction& consumer, const ShapeIndex& shape_index) { - llvm_ir::IrArray ir_array(GetBasePointer(hlo, shape_index), + llvm::Value* base_ptr = GetBasePointer(hlo, shape_index); + CHECK_NE(base_ptr, nullptr) + << "Buffer not assigned for shape_index " << shape_index.ToString() + << " of " << hlo.ToString(); + llvm_ir::IrArray ir_array(base_ptr, ShapeUtil::GetSubshape(hlo.shape(), shape_index)); alias_analysis_.AddAliasingInformationToIrArray(hlo, &ir_array); @@ -223,5 +232,54 @@ void HloToIrBindings::UnbindAllLocalIrValues() { } } +string HloToIrBindings::ToString() const { + string s = StrCat("** HloToIrBindings **\n"); + StrAppend(&s, " is_nested_=", is_nested_, "\n"); + StrAppend(&s, + " temp_buffer_base_=", llvm_ir::DumpToString(*temp_buffer_base_), + "\n"); + + if (base_ptrs_.empty()) { + return s; + } + + // Iterate over all computations in the module in topological order, and print + // out the base pointers we have in each computation in topological order. + for (const HloComputation* computation : + base_ptrs_.begin()->first->GetModule()->MakeComputationPostOrder()) { + bool is_first = true; + for (const HloInstruction* instr : + computation->MakeInstructionPostOrder()) { + auto it = base_ptrs_.find(instr); + if (it == base_ptrs_.end()) { + continue; + } + if (is_first) { + StrAppend(&s, " Base pointers for computation ", computation->name(), + ":\n"); + is_first = false; + } + StrAppend(&s, " ", instr->ToString()); + + const ShapeTree& shape_tree = it->second; + if (!ShapeUtil::IsTuple(instr->shape())) { + const llvm::Value* val = shape_tree.begin()->second; + StrAppend(&s, " -> ", llvm_ir::DumpToString(*val), "\n"); + continue; + } + + StrAppend(&s, "\n"); + for (auto shape_it = shape_tree.begin(); shape_it != shape_tree.end(); + ++shape_it) { + llvm::Value* val = shape_it->second; + StrAppend(&s, " ", shape_it->first.ToString(), " -> ", + (val != nullptr ? llvm_ir::DumpToString(*val) : "null"), + "\n"); + } + } + } + return s; +} + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h index 62ae1769a1f2fb3b9acaf35bdf18a793232500b0..1fe7970e7d94ad4a4cad6aabcfc84a1356753443 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h +++ b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h @@ -87,6 +87,8 @@ class HloToIrBindings { const HloInstruction& consumer, const ShapeIndex& shape_index = {}); + string ToString() const; + private: // Emits IR to resolve (possibly) recursive GetTupleElement instructions. llvm::Value* EmitGetTupleElement(const HloInstruction* gte, diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc index 1d47ffde4331868cbc8a8afb2d01b11e77a7fab0..2d6dad27a59978da6e4719afc50ebee5e641dde0 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc @@ -137,49 +137,6 @@ TEST_F(InstructionFusionTest, PotentialBitcastTransposeOfDotUnfused) { .ValueOrDie()); } -TEST_F(InstructionFusionTest, PotentialBitcastTransposeOfConvolutionUnfused) { - HloComputation::Builder builder(TestName()); - auto input = builder.AddInstruction(HloInstruction::CreateParameter( - 0, ShapeUtil::MakeShape(F32, {1, 1, 1, 3}), "input")); - auto filter = builder.AddInstruction(HloInstruction::CreateParameter( - 1, ShapeUtil::MakeShape(F32, {1, 1, 1, 2}), "filter")); - - Window conv_window; - WindowDimension* conv_window_row = conv_window.add_dimensions(); - conv_window_row->set_size(1); - WindowDimension* conv_window_col = conv_window.add_dimensions(); - conv_window_col->set_size(2); - conv_window_col->set_padding_high(1); - - ConvolutionDimensionNumbers conv_dnums; - conv_dnums.set_input_batch_dimension(0); - conv_dnums.set_output_batch_dimension(0); - conv_dnums.set_input_feature_dimension(1); - conv_dnums.set_output_feature_dimension(1); - conv_dnums.add_input_spatial_dimensions(2); - conv_dnums.add_output_spatial_dimensions(2); - conv_dnums.add_input_spatial_dimensions(3); - conv_dnums.add_output_spatial_dimensions(3); - conv_dnums.set_kernel_output_feature_dimension(0); - conv_dnums.set_kernel_input_feature_dimension(1); - conv_dnums.add_kernel_spatial_dimensions(2); - conv_dnums.add_kernel_spatial_dimensions(3); - - auto conv = builder.AddInstruction( - HloInstruction::CreateConvolve(ShapeUtil::MakeShape(F32, {1, 1, 1, 3}), - input, filter, conv_window, conv_dnums)); - auto transpose = builder.AddInstruction(HloInstruction::CreateTranspose( - ShapeUtil::MakeShape(F32, {3, 1, 1, 1}), conv, {3, 2, 1, 0})); - builder.AddInstruction( - HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {3}), transpose)); - - auto module = CreateNewModule(); - module->AddEntryComputation(builder.Build()); - EXPECT_FALSE(GpuInstructionFusion(/*may_duplicate=*/true) - .Run(module.get()) - .ValueOrDie()); -} - TEST_F(InstructionFusionTest, GetTupleElementFused) { HloComputation::Builder builder(TestName()); Shape data_shape = ShapeUtil::MakeShape(F32, {8}); diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc index 76566a9e3dbbc936ff90fe3f440ede14bf4e5233..2f65edffea81db7dba1f8545f92b27ea622044e7 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc @@ -90,43 +90,6 @@ bool ImplementedAsGemm(const HloInstruction& hlo) { return false; } -bool ImplementedAsDnnConvolution(const HloInstruction& hlo) { - // We can only do this if the HLO is unnested. - if (hlo.parent() != hlo.GetModule()->entry_computation()) { - return false; - } - - // Forward convolution. - if (hlo.opcode() == HloOpcode::kConvolution) { - const ConvolutionDimensionNumbers& dnums = - hlo.convolution_dimension_numbers(); - if (dnums.input_spatial_dimensions_size() > 3) { - return false; - } - - // CuDNN does not accept zero-element arguments - if (ShapeUtil::HasZeroElements(hlo.operand(0)->shape()) || - ShapeUtil::HasZeroElements(hlo.operand(1)->shape())) { - return false; - } - - if (window_util::HasWindowReversal(hlo.window())) { - return false; - } - - return true; - } - - // Backward convolution. - if (hlo.opcode() == HloOpcode::kFusion && - (hlo.fusion_kind() == HloInstruction::FusionKind::kConvBackwardFilter || - hlo.fusion_kind() == HloInstruction::FusionKind::kConvBackwardInput)) { - return true; - } - - return false; -} - const char* const kCudnnBatchNormForwardInferenceCallTarget = "__cudnn$batchNormalizationForwardInference"; const char* const kCudnnBatchNormForwardTrainingCallTarget = @@ -144,9 +107,76 @@ bool IsCustomCallToDnnBatchNorm(const HloInstruction& hlo) { target == kCudnnBatchNormBackwardCallTarget; } +const char* const kCudnnConvForwardCallTarget = "__cudnn$convForward"; +const char* const kCudnnConvBackwardInputCallTarget = + "__cudnn$convBackwardInput"; +const char* const kCudnnConvBackwardFilterCallTarget = + "__cudnn$convBackwardFilter"; + +bool IsCustomCallToDnnConvolution(const HloInstruction& hlo) { + if (hlo.opcode() != HloOpcode::kCustomCall) { + return false; + } + const auto& target = hlo.custom_call_target(); + return target == kCudnnConvForwardCallTarget || + target == kCudnnConvBackwardInputCallTarget || + target == kCudnnConvBackwardFilterCallTarget; +} + bool ImplementedAsLibraryCall(const HloInstruction& hlo) { - return ImplementedAsGemm(hlo) || ImplementedAsDnnConvolution(hlo) || - IsCustomCallToDnnBatchNorm(hlo); + return ImplementedAsGemm(hlo) || IsCustomCallToDnnBatchNorm(hlo) || + IsCustomCallToDnnConvolution(hlo); +} + +static HloInstruction* CreateCudnnConv( + const char* call_target, const Shape& shape, HloInstruction* lhs, + HloInstruction* rhs, const Window& window, + const ConvolutionDimensionNumbers& dnums) { + HloComputation* computation = lhs->parent(); + + // This call returns a tuple of (conv_result, scratch_memory), where + // conv_result is the actual result of the convolution, and scratch_memory is + // temporary memory used by cudnn. + // + // At the moment, we don't know how much scratch memory this conv is going to + // use, so we put u8[0] in this place. Later on another pass will choose + // which conv algorithm to use, and at that point we'll modify the shape of + // this second tuple element. + Shape call_shape = + ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U8, {0})}); + + // Our CustomCall takes three arguments: The conv lhs and rhs, and the cudnn + // algorithm to use. It's up to a later pass to choose the algorithm, so to + // indicate that we haven't yet made a choice, we speicfy -1 for that arg. + HloInstruction* negative_one = computation->AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(-1))); + HloInstruction* custom_call = + computation->AddInstruction(HloInstruction::CreateCustomCall( + call_shape, {lhs, rhs, negative_one}, call_target)); + custom_call->set_window(window); + custom_call->set_convolution_dimension_numbers(dnums); + return custom_call; +} + +HloInstruction* CreateCudnnConvForward( + const Shape& shape, HloInstruction* input, HloInstruction* kernel, + const Window& window, const ConvolutionDimensionNumbers& dnums) { + return CreateCudnnConv(kCudnnConvForwardCallTarget, shape, input, kernel, + window, dnums); +} + +HloInstruction* CreateCudnnConvBackwardInput( + const Shape& shape, HloInstruction* output, HloInstruction* reverse_filter, + const Window& window, const ConvolutionDimensionNumbers& dnums) { + return CreateCudnnConv(kCudnnConvBackwardInputCallTarget, shape, output, + reverse_filter, window, dnums); +} + +HloInstruction* CreateCudnnConvBackwardFilter( + const Shape& shape, HloInstruction* input, HloInstruction* output, + const Window& window, const ConvolutionDimensionNumbers& dnums) { + return CreateCudnnConv(kCudnnConvBackwardFilterCallTarget, shape, input, + output, window, dnums); } bool IsReductionToVector(const HloInstruction& reduce) { diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h index d24ed9879d084e96862885efaae2f79a256cd71d..7ad9680bfb4a2ec0d43e2fe86fd138a4a46e2935 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h @@ -22,6 +22,9 @@ limitations under the License. #include "llvm/IR/Value.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +// TODO(jlebar): Move functions related to cublas/cudnn to a separate file; they +// don't belong in "ir_emission_utils". + namespace xla { namespace gpu { @@ -30,9 +33,6 @@ constexpr int64 kWarpSize = 32; // Returns true if `hlo` will be implemented as a call to BLAS gemm. bool ImplementedAsGemm(const HloInstruction& hlo); -// Returns true if `hlo` will be implemented as a call to cuDNN convolution. -bool ImplementedAsDnnConvolution(const HloInstruction& hlo); - // A call to cuDNN for batch normalization is represented as CustomCall HLO with // a call target equal to one of these strings. // @@ -58,6 +58,60 @@ extern const char* const kCudnnBatchNormBackwardCallTarget; // sequence of generic HLOs or to a cuDNN CustomCall. bool IsCustomCallToDnnBatchNorm(const HloInstruction& hlo); +// A call to cuDNN for convolution (forward, backward filter, or backward input) +// is represented as a CustomCall HLO with a call target equal to one of these +// strings. +// +// These CustomCalls have window() and convolution_dimension_numbers() set like +// regular convolution ops. They have the same LHS and RHS operands, plus one +// additional int64 operand, representing which cudnn algorithm to run. This +// operand must be an HLO constant. A value of -1 means that the implementation +// is free to choose the best algorithm it can. +// +// These calls output a tuple (conv_result, scratch_memory), where conv_result +// is the actual result of the convolution, and scratch_memory is temporary +// memory used by cudnn. Callers shouldn't inspect scratch_memory, as its value +// is not well-defined. +// +// CudnnConvolutionRewriter lowers kConvolution HLOs to these custom calls. +// When it does so, it chooses algorithm -1 and 0 bytes of scratch space. Later +// on in the pipeline, CudnnConvolutionAlgorithmChooser chooses an explicit +// algorithm for each conv and sets the amount of scratch space needed. +// +// (Representing the scratch memory as an output may seem strange at first, but +// it's quite sensible, from a certain point of view. The scratch buffer is a +// location in memory that the conv can write into, but which it can't legally +// read from, at least until it's written something first. But that's exactly +// the definition of an output buffer.) +extern const char* const kCudnnConvForwardCallTarget; +extern const char* const kCudnnConvBackwardInputCallTarget; +extern const char* const kCudnnConvBackwardFilterCallTarget; + +// Returns true if `hlo` will be implemented as a call to a cuDNN convolution +// routine. +// +// This returns true if `hlo` is a CustomCall HLO with a call target equal to +// one of the kCudnnConvFoo constants above, but returns *false* for HLOs with a +// kConvolution opcode. +bool IsCustomCallToDnnConvolution(const HloInstruction& hlo); + +// Creates a CustomCall for a cudnn forward/backward-input/backward-filter conv. +// Note that these CustomCalls return a tuple (conv_result, scratch_memory). If +// you want just the conv result, you'll need to get-tuple-element the value +// returned by this function. +// +// The created cudnn call will use the default cudnn algorithm and no scratch +// space. +HloInstruction* CreateCudnnConvForward( + const Shape& shape, HloInstruction* input, HloInstruction* kernel, + const Window& window, const ConvolutionDimensionNumbers& dnums); +HloInstruction* CreateCudnnConvBackwardInput( + const Shape& shape, HloInstruction* output, HloInstruction* reverse_filter, + const Window& window, const ConvolutionDimensionNumbers& dnums); +HloInstruction* CreateCudnnConvBackwardFilter( + const Shape& shape, HloInstruction* input, HloInstruction* output, + const Window& window, const ConvolutionDimensionNumbers& dnums); + // Returns true if `hlo` will be implemented as a library call, e.g. cuBLAS gemm // or cuDNN convolution. bool ImplementedAsLibraryCall(const HloInstruction& hlo); diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc index 23b72c3f71dacf2be02a0719c07c7e6e88abd00c..a3df67a87344d6ece2ea9047321ad9542c13f8cf 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc @@ -27,6 +27,8 @@ limitations under the License. #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/service/elemental_ir_emitter.h" #include "tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h" +#include "tensorflow/compiler/xla/service/gpu/ir_emitter_nested.h" +#include "tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h" #include "tensorflow/compiler/xla/service/gpu/partition_assignment.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h" @@ -615,8 +617,7 @@ Status IrEmitter::HandleFft(HloInstruction* fft) { Status IrEmitter::HandleCrossReplicaSum(HloInstruction* crs) { // TODO(b/33011107): Support cross replica sum on GPU. - return Unimplemented( - "Cross replica sum not implemented on GPU. See b/33011107."); + return Unimplemented("CrossReplicaSum is not implemented on GPU."); } Status IrEmitter::HandleParameter(HloInstruction* parameter) { @@ -710,11 +711,13 @@ Status IrEmitter::HandleCustomCall(HloInstruction*) { } Status IrEmitter::HandleInfeed(HloInstruction*) { - return Unimplemented("Infeed is not supported on GPU (b/30467474)."); + // TODO(b/30467474): Implement infeed on GPU. + return Unimplemented("Infeed is not supported on GPU."); } Status IrEmitter::HandleOutfeed(HloInstruction*) { - return Unimplemented("Outfeed is not supported on GPU (b/34359662)."); + // TODO(b/34359662): Implement outfeed on GPU. + return Unimplemented("Outfeed is not supported on GPU."); } Status IrEmitter::HandleRng(HloInstruction* random) { diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.h b/tensorflow/compiler/xla/service/gpu/ir_emitter.h index 3aa178410f05aef3630a4bd83b9651f6c1aac79b..b0accc08d479258d65a18202122e4c9e90ff78d0 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.h @@ -13,19 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// An XLA HLO graph may contain multiple computations. These computations -// fall into two types, nested and unnested. We translate each nested -// computation (e.g. the computation operand of a Map operator) to a device -// function. For each unnested computation composed of top-level -// HloInstructions, we generate a CUDA kernel for each HloInstruction. -// -// This file declares classes that translate an XLA HLO graph to LLVM IR for -// GPUs. IrEmitterNested emits LLVM IR for nested computations, and -// IrEmitterUnnested for unnested computations. The logic of emitting LLVM IR -// for each individual HloInstruction is largely the same between these two -// classes. Therefore, we implement the common logic in the Handle* functions in -// the superclass IrEmitter. - #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_H_ @@ -60,19 +47,28 @@ limitations under the License. namespace xla { namespace gpu { -// This class is the top-level API for the XLA HLO --> LLVM IR compiler. -// It implements the DfsHloVisitor interface and emits an LLVM IR program that -// implements the input HLO graph. +// Abstract base class for translating HLO graphs to LLVM IR for a GPU. +// +// There are two concrete subclasses of IrEmitter: IrEmitterNested and +// IrEmitterUnnested. In the unnested variety, each HLO gets its own kernel +// function, whereas in the nested version the whole computation is emitted as +// one *non-kernel* function. +// +// In XLA, kernel functions never call other kernel functions. This means that +// if we have a kernel -- e.g. implementing a kReduce HLO -- that wants to use +// an HLO computation as a "subroutine" -- e.g. the HLO computation that +// specifies how to reduce two elements -- then the subroutine computation must +// be emitted using IrEmitterNested. // -// Note: if `T` is a subclass of `IrEmitter` and a handler is not overridden in -// either `IrEmitter` or `T`, the handler in `DfsHloVisitorWithDefault` -// calls `T::DefaultAction`. +// Fusion nodes are a special case. A fusion node is emitted using +// IrEmitterUnnested, but the code is generated using FusedIrEmitter, which is +// not a subclass of gpu::IrEmitter, and in fact is better understood as an IR +// generator generator. See comments on that class. class IrEmitter : public DfsHloVisitorWithDefault { public: IrEmitter(const IrEmitter&) = delete; IrEmitter& operator=(const IrEmitter&) = delete; - // The following methods implement the DfsHloVisitorWithDefault interface. Status DefaultAction(HloInstruction* hlo) override; Status HandleConstant(HloInstruction* constant) override; Status HandleBitcast(HloInstruction* bitcast) override; @@ -217,202 +213,6 @@ class IrEmitter : public DfsHloVisitorWithDefault { std::map computation_to_ir_function_; }; -// Emits LLVM IR for unnested computations. Each HloInstruction is translated to -// a separate CUDA kernel. These kernels are inserted into the resultant module -// sorted in reverse postorder of the XLA HLO graph. -class IrEmitterUnnested : public IrEmitter { - public: - IrEmitterUnnested(const HloModuleConfig& hlo_module_config, - const HloComputation* hlo_computation, - IrEmitterContext* ir_emitter_context); - IrEmitterUnnested(const IrEmitterUnnested&) = delete; - IrEmitterUnnested& operator=(const IrEmitterUnnested&) = delete; - - // Transfers the ownship of thunk_sequence_ out. - std::unique_ptr ConsumeThunkSequence() { - return std::move(thunk_sequence_); - } - - Status DefaultAction(HloInstruction* hlo) override; - - // IrEmitterUnnested handles the following instructions differently from - // IrEmitter. - Status HandleCopy(HloInstruction* copy) override; - Status HandleConditional(HloInstruction* conditional) override; - Status HandleConvolution(HloInstruction* convolution) override; - Status HandleCustomCall(HloInstruction* custom_call) override; - Status HandleDot(HloInstruction* dot) override; - Status HandleFft(HloInstruction* fft) override; - Status HandleFusion(HloInstruction* fusion) override; - Status HandleGetTupleElement(HloInstruction* get_tuple_element) override; - Status HandleReduce(HloInstruction* reduce) override; - Status HandleSelectAndScatter(HloInstruction* instruction) override; - Status HandleTuple(HloInstruction* tuple) override; - Status HandleWhile(HloInstruction* xla_while) override; - Status HandleInfeed(HloInstruction* xla_infeed) override; - Status HandleRng(HloInstruction* random) override; - Status HandleSelect(HloInstruction* select) override; - - Status EmitTargetElementLoop( - const HloInstruction& hlo, - const llvm_ir::ElementGenerator& body_emitter) override; - - // Same as `EmitTargetElementLoop`, but in given `thunk` rather than - // `LastThunk()`. - Status EmitTargetElementLoopInThunk( - const HloInstruction& hlo, const llvm_ir::ElementGenerator& body_emitter, - KernelThunk* thunk); - - private: - // Builds the appropriate thunk for the instruction hlo and returns the owning - // pointer to it. The caller needs to make sure `inst` outlives the lifetime - // of the returned Thunk object. - std::unique_ptr BuildThunk(const HloInstruction* hlo); - - // Builds the prototype of the IR kernel for `inst` and adds it to the module. - llvm::Function* BuildKernelPrototype( - const HloInstruction& inst, - tensorflow::gtl::ArraySlice escaped_hlos); - - // Emits the base pointers for `hlo` and its operands. `io_hlos` will store - // all input/output HLOs among `hlo` and its operands. - llvm::Function* EmitBasePointersForHloAndItsOperands( - const HloInstruction& hlo, std::vector* io_hlos); - - // EmitColumnReduction and EmitRowReduction emit code for column and row - // reduction of a matrix and/or 3D tensor. Row and column reduction have - // different memory access pattern, so for performance their implementations - // are significantly different. - // - // Emits code that reduces a matrix of shape [height x width] to a vector of - // [width]. Other parameters have the same meaning as those of - // `EmitReductionToVector`. Note that input shape might not be - // [height x width], but can be bitcast to [height x weight] with "height" - // being the major dimension. - Status EmitColumnReduction(int64 height, int64 width, HloInstruction* reduce, - const Shape& input_shape, - const llvm_ir::ElementGenerator& input_gen, - const llvm_ir::ElementGenerator& init_value_gen, - HloComputation* reducer); - - // Emits code that reduces a 3D tensor of shape [depth x height x width] to a - // vector of shape [height]. Other parameters have the same meaning as those - // of `EmitReductionToVector`. Note that input shape might not be - // [depth x height x width], but can be bitcast to [depth x height x weight] - // with "depth" being the most major dimension. - Status EmitRowReduction(int64 depth, int64 height, int64 width, - HloInstruction* reduce, const Shape& input_shape, - const llvm_ir::ElementGenerator& input_gen, - const llvm_ir::ElementGenerator& init_value_gen, - HloComputation* reducer); - - // Emits code that reduces a tensor of arbitrary rank to a scalar. - Status EmitReductionToScalar(HloInstruction* reduce, const Shape& input_shape, - const llvm_ir::ElementGenerator& input_gen, - const llvm_ir::ElementGenerator& init_value_gen, - HloComputation* reducer); - - // Figures out whether `reduce` is a row or column reduction, and which - // dimensions to reduce, and calls either `EmitRowReduction` or - // `EmitColumnReduction` as appropriate. `input_shape` is the shape of the - // input array, which is the operand of the Reduce instruction if unfused or - // of the Fusion instruction if fused. `input_gen` and `init_value_gen` - // generate elements of the input and the initial value. Other parameters mean - // the same as for `HandleReduce`. - // - // Prerequisite: `IsReductionToVector(*reduce)` - Status EmitReductionToVector( - HloInstruction* reduce, const Shape& input_shape, - const llvm_ir::ElementGenerator& input_gen, - const llvm_ir::ElementGenerator& init_value_gen, - tensorflow::gtl::ArraySlice dimensions_to_reduce, - HloComputation* reducer); - - // Emits code to initialize buffer of `inst` in given `thunk`. - Status EmitInitializer(const HloInstruction* inst, KernelThunk* thunk); - - // Returns a KernelThunk that invokes the kernel emitted for `inst`. The - // caller needs to make sure `inst` outlives the lifetime of the returned - // Thunk object. - std::unique_ptr BuildKernelThunk(const HloInstruction* inst); - - // Returns a ConvolutionThunk that calls DNN to implement `inst`. - std::unique_ptr BuildConvolutionThunk(const HloInstruction* inst); - - // Returns a FftThunk that calls cuFFT to implement `inst`. - std::unique_ptr BuildFftThunk(const HloInstruction* inst); - - // Returns a GemmThunk that calls gemm to implement `inst`. The caller needs - // to make sure `inst` outlives the lifetime of the returned Thunk object. - std::unique_ptr BuildGemmThunk(const HloInstruction* inst); - - // Returns a thunk that calls host-to-device cuMemcpy to implement `inst`. - std::unique_ptr BuildHostToDeviceCopyThunk(const HloInstruction* inst); - - // Returns a thunk that calls device-to-device cuMemcpy to implement `inst`. - std::unique_ptr BuildDeviceToDeviceCopyThunk( - const HloInstruction* inst); - - // Returns an InfeedThunk that performs device-to-device memcpy to implement - // `inst`. - std::unique_ptr BuildInfeedThunk(const HloInstruction* inst); - - // Returns a WhileThunk that invokes thunk sequences for 'condition' and - // 'body' sub-computations of while instruction 'hlo'. - std::unique_ptr BuildWhileThunk(const HloInstruction* hlo); - - // Returns a ForThunk which executes 'loop_limit' invocations of a thunk - // sequence from the 'body' sub-computation of the while instruction 'hlo'. - std::unique_ptr BuildForThunk(const HloInstruction* hlo, - const int64 loop_limit); - - // Returns a ConditionalThunk that executes the thunk sequence for - // 'true_computation' or 'false_computation' depending on the value of the - // predicate in the given conditional instruction. - std::unique_ptr BuildConditionalThunk(const HloInstruction* hlo); - - Status Postprocess(HloInstruction* hlo) override; - - // Returns the last generated thunk. - Thunk* LastThunk() const { return thunk_sequence_->back().get(); } - - // The thunk sequence this IrEmitter generates for the input computation. - std::unique_ptr thunk_sequence_; - - // The HloComputation that this IrEmitter emits code for. - const HloComputation* hlo_computation_; -}; - -// Emits LLVM IR for a nested computation to the resultant function. -class IrEmitterNested : public IrEmitter { - public: - // Constructs an LLVM IR emitter for a nested HLO computation. `function` is - // the containing IR function this emitter produces IR to. See - // IrEmitter::IrEmitter for the meanings of other arguments. - IrEmitterNested(const HloModuleConfig& hlo_module_config, - const HloComputation& nested_computation, - IrEmitterContext* ir_emitter_context); - IrEmitterNested(const IrEmitterNested&) = delete; - IrEmitterNested& operator=(const IrEmitterNested&) = delete; - - // Overrides the default empty implementation. Binds the given instruction - // "parameter" with the parameter of the IR function. - Status HandleParameter(HloInstruction* parameter) override; - - llvm::Function* GetEmittedFunction() const { return emitted_function_; } - - Status EmitTargetElementLoop( - const HloInstruction& hlo, - const llvm_ir::ElementGenerator& body_emitter) override; - - private: - llvm::Function* EmitBasePointersForNestedComputation( - const HloComputation& nested_computation, - std::vector* io_hlos); - - llvm::Function* emitted_function_; -}; - } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc index 5225ff36ff3a8a1b049479c34aa301de8724f73e..71aada080ae8df70bffce3e1854b5fbd833efd23 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc @@ -16,12 +16,13 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/service/gpu/ir_emitter_nested.h" + #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Instructions.h" #include "tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h" -#include "tensorflow/compiler/xla/service/gpu/ir_emitter.h" #include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.h new file mode 100644 index 0000000000000000000000000000000000000000..ca11cf2c182b0600b931b19d2d7fb3983e36441a --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.h @@ -0,0 +1,72 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_NESTED_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_NESTED_H_ + +#include "llvm/IR/Function.h" +#include "tensorflow/compiler/xla/service/gpu/ir_emitter.h" + +namespace xla { +namespace gpu { + +// Emits LLVM IR for a "nested computation" into a non-kernel device function. +// +// This is used to emit code for HloComputations that don't require a separate +// kernel call. For example, IrEmitterNested is used to emit code for a kReduce +// HLO's elementwise reduction computation. Notably, IrEmitterNested is *not* +// used to emit code for fusion nodes -- fusion nodes use FusedIrEmitter, which +// is a different beast altogether. +// +// IrEmitterNested generates a non-kernel function with the following +// parameters: +// +// - N pointers to the buffers of each of the N parameters to the computation, +// - a pointer to the output buffer of the computation, and +// - a pointer to the top-level temp buffer. +// +class IrEmitterNested : public IrEmitter { + public: + // Constructs an LLVM IR emitter for a nested HLO computation. `function` is + // the containing IR function this emitter produces IR to. See + // IrEmitter::IrEmitter for the meanings of other arguments. + IrEmitterNested(const HloModuleConfig& hlo_module_config, + const HloComputation& nested_computation, + IrEmitterContext* ir_emitter_context); + IrEmitterNested(const IrEmitterNested&) = delete; + IrEmitterNested& operator=(const IrEmitterNested&) = delete; + + // Overrides the default empty implementation. Binds the given instruction + // "parameter" with the parameter of the IR function. + Status HandleParameter(HloInstruction* parameter) override; + + llvm::Function* GetEmittedFunction() const { return emitted_function_; } + + Status EmitTargetElementLoop( + const HloInstruction& hlo, + const llvm_ir::ElementGenerator& body_emitter) override; + + private: + llvm::Function* EmitBasePointersForNestedComputation( + const HloComputation& nested_computation, + std::vector* io_hlos); + + llvm::Function* emitted_function_; +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_NESTED_H_ diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index fc8783e753d3819ee7a35b2ad660a25eafc42f76..c81dfbf6c2a34aeb6d92ded23b8e264ebec30d54 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -17,6 +17,8 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h" + #include "llvm/ADT/StringRef.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Function.h" @@ -32,6 +34,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h" #include "tensorflow/compiler/xla/service/gpu/copy_thunk.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h" +#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h" #include "tensorflow/compiler/xla/service/gpu/fft_thunk.h" #include "tensorflow/compiler/xla/service/gpu/for_thunk.h" #include "tensorflow/compiler/xla/service/gpu/gemm_thunk.h" @@ -39,7 +42,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h" #include "tensorflow/compiler/xla/service/gpu/infeed_thunk.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" -#include "tensorflow/compiler/xla/service/gpu/ir_emitter.h" #include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h" #include "tensorflow/compiler/xla/service/gpu/kernel_thunk.h" #include "tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h" @@ -278,10 +280,6 @@ Status IrEmitterUnnested::HandleConditional(HloInstruction* conditional) { } Status IrEmitterUnnested::HandleConvolution(HloInstruction* convolution) { - if (ImplementedAsDnnConvolution(*convolution)) { - thunk_sequence_->emplace_back(BuildConvolutionThunk(convolution)); - return Status::OK(); - } thunk_sequence_->emplace_back(BuildKernelThunk(convolution)); return IrEmitter::HandleConvolution(convolution); } @@ -380,6 +378,71 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) { return Status::OK(); } + if (IsCustomCallToDnnConvolution(*custom_call)) { + const auto& assn = ir_emitter_context_->buffer_assignment(); + const auto& lhs_shape = custom_call->operand(0)->shape(); + const auto& rhs_shape = custom_call->operand(1)->shape(); + const auto& conv_result_shape = custom_call->shape().tuple_shapes(0); + auto lhs_slice = GetAllocationSlice(*custom_call->operand(0)); + auto rhs_slice = GetAllocationSlice(*custom_call->operand(1)); + auto tuple_result_slice = GetAllocationSlice(*custom_call); + auto conv_result_slice = assn.GetUniqueSlice(custom_call, {0}).ValueOrDie(); + auto scratch_slice = assn.GetUniqueSlice(custom_call, {1}).ValueOrDie(); + + const HloInstruction* algorithm_inst = custom_call->operand(2); + CHECK(algorithm_inst->IsConstant()) << algorithm_inst->ToString(); + int64 algorithm = algorithm_inst->literal().Get({}); + + const auto& target = custom_call->custom_call_target(); + std::unique_ptr thunk; + if (target == kCudnnConvForwardCallTarget) { + thunk = MakeUnique( + CudnnConvKind::kForward, + /*input_buffer=*/lhs_slice, + /*filter_buffer=*/rhs_slice, + /*output_buffer=*/conv_result_slice, + /*tuple_result_buffer=*/tuple_result_slice, + /*scratch_buffer=*/scratch_slice, + /*input_shape=*/lhs_shape, + /*filter_shape=*/rhs_shape, + /*output_shape=*/conv_result_shape, // + custom_call->window(), custom_call->convolution_dimension_numbers(), + algorithm, custom_call); + } else if (target == kCudnnConvBackwardInputCallTarget) { + thunk = MakeUnique( + CudnnConvKind::kBackwardInput, + /*input_buffer=*/conv_result_slice, + /*filter_buffer=*/rhs_slice, + /*output_buffer=*/lhs_slice, + /*tuple_result_buffer=*/tuple_result_slice, + /*scratch_buffer=*/scratch_slice, + /*input_shape=*/conv_result_shape, + /*filter_shape=*/rhs_shape, + /*output_shape=*/lhs_shape, // + custom_call->window(), custom_call->convolution_dimension_numbers(), + algorithm, custom_call); + } else if (target == kCudnnConvBackwardFilterCallTarget) { + thunk = MakeUnique( + CudnnConvKind::kBackwardFilter, + /*input_buffer=*/lhs_slice, + /*filter_buffer=*/conv_result_slice, + /*output_buffer=*/rhs_slice, + /*tuple_result_buffer=*/tuple_result_slice, + /*scratch_buffer=*/scratch_slice, + /*input_shape=*/lhs_shape, + /*filter_shape=*/conv_result_shape, + /*output_shape=*/rhs_shape, // + custom_call->window(), custom_call->convolution_dimension_numbers(), + algorithm, custom_call); + } else { + LOG(FATAL) << "Unexpected custom call target: " + << custom_call->custom_call_target(); + } + + thunk_sequence_->emplace_back(std::move(thunk)); + return Status::OK(); + } + return IrEmitter::HandleCustomCall(custom_call); } @@ -500,10 +563,6 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { thunk_sequence_->emplace_back(BuildGemmThunk(fusion)); return Status::OK(); } - if (ImplementedAsDnnConvolution(*fusion)) { - thunk_sequence_->emplace_back(BuildConvolutionThunk(fusion)); - return Status::OK(); - } thunk_sequence_->emplace_back(BuildKernelThunk(fusion)); return IrEmitter::HandleFusion(fusion); } @@ -1599,24 +1658,24 @@ Status IrEmitterUnnested::HandleReduce(HloInstruction* reduce) { } Status IrEmitterUnnested::HandleTuple(HloInstruction* tuple) { - tensorflow::gtl::ArraySlice operands(tuple->operands()); - bool all_tuple_elements_have_buffer = std::all_of( - operands.begin(), operands.end(), [this](HloInstruction* tuple_element) { + bool all_tuple_elements_have_buffer = + c_all_of(tuple->operands(), [&](HloInstruction* tuple_element) { return ir_emitter_context_->buffer_assignment().HasTopLevelAllocation( tuple_element); }); - // Tuples (especially output tuples) can take too many tuple elements, - // causing the kernel emitted exceeds the parameter space limit - // (b/31336476). As an optimization, if all tuple elements have a buffer, we - // collect their buffer addresses in a host array, and then copy that array - // to the tuple's buffer. + // Tuples (especially tuples that are the final result of a computation) can + // be so huge that if we were to emit a kernel that took each tuple element as + // a parameter, we would exceed the max allowable number of parameters to a + // GPU kernel, b/31336476. As an optimization, if all tuple elements have a + // buffer, we collect their buffer addresses in a host array, and then copy + // that array to the tuple's buffer. // // Some tuple elements (e.g. const or bitcast of const) might not have a - // buffer -- their contents are stored in code. In that case, we fall back - // to emitting kernels which have access to their buffer addresses in code. + // buffer -- their contents are stored in code. In that case, we fall back to + // emitting kernels which have access to their buffer addresses in code. if (all_tuple_elements_have_buffer) { std::vector tuple_element_buffers; - for (const HloInstruction* tuple_element : operands) { + for (const HloInstruction* tuple_element : tuple->operands()) { tuple_element_buffers.push_back(GetAllocationSlice(*tuple_element)); } thunk_sequence_->emplace_back(MakeUnique( @@ -1658,8 +1717,7 @@ Status IrEmitterUnnested::HandleSelectAndScatter( // TODO(b/31410564): Implement dilation rate for select-and-scatter. if (window_util::HasDilation(window)) { return Unimplemented( - "Dilation for select-and-scatter not implemented on GPU. " - "See b/31410564."); + "Dilation for SelectAndScatter not implemented on GPU."); } // kSelectAndScatter is implemented as two kernel launches: the first launch @@ -2012,52 +2070,6 @@ std::unique_ptr IrEmitterUnnested::BuildGemmThunk( LOG(FATAL) << "Cannot build a GemmThunk for " << inst->ToString(); } -std::unique_ptr IrEmitterUnnested::BuildConvolutionThunk( - const HloInstruction* inst) { - const HloInstruction* lhs = inst->operand(0); - const HloInstruction* rhs = inst->operand(1); - if (inst->opcode() == HloOpcode::kConvolution) { - // Forward covolution. - return MakeUnique( - ConvolutionThunk::ConvolutionKind::kForward, - /*input_buffer=*/GetAllocationSlice(*lhs), - /*filter_buffer=*/GetAllocationSlice(*rhs), - /*output_buffer=*/GetAllocationSlice(*inst), - /*input_shape=*/lhs->shape(), - /*filter_shape=*/rhs->shape(), - /*output_shape=*/inst->shape(), inst->window(), - inst->convolution_dimension_numbers(), inst); - } - - // Backward filter convolution, which takes the input (activations) and the - // gradients, and computes the filter. - CHECK_EQ(HloOpcode::kFusion, inst->opcode()); - switch (inst->fusion_kind()) { - case HloInstruction::FusionKind::kConvBackwardFilter: - return MakeUnique( - ConvolutionThunk::ConvolutionKind::kBackwardFilter, - /*input_buffer=*/GetAllocationSlice(*lhs), - /*filter_buffer=*/GetAllocationSlice(*inst), - /*output_buffer=*/GetAllocationSlice(*rhs), - /*input_shape=*/lhs->shape(), - /*filter_shape=*/inst->shape(), - /*output_shape=*/rhs->shape(), inst->window(), - inst->convolution_dimension_numbers(), inst); - case HloInstruction::FusionKind::kConvBackwardInput: - return MakeUnique( - ConvolutionThunk::ConvolutionKind::kBackwardInput, - /*input_buffer=*/GetAllocationSlice(*inst), - /*filter_buffer=*/GetAllocationSlice(*rhs), - /*output_buffer=*/GetAllocationSlice(*lhs), - /*input_shape=*/inst->shape(), - /*filter_shape=*/rhs->shape(), - /*output_shape=*/lhs->shape(), inst->window(), - inst->convolution_dimension_numbers(), inst); - default: - LOG(FATAL) << "Not a convolution-fusion"; - } -} - std::unique_ptr IrEmitterUnnested::BuildFftThunk( const HloInstruction* inst) { const HloInstruction* operand = inst->operand(0); @@ -2259,6 +2271,8 @@ std::unique_ptr IrEmitterUnnested::BuildConditionalThunk( Status IrEmitterUnnested::EmitTargetElementLoopInThunk( const HloInstruction& hlo, const llvm_ir::ElementGenerator& element_generator, KernelThunk* thunk) { + VLOG(3) << bindings_.ToString(); + const Shape& element_shape = hlo.IsMultiOutputFusion() ? ShapeUtil::GetSubshape(hlo.shape(), {0}) : hlo.shape(); diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h new file mode 100644 index 0000000000000000000000000000000000000000..56ab8208cee6f53afce365baa213fd2f5a6425a0 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h @@ -0,0 +1,209 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_UNNESTED_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_UNNESTED_H_ + +#include "tensorflow/compiler/xla/service/gpu/ir_emitter.h" +#include "tensorflow/compiler/xla/service/gpu/thunk.h" + +namespace xla { +namespace gpu { + +// Emits LLVM IR for an "unnested computation". +// +// An unnested computation is an HloComputation which you run by executing one +// or more kernels for each HloInstruction it contains. Examples of unnested +// computations: +// +// - An HloModule's root computation, +// - The body of an HLO while loop, +// - The true/false computation of an HLO conditional. +// +// Note the opportunity for confusion -- the while loop's computation is nested +// within the root computation, but it's emitted using IrEmitterUnnested! Don't +// think about it too hard. +// +// Examples of things that are not unnested computations: +// +// - The reducer of a kReduce HLO. This is emited using IrEmitterNested. +// - The body of a fusion node. IrEmitterUnenested emits the relevant code +// within a kernel function using FusedIrEmitter. (FusedIrEmitter is not +// really an IrEmitter, but is more an "IR generator generator".) +// +class IrEmitterUnnested : public IrEmitter { + public: + IrEmitterUnnested(const HloModuleConfig& hlo_module_config, + const HloComputation* hlo_computation, + IrEmitterContext* ir_emitter_context); + IrEmitterUnnested(const IrEmitterUnnested&) = delete; + IrEmitterUnnested& operator=(const IrEmitterUnnested&) = delete; + + // Transfers the ownship of thunk_sequence_ out. + std::unique_ptr ConsumeThunkSequence() { + return std::move(thunk_sequence_); + } + + Status DefaultAction(HloInstruction* hlo) override; + + // IrEmitterUnnested handles the following instructions differently from + // IrEmitter. + Status HandleCopy(HloInstruction* copy) override; + Status HandleConditional(HloInstruction* conditional) override; + Status HandleConvolution(HloInstruction* convolution) override; + Status HandleCustomCall(HloInstruction* custom_call) override; + Status HandleDot(HloInstruction* dot) override; + Status HandleFft(HloInstruction* fft) override; + Status HandleFusion(HloInstruction* fusion) override; + Status HandleGetTupleElement(HloInstruction* get_tuple_element) override; + Status HandleReduce(HloInstruction* reduce) override; + Status HandleSelectAndScatter(HloInstruction* instruction) override; + Status HandleTuple(HloInstruction* tuple) override; + Status HandleWhile(HloInstruction* xla_while) override; + Status HandleInfeed(HloInstruction* xla_infeed) override; + Status HandleRng(HloInstruction* random) override; + Status HandleSelect(HloInstruction* select) override; + + Status EmitTargetElementLoop( + const HloInstruction& hlo, + const llvm_ir::ElementGenerator& body_emitter) override; + + // Same as `EmitTargetElementLoop`, but in given `thunk` rather than + // `LastThunk()`. + Status EmitTargetElementLoopInThunk( + const HloInstruction& hlo, const llvm_ir::ElementGenerator& body_emitter, + KernelThunk* thunk); + + private: + // Builds the appropriate thunk for the instruction hlo and returns the owning + // pointer to it. The caller needs to make sure `inst` outlives the lifetime + // of the returned Thunk object. + std::unique_ptr BuildThunk(const HloInstruction* hlo); + + // Builds the prototype of the IR kernel for `inst` and adds it to the module. + llvm::Function* BuildKernelPrototype( + const HloInstruction& inst, + tensorflow::gtl::ArraySlice escaped_hlos); + + // Emits the base pointers for `hlo` and its operands. `io_hlos` will store + // all input/output HLOs among `hlo` and its operands. + llvm::Function* EmitBasePointersForHloAndItsOperands( + const HloInstruction& hlo, std::vector* io_hlos); + + // EmitColumnReduction and EmitRowReduction emit code for column and row + // reduction of a matrix and/or 3D tensor. Row and column reduction have + // different memory access pattern, so for performance their implementations + // are significantly different. + // + // Emits code that reduces a matrix of shape [height x width] to a vector of + // [width]. Other parameters have the same meaning as those of + // `EmitReductionToVector`. Note that input shape might not be + // [height x width], but can be bitcast to [height x weight] with "height" + // being the major dimension. + Status EmitColumnReduction(int64 height, int64 width, HloInstruction* reduce, + const Shape& input_shape, + const llvm_ir::ElementGenerator& input_gen, + const llvm_ir::ElementGenerator& init_value_gen, + HloComputation* reducer); + + // Emits code that reduces a 3D tensor of shape [depth x height x width] to a + // vector of shape [height]. Other parameters have the same meaning as those + // of `EmitReductionToVector`. Note that input shape might not be + // [depth x height x width], but can be bitcast to [depth x height x weight] + // with "depth" being the most major dimension. + Status EmitRowReduction(int64 depth, int64 height, int64 width, + HloInstruction* reduce, const Shape& input_shape, + const llvm_ir::ElementGenerator& input_gen, + const llvm_ir::ElementGenerator& init_value_gen, + HloComputation* reducer); + + // Emits code that reduces a tensor of arbitrary rank to a scalar. + Status EmitReductionToScalar(HloInstruction* reduce, const Shape& input_shape, + const llvm_ir::ElementGenerator& input_gen, + const llvm_ir::ElementGenerator& init_value_gen, + HloComputation* reducer); + + // Figures out whether `reduce` is a row or column reduction, and which + // dimensions to reduce, and calls either `EmitRowReduction` or + // `EmitColumnReduction` as appropriate. `input_shape` is the shape of the + // input array, which is the operand of the Reduce instruction if unfused or + // of the Fusion instruction if fused. `input_gen` and `init_value_gen` + // generate elements of the input and the initial value. Other parameters mean + // the same as for `HandleReduce`. + // + // Prerequisite: `IsReductionToVector(*reduce)` + Status EmitReductionToVector( + HloInstruction* reduce, const Shape& input_shape, + const llvm_ir::ElementGenerator& input_gen, + const llvm_ir::ElementGenerator& init_value_gen, + tensorflow::gtl::ArraySlice dimensions_to_reduce, + HloComputation* reducer); + + // Emits code to initialize buffer of `inst` in given `thunk`. + Status EmitInitializer(const HloInstruction* inst, KernelThunk* thunk); + + // Returns a KernelThunk that invokes the kernel emitted for `inst`. The + // caller needs to make sure `inst` outlives the lifetime of the returned + // Thunk object. + std::unique_ptr BuildKernelThunk(const HloInstruction* inst); + + // Returns a FftThunk that calls cuFFT to implement `inst`. + std::unique_ptr BuildFftThunk(const HloInstruction* inst); + + // Returns a GemmThunk that calls gemm to implement `inst`. The caller needs + // to make sure `inst` outlives the lifetime of the returned Thunk object. + std::unique_ptr BuildGemmThunk(const HloInstruction* inst); + + // Returns a thunk that calls host-to-device cuMemcpy to implement `inst`. + std::unique_ptr BuildHostToDeviceCopyThunk(const HloInstruction* inst); + + // Returns a thunk that calls device-to-device cuMemcpy to implement `inst`. + std::unique_ptr BuildDeviceToDeviceCopyThunk( + const HloInstruction* inst); + + // Returns an InfeedThunk that performs device-to-device memcpy to implement + // `inst`. + std::unique_ptr BuildInfeedThunk(const HloInstruction* inst); + + // Returns a WhileThunk that invokes thunk sequences for 'condition' and + // 'body' sub-computations of while instruction 'hlo'. + std::unique_ptr BuildWhileThunk(const HloInstruction* hlo); + + // Returns a ForThunk which executes 'loop_limit' invocations of a thunk + // sequence from the 'body' sub-computation of the while instruction 'hlo'. + std::unique_ptr BuildForThunk(const HloInstruction* hlo, + const int64 loop_limit); + + // Returns a ConditionalThunk that executes the thunk sequence for + // 'true_computation' or 'false_computation' depending on the value of the + // predicate in the given conditional instruction. + std::unique_ptr BuildConditionalThunk(const HloInstruction* hlo); + + Status Postprocess(HloInstruction* hlo) override; + + // Returns the last generated thunk. + Thunk* LastThunk() const { return thunk_sequence_->back().get(); } + + // The thunk sequence this IrEmitter generates for the input computation. + std::unique_ptr thunk_sequence_; + + // The HloComputation that this IrEmitter emits code for. + const HloComputation* hlo_computation_; +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_UNNESTED_H_ diff --git a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc index 2923a79af0a559b08a2126162130a83801d024f8..25846dc6cd4633c7becb6e62d6bc9585348a6eac 100644 --- a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc +++ b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc @@ -27,7 +27,7 @@ namespace gpu { namespace { bool IsForwardConvolutionCanonical(const HloInstruction& conv) { - CHECK_EQ(HloOpcode::kConvolution, conv.opcode()); + CHECK_EQ(conv.custom_call_target(), kCudnnConvForwardCallTarget); return window_util::HasSymmetricPadding(conv.window()) && !window_util::HasNegativePadding(conv.window()) && !window_util::HasDilation(conv.window()); @@ -47,6 +47,12 @@ HloInstruction* MaybePaddedAndSlicedInput( window_util::HasBaseDilation(conv_window)) { // If padding is uneven or has dilation, we insert a kPad instruction that // applies positive padding and dilation. + // + // TODO(phawkins): If conv_window has asymmetric padding, perhaps instead of + // moving all the padding into an explicit pad op, we should keep as much + // padding inside of cudnn as possible, on the assumption that padding + // within cudnn is basically free, whereas a kPad's cost increases as the + // amount of padding increases. PaddingConfig padding_config = MakeNoPaddingConfig(input->shape().dimensions_size()); for (size_t i = 0; i < conv_dnums.input_spatial_dimensions().size(); ++i) { @@ -167,14 +173,17 @@ bool PadInsertion::CanonicalizeForwardConvolution(HloInstruction* conv) { dim->set_window_dilation(1); } + // The conv CustomCall returns a tuple (conv_result, scratch_buffer). Extract + // out the shape of conv_result. + Shape old_conv_shape = conv->shape().tuple_shapes(0); + VLOG(1) << "Canonicalizing forward conv"; - auto new_conv = HloInstruction::CreateConvolve( - conv->shape(), new_input, new_kernel, new_conv_window, - conv->convolution_dimension_numbers()); + auto new_conv = CreateCudnnConvForward(old_conv_shape, new_input, new_kernel, + new_conv_window, + conv->convolution_dimension_numbers()); VLOG(1) << "Replacing:\n " << conv->ToString() << "\nwith:\n " << new_conv->ToString(); - TF_CHECK_OK( - conv->parent()->ReplaceWithNewInstruction(conv, std::move(new_conv))); + TF_CHECK_OK(conv->parent()->ReplaceInstruction(conv, new_conv)); return true; } @@ -190,6 +199,8 @@ void IncreasePaddingHighBy(int64 delta, WindowDimension* window_dim) { bool PadInsertion::CanonicalizeBackwardFilterConvolution( HloInstruction* backward_conv) { + CHECK_EQ(backward_conv->custom_call_target(), + kCudnnConvBackwardFilterCallTarget); if (window_util::HasSymmetricPadding(backward_conv->window())) { return false; } @@ -202,15 +213,11 @@ bool PadInsertion::CanonicalizeBackwardFilterConvolution( // ABCD0 = Pad(ABCD, padding_high=1) // BackwardFilterConv(ABCD0, xyz, padding_low=pading_high=1) // We choose the lesser of padding_low and padding_high as the new padding. - HloInstruction* forward_conv = backward_conv->fused_expression_root(); HloInstruction* input = backward_conv->mutable_operand(0); - Window new_forward_conv_window = forward_conv->window(); Window new_backward_conv_window = backward_conv->window(); // input_padding_config is the config of the kPad to be inserted. PaddingConfig input_padding_config = MakeNoPaddingConfig(ShapeUtil::Rank(input->shape())); - ConvolutionDimensionNumbers forward_conv_dnums = - forward_conv->convolution_dimension_numbers(); ConvolutionDimensionNumbers backward_conv_dnums = backward_conv->convolution_dimension_numbers(); for (size_t i = 0; i < backward_conv->window().dimensions_size(); ++i) { @@ -222,11 +229,7 @@ bool PadInsertion::CanonicalizeBackwardFilterConvolution( // cuDNN convolution (which doesn't support negative padding) to fail. return false; } - // If the backward convolution has uneven padding on the activations, we - // move some padding on the larger end to "internal" padding, so that the - // backward convolution produces larger weight gradients which get sliced - // later. Therefore, the amount of new padding (low or high) is the minimum - // of the amount of old padding low and old padding high. + // Compute the new, even padding for the backward conv operation. int64 new_conv_padding = std::min(padding_low, padding_high); int64 dim = backward_conv_dnums.input_spatial_dimensions(i); input_padding_config.mutable_dimensions(dim)->set_edge_padding_low( @@ -237,14 +240,9 @@ bool PadInsertion::CanonicalizeBackwardFilterConvolution( // Since we move some padding from the backward convolution to the kPad, we // need to accordingly reduce the padding amount of the backward convolution // and its inner forward convolution. - IncreasePaddingLowBy(-(padding_low - new_conv_padding), - new_backward_conv_window.mutable_dimensions(i)); - IncreasePaddingHighBy(-(padding_high - new_conv_padding), - new_backward_conv_window.mutable_dimensions(i)); - IncreasePaddingLowBy(-(padding_low - new_conv_padding), - new_forward_conv_window.mutable_dimensions(i)); - IncreasePaddingHighBy(-(padding_high - new_conv_padding), - new_forward_conv_window.mutable_dimensions(i)); + auto* new_dim = new_backward_conv_window.mutable_dimensions(i); + new_dim->set_padding_low(new_conv_padding); + new_dim->set_padding_high(new_conv_padding); } // Create a new backward convolution replacing the old one. @@ -260,19 +258,12 @@ bool PadInsertion::CanonicalizeBackwardFilterConvolution( .ConsumeValueOrDie(), input, padding, input_padding_config)); - HloInstruction* new_forward_conv = - computation->AddInstruction(HloInstruction::CreateConvolve( - ShapeInference::InferConvolveShape( - padded_input->shape(), output->shape(), new_forward_conv_window, - forward_conv_dnums) - .ConsumeValueOrDie(), - padded_input, output, new_forward_conv_window, forward_conv_dnums)); - - // Fuse the new forward convolution to the new backward convolution. - HloInstruction* new_backward_conv = - computation->CreateFusionInstructionForBackwardConvolution( - {new_forward_conv}, HloInstruction::FusionKind::kConvBackwardFilter, - new_backward_conv_window, backward_conv_dnums); + // The shape of the backward_conv CustomCall is a tuple (conv_result, + // scratch_buffer). Extract out the shape of conv_result. + Shape backward_conv_shape = backward_conv->shape().tuple_shapes(0); + HloInstruction* new_backward_conv = CreateCudnnConvBackwardFilter( + backward_conv_shape, padded_input, output, new_backward_conv_window, + backward_conv_dnums); VLOG(1) << "Canonicalizing backward filter conv"; VLOG(1) << "Replacing:\n " << backward_conv->ToString() << "\nwith:\n " @@ -289,14 +280,15 @@ bool PadInsertion::CanonicalizeBackwardInputConvolution( return false; } - HloInstruction* forward_conv = backward_conv->fused_expression_root(); - HloInstruction* reverse_filter = forward_conv->mutable_operand(1); - Window new_forward_conv_window = forward_conv->window(); Window new_backward_conv_window = backward_conv->window(); - ConvolutionDimensionNumbers forward_conv_dnums = - forward_conv->convolution_dimension_numbers(); ConvolutionDimensionNumbers backward_conv_dnums = backward_conv->convolution_dimension_numbers(); + + // The backward_conv CustomCall returns a tuple (conv_result, scratch_memory). + // Get the shape of conv_result. + Shape backward_conv_shape = backward_conv->shape().tuple_shapes(0); + + Shape new_backward_conv_shape = backward_conv_shape; for (size_t i = 0; i < backward_conv->window().dimensions_size(); ++i) { int64 padding_low = backward_conv->window().dimensions(i).padding_low(); int64 padding_high = backward_conv->window().dimensions(i).padding_high(); @@ -315,41 +307,38 @@ bool PadInsertion::CanonicalizeBackwardInputConvolution( // where the amount of padding low is larger, we can canonicalize it to // [B A] = BackwardInputConvolve([a b], [x y z], padding=(low=1,high=1)) // [A] = Slice([B A]) - // For consistency, we need to increase the low padding of the inner - // convolution by 1 as well because the input is larger now. if (padding_low > padding_high) { IncreasePaddingLowBy(padding_high - padding_low, new_backward_conv_window.mutable_dimensions(i)); - IncreasePaddingLowBy(padding_low - padding_high, - new_forward_conv_window.mutable_dimensions(i)); } else if (padding_low < padding_high) { IncreasePaddingHighBy(padding_low - padding_high, new_backward_conv_window.mutable_dimensions(i)); - IncreasePaddingHighBy(padding_high - padding_low, - new_forward_conv_window.mutable_dimensions(i)); } + // Decreasing the padding by X *increases* the size of our output by X. + int64 dim = backward_conv_dnums.output_spatial_dimensions(i); + new_backward_conv_shape.set_dimensions( + dim, new_backward_conv_shape.dimensions(dim) + + std::abs(padding_low - padding_high)); } // Create a new backward convolution replacing the old one. HloComputation* computation = backward_conv->parent(); HloInstruction* output = backward_conv->mutable_operand(0); HloInstruction* filter = backward_conv->mutable_operand(1); - HloInstruction* new_reverse_filter = - computation->AddInstruction(HloInstruction::CreateReverse( - filter->shape(), filter, reverse_filter->dimensions())); - HloInstruction* new_forward_conv = - computation->AddInstruction(HloInstruction::CreateConvolve( - ShapeInference::InferConvolveShape( - output->shape(), new_reverse_filter->shape(), - new_forward_conv_window, forward_conv_dnums) - .ConsumeValueOrDie(), - output, new_reverse_filter, new_forward_conv_window, - forward_conv_dnums)); + + HloInstruction* new_backward_conv_call = CreateCudnnConvBackwardInput( + new_backward_conv_shape, output, filter, new_backward_conv_window, + backward_conv_dnums); + + // The CustomCall created above returns a tuple (conv_result, scratch_memory). + // Extract out the two elements. HloInstruction* new_backward_conv = - computation->CreateFusionInstructionForBackwardConvolution( - {new_forward_conv, new_reverse_filter}, - HloInstruction::FusionKind::kConvBackwardInput, - new_backward_conv_window, backward_conv_dnums); + computation->AddInstruction(HloInstruction::CreateGetTupleElement( + new_backward_conv_shape, new_backward_conv_call, 0)); + HloInstruction* new_backward_conv_scratch = + computation->AddInstruction(HloInstruction::CreateGetTupleElement( + new_backward_conv_call->shape().tuple_shapes(1), + new_backward_conv_call, 1)); // Slice the new backward convolution. // @@ -377,22 +366,25 @@ bool PadInsertion::CanonicalizeBackwardInputConvolution( } // Replace the old backward convolution with the slice. - CHECK(ShapeUtil::Compatible( + Shape slice_shape = ShapeInference::InferSliceShape(new_backward_conv->shape(), start_indices, limit_indices, strides) - .ConsumeValueOrDie(), - backward_conv->shape())); + .ConsumeValueOrDie(); + CHECK(ShapeUtil::Compatible(slice_shape, backward_conv_shape)) + << ShapeUtil::HumanString(slice_shape) << " vs " + << ShapeUtil::HumanString(backward_conv_shape); - auto slice = - HloInstruction::CreateSlice(backward_conv->shape(), new_backward_conv, - start_indices, limit_indices, strides); + HloInstruction* slice = computation->AddInstruction( + HloInstruction::CreateSlice(backward_conv_shape, new_backward_conv, + start_indices, limit_indices, strides)); + HloInstruction* new_tuple = computation->AddInstruction( + HloInstruction::CreateTuple({slice, new_backward_conv_scratch})); VLOG(1) << "Canonicalizing backward input conv"; VLOG(1) << "Replacing:\n " << backward_conv->ToString() << "\nwith:\n " - << slice->ToString(); + << new_tuple->ToString(); - TF_CHECK_OK( - computation->ReplaceWithNewInstruction(backward_conv, std::move(slice))); + TF_CHECK_OK(computation->ReplaceInstruction(backward_conv, new_tuple)); return true; } @@ -400,18 +392,17 @@ StatusOr PadInsertion::Run(HloModule* module) { bool changed = false; for (HloInstruction* instruction : module->entry_computation()->MakeInstructionPostOrder()) { - if (instruction->opcode() == HloOpcode::kConvolution) { - changed |= CanonicalizeForwardConvolution(instruction); - } else if (instruction->opcode() == HloOpcode::kFusion) { - switch (instruction->fusion_kind()) { - case HloInstruction::FusionKind::kConvBackwardFilter: - changed |= CanonicalizeBackwardFilterConvolution(instruction); - break; - case HloInstruction::FusionKind::kConvBackwardInput: - changed |= CanonicalizeBackwardInputConvolution(instruction); - break; - default: - break; + if (IsCustomCallToDnnConvolution(*instruction)) { + const auto& target = instruction->custom_call_target(); + if (target == kCudnnConvForwardCallTarget) { + changed |= CanonicalizeForwardConvolution(instruction); + } else if (target == kCudnnConvBackwardFilterCallTarget) { + changed |= CanonicalizeBackwardFilterConvolution(instruction); + } else if (target == kCudnnConvBackwardInputCallTarget) { + changed |= CanonicalizeBackwardInputConvolution(instruction); + } else { + LOG(FATAL) << "Unknown custom call target for cudnn conv: " + << instruction->ToString(); } } } diff --git a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h index 934e7e1919f08a16daf09ec634e2f9dc0c7cc723..8ed63a854a74fc06c3c389f40fe1f5970885deac 100644 --- a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h @@ -42,6 +42,11 @@ class ParallelLoopEmitter : public llvm_ir::LoopEmitter { const LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* ir_builder); + // Constructs a loop emitter for a loop that generates on element of each of N + // arrays on each iteration. + // + // This is used in multi-output fusion. target_element_generator should + // produce a struct with N elements, one for each of target_arrays. ParallelLoopEmitter( const llvm_ir::ElementGenerator& target_element_generator, tensorflow::gtl::ArraySlice target_arrays, diff --git a/tensorflow/compiler/xla/service/heap_simulator.cc b/tensorflow/compiler/xla/service/heap_simulator.cc index 34e2f7ee206c6a74073d8f4e867e862feb4aff49..cde5877e29f36abc61c5417ce960e2c7699e2749 100644 --- a/tensorflow/compiler/xla/service/heap_simulator.cc +++ b/tensorflow/compiler/xla/service/heap_simulator.cc @@ -64,10 +64,8 @@ StatusOr HeapSimulator::Run( std::unique_ptr algorithm, const HloModule& module, const SequentialHloOrdering::HloModuleSequence& module_sequence, const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_fn, - const FlatSet* buffers_to_assign) { - HeapSimulator heap(std::move(algorithm), size_fn, buffers_to_assign, - &module_sequence); + const LogicalBuffer::SizeFunction& size_fn, const Options& options) { + HeapSimulator heap(std::move(algorithm), size_fn, options, &module_sequence); const HloComputation* entry_computation = module.entry_computation(); const std::vector& instruction_sequence = FindOrDie(module_sequence, entry_computation); @@ -81,9 +79,8 @@ StatusOr HeapSimulator::Run( std::unique_ptr algorithm, const HloComputation& computation, const std::vector& instruction_sequence, const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_fn, - const FlatSet* buffers_to_assign) { - HeapSimulator heap(std::move(algorithm), size_fn, buffers_to_assign, + const LogicalBuffer::SizeFunction& size_fn, const Options& options) { + HeapSimulator heap(std::move(algorithm), size_fn, options, /*module_sequence=*/nullptr); TF_RETURN_IF_ERROR(heap.RunComputation(computation, instruction_sequence, points_to_analysis)); @@ -199,15 +196,17 @@ Status HeapSimulator::RunComputation( // We can only share with the operand buffer if it is about to be freed; // we must be the last user of the buffer. bool shared = false; - for (const LogicalBuffer* operand_buffer : operand_buffers_to_free) { - if (buffer->instruction()->IsUserOf(operand_buffer->instruction()) && - buffer->instruction()->opcode() != HloOpcode::kCopy && - CanShareOperandBufferWithUser( - operand_buffer->instruction(), operand_buffer->index(), - buffer->instruction(), buffer->index(), points_to_analysis)) { - ShareBuffer(buffer, operand_buffer, instruction); - shared = true; - break; + if (options_.may_reuse_operand_buffers) { + for (const LogicalBuffer* operand_buffer : operand_buffers_to_free) { + if (buffer->instruction()->IsUserOf(operand_buffer->instruction()) && + buffer->instruction()->opcode() != HloOpcode::kCopy && + CanShareOperandBufferWithUser( + operand_buffer->instruction(), operand_buffer->index(), + buffer->instruction(), buffer->index(), points_to_analysis)) { + ShareBuffer(buffer, operand_buffer, instruction); + shared = true; + break; + } } } @@ -266,13 +265,12 @@ Status HeapSimulator::RunComputation( HeapSimulator::HeapSimulator( std::unique_ptr algorithm, - const LogicalBuffer::SizeFunction& size_fn, - const FlatSet* buffers_to_assign, + const LogicalBuffer::SizeFunction& size_fn, const Options& options, const SequentialHloOrdering::HloModuleSequence* module_sequence) : no_fragmentation_stats_(MakeUnique()), algorithm_(std::move(algorithm)), size_fn_(size_fn), - buffers_to_assign_(buffers_to_assign), + options_(options), module_sequence_(module_sequence) { debug_trace_.set_whole_module_simulation(module_sequence_ != nullptr); } @@ -280,13 +278,16 @@ HeapSimulator::HeapSimulator( HeapSimulator::~HeapSimulator() {} bool HeapSimulator::IgnoreBuffer(const LogicalBuffer* buffer) const { - // Buffers for constants are ignored, as with BufferAssigner. Also ignore - // buffers that we're not meant to assign. + // Buffers for constants are ignored unless the alloc_constants option is + // set. Also ignore buffers that we're not meant to assign. // // TODO(b/32248867): For consistency, constants should get allocations. - return buffer->instruction()->opcode() == HloOpcode::kConstant || - (buffers_to_assign_ != nullptr && - buffers_to_assign_->count(buffer) == 0); + if (!options_.alloc_constants && + buffer->instruction()->opcode() == HloOpcode::kConstant) { + return true; + } + return options_.buffers_to_assign != nullptr && + options_.buffers_to_assign->count(buffer) == 0; } // Alloc always calls the underlying heap algorithm. @@ -400,8 +401,8 @@ HeapSimulator::Result HeapSimulator::Finish() { } // If we were told to assign specific buffers, make sure we've assigned // exactly that many buffers. - if (buffers_to_assign_ != nullptr) { - CHECK_EQ(buffers_to_assign_->size(), result.chunk_map.size()); + if (options_.buffers_to_assign != nullptr) { + CHECK_EQ(options_.buffers_to_assign->size(), result.chunk_map.size()); } } diff --git a/tensorflow/compiler/xla/service/heap_simulator.h b/tensorflow/compiler/xla/service/heap_simulator.h index 88a8698d16132372fc8f4e87eba3b99125aab876..636f19dd39f09721bd82fc4b44785f196f281ad7 100644 --- a/tensorflow/compiler/xla/service/heap_simulator.h +++ b/tensorflow/compiler/xla/service/heap_simulator.h @@ -67,6 +67,23 @@ class HeapSimulator { HeapSimulatorTrace debug_trace; }; + // The different options to be passed to the Run() APIs. + struct Options { + Options() + : may_reuse_operand_buffers(true), + alloc_constants(false), + buffers_to_assign(nullptr) {} + + // Whether a buffer about to be Free()-ed, can be recycled for a new born + // one, hence collapsing Free()+Alloc() calls (default true). + bool may_reuse_operand_buffers; + // Whether to issue Alloc() and Free() calls for constants (default false). + bool alloc_constants; + // If 'buffers_to_assign' is provided, only those buffers are assigned + // offsets, otherwise all buffers defined by the instructions are assigned. + const tensorflow::gtl::FlatSet* buffers_to_assign; + }; + // Run the heap simulation with the given algorithm, assuming the given // module_sequence, which must contain a topologically-consistent total // ordering of all instructions within each computation. The result is invalid @@ -76,15 +93,12 @@ class HeapSimulator { // to running on a per-computation basis, since we can re-use buffer space for // called sub-computations. // - // If 'buffers_to_assign' is provided, only those buffers are assigned - // offsets, otherwise all buffers defined by the instructions are assigned. static StatusOr Run( std::unique_ptr algorithm, const HloModule& module, const SequentialHloOrdering::HloModuleSequence& module_sequence, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_fn, - const tensorflow::gtl::FlatSet* buffers_to_assign = - nullptr); + const Options& options = Options()); // Same as above, but runs on a single computation. The 'instruction_sequence' // must contain a topologically-consistent total ordering of all instructions @@ -96,8 +110,7 @@ class HeapSimulator { const std::vector& instruction_sequence, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_fn, - const tensorflow::gtl::FlatSet* buffers_to_assign = - nullptr); + const Options& options = Options()); private: // If 'module_sequence' is non-null, it is used to find kCall and kWhile @@ -105,8 +118,7 @@ class HeapSimulator { // be run recursively. I.e. the simulation is run over the whole module. HeapSimulator( std::unique_ptr algorithm, - const LogicalBuffer::SizeFunction& size_fn, - const tensorflow::gtl::FlatSet* buffers_to_assign, + const LogicalBuffer::SizeFunction& size_fn, const Options& options, const SequentialHloOrdering::HloModuleSequence* module_sequence); ~HeapSimulator(); @@ -130,7 +142,7 @@ class HeapSimulator { const std::unique_ptr no_fragmentation_stats_; const std::unique_ptr algorithm_; const LogicalBuffer::SizeFunction size_fn_; - const tensorflow::gtl::FlatSet* buffers_to_assign_; + const Options options_; const SequentialHloOrdering::HloModuleSequence* module_sequence_; // In addition to Alloc and Free, the heap simulator exposes a concept of diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index a63affa06caf75f1ccab084bd114e39ba7c91a38..5432419e4a2dd2916da32ac6566851bf52fd68ca 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -461,20 +461,6 @@ HloInstruction* HloComputation::CreateFusionInstruction( return fusion_instruction; } -HloInstruction* HloComputation::CreateFusionInstructionForBackwardConvolution( - tensorflow::gtl::ArraySlice instructions_to_fuse, - HloInstruction::FusionKind fusion_kind, const Window& window, - const ConvolutionDimensionNumbers& conv_dnums) { - CHECK(HloInstruction::FusionKind::kConvBackwardFilter == fusion_kind || - HloInstruction::FusionKind::kConvBackwardInput == fusion_kind); - HloInstruction* root = instructions_to_fuse.front(); - HloInstruction* fusion_instruction = - AddInstruction(HloInstruction::CreateFusionForBackwardConvolution( - root->shape(), fusion_kind, window, conv_dnums, root)); - FuseInstructionsInto(instructions_to_fuse, fusion_instruction); - return fusion_instruction; -} - StatusOr HloComputation::DeepCopyHelper( HloInstruction* instruction, const ShapeTree* indices_to_copy, ShapeTree* copies_added, ShapeIndex* index) { @@ -577,8 +563,11 @@ Status HloComputation::ReplaceWithNewInstruction( Status HloComputation::ReplaceInstruction(HloInstruction* old_instruction, HloInstruction* new_instruction) { - TF_RET_CHECK(ShapeUtil::Compatible(old_instruction->shape(), - new_instruction->shape())); + TF_RET_CHECK( + ShapeUtil::Compatible(old_instruction->shape(), new_instruction->shape())) + << ShapeUtil::HumanString(old_instruction->shape()) << " vs " + << ShapeUtil::HumanString(new_instruction->shape()); + VLOG(10) << "transformed " << old_instruction->ToString() << " to " << new_instruction->ToString(); // Try to add metadata for HLO instructions that are created to replace diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h index 6436815f910405477ec21a33dec75ef71df08602..061c59abe5e315917161ed737f89de53d71bb1b6 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.h +++ b/tensorflow/compiler/xla/service/hlo_computation.h @@ -224,15 +224,6 @@ class HloComputation { tensorflow::gtl::ArraySlice instructions_to_fuse, HloInstruction::FusionKind fusion_kind); - // Creates a fusion instruction that represents a backward convolution. This - // is similar to CreateFusionInstruction but takes window and conv_dnums which - // indicate the window and convolution dimension numbers of the backward - // convolution. - HloInstruction* CreateFusionInstructionForBackwardConvolution( - tensorflow::gtl::ArraySlice instructions_to_fuse, - HloInstruction::FusionKind fusion_kind, const Window& window, - const ConvolutionDimensionNumbers& conv_dnums); - // Create a deep copy of the given instruction and return the instruction // producing the copied result. All instructions performing the copy are added // to the computation. For array-shaped values, this method trivially returns diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc index cd54eb74d18d0be714b5b56fc8ae0dfa55ff31a0..9cd5a1e2b71a7aa768e478289e8e4cc13030fcc3 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc @@ -469,7 +469,13 @@ Status HloCostAnalysis::HandleCall(const HloInstruction* call) { } Status HloCostAnalysis::HandleCustomCall(const HloInstruction*) { - return Unimplemented("Custom-call is not implemented for HLO cost analysis."); + // We can't do anything sane with CustomCalls, since we don't know what they + // do, and returning an error status will stop iteration over this + // computation, which is probably also not what we want. So just punt and + // return OK. This will cause all of the properties to be reported as 0, + // which is fine. + current_should_compute_bottleneck_time_ = false; + return Status::OK(); } Status HloCostAnalysis::HandleSort(const HloInstruction* sort) { diff --git a/tensorflow/compiler/xla/service/hlo_cse.cc b/tensorflow/compiler/xla/service/hlo_cse.cc index 7feda2b3b040de1f0a14303ce1adcd21c6624c8b..279edd4ba8772a9c576f76f554de8ec68631b953 100644 --- a/tensorflow/compiler/xla/service/hlo_cse.cc +++ b/tensorflow/compiler/xla/service/hlo_cse.cc @@ -119,9 +119,8 @@ StatusOr HloCSE::Run(HloModule* module) { equivalent_instructions; for (HloInstruction* user : operand->users()) { if (user != instruction && - user->Identical(*instruction, eq_instructions, eq_computations) && - (!is_layout_sensitive_ || - ShapeUtil::Equal(user->shape(), instruction->shape()))) { + user->Identical(*instruction, eq_instructions, eq_computations, + is_layout_sensitive_)) { equivalent_instructions.push_back(user); } } diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index c744c8ed81ad991e70a665a9eda469e23054e83d..44fcd36370dcd0cf77601aa1cd2b92810947bd5f 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -1426,9 +1426,11 @@ void DumpText(const HloModule& module, const string& label, string MaybeDumpHloModule(const HloModule& module, const string& label, const HloExecutionProfile* profile) { - VLOG(2) << "MaybeDumpHloModule called on module " << module.name(); - string graph_url; const DebugOptions& debug_options = module.config().debug_options(); + VLOG(2) << "MaybeDumpHloModule called on module " << module.name() + << " with generate_hlo_graph regex \"" + << debug_options.xla_generate_hlo_graph() << "\""; + string graph_url; if (!debug_options.xla_generate_hlo_graph().empty() && RE2::PartialMatch(module.name(), debug_options.xla_generate_hlo_graph())) { diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index a889c35aeb297bd118c40ced2dd9539957dce67a..277648f07206ecf28479c3f63521732f8f6d8e0f 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -763,16 +763,13 @@ HloInstruction::CreateBroadcastSequence( return instruction; } -// We put the fusion kind into the instruction's name for transpose-dot and -// backward-conv fusions, since those fusions are really just describing a type -// of dot/conv rather than generating a novel computation. +// We put the fusion kind into the instruction's name for transpose-dot fusions, +// since those fusions are really just describing a type of dot rather than +// generating a novel computation. static string FusionNodeName(HloInstruction::FusionKind fusion_kind) { switch (fusion_kind) { case HloInstruction::FusionKind::kTransposeDot: return "dot_fusion"; - case HloInstruction::FusionKind::kConvBackwardInput: - case HloInstruction::FusionKind::kConvBackwardFilter: - return "conv_fusion"; default: return "fusion"; } @@ -804,18 +801,6 @@ static string FusionNodeName(HloInstruction::FusionKind fusion_kind) { return instruction; } -/* static */ std::unique_ptr -HloInstruction::CreateFusionForBackwardConvolution( - const Shape& shape, FusionKind fusion_kind, const Window& window, - const ConvolutionDimensionNumbers& conv_dnums, HloInstruction* fused_root) { - std::unique_ptr fusion = - CreateFusion(shape, fusion_kind, fused_root); - fusion->window_ = MakeUnique(window); - fusion->convolution_dimension_numbers_ = - MakeUnique(conv_dnums); - return fusion; -} - void HloInstruction::MergeFusionInstruction( HloInstruction* instruction_to_merge) { CHECK_EQ(opcode_, HloOpcode::kFusion); @@ -1627,7 +1612,8 @@ bool HloInstruction::HasConstantOperand() const { bool HloInstruction::IdenticalSlowPath( const HloInstruction& other, const std::function& - eq_computations) const { + eq_computations, + const std::function& eq_shapes) const { // Perform opcode specific checks. switch (opcode()) { // The result of these instructions only depend upon their opcode and @@ -1686,7 +1672,7 @@ bool HloInstruction::IdenticalSlowPath( return parameter_number() == other.parameter_number() && // Check the shape too because `this` and `other` may be in // different HloComputations. - ShapeUtil::Compatible(shape(), other.shape()); + eq_shapes(shape(), other.shape()); case HloOpcode::kBatchNormTraining: case HloOpcode::kBatchNormInference: @@ -1742,18 +1728,18 @@ bool HloInstruction::IdenticalSlowPath( protobuf_util::ProtobufEquals(window(), other.window()); case HloOpcode::kReshape: - return ShapeUtil::Compatible(shape(), other.shape()); + return eq_shapes(shape(), other.shape()); // Transpose result is determined by the final shape and the permutation. case HloOpcode::kTranspose: - return ShapeUtil::Compatible(shape(), other.shape()) && + return eq_shapes(shape(), other.shape()) && dimensions() == other.dimensions(); // Remaining instructions with special values. case HloOpcode::kBitcast: - return ShapeUtil::Equal(shape(), other.shape()); + return eq_shapes(shape(), other.shape()); case HloOpcode::kBroadcast: - return ShapeUtil::Compatible(shape(), other.shape()) && + return eq_shapes(shape(), other.shape()) && dimensions() == other.dimensions(); case HloOpcode::kConcatenate: return dimensions() == other.dimensions(); @@ -1767,10 +1753,10 @@ bool HloInstruction::IdenticalSlowPath( slice_limits_ == other.slice_limits_ && slice_strides_ == other.slice_strides_; case HloOpcode::kDynamicSlice: - return ShapeUtil::Compatible(shape(), other.shape()) && + return eq_shapes(shape(), other.shape()) && dynamic_slice_sizes_ == other.dynamic_slice_sizes_; case HloOpcode::kDynamicUpdateSlice: - return ShapeUtil::Compatible(shape(), other.shape()); + return eq_shapes(shape(), other.shape()); case HloOpcode::kCall: case HloOpcode::kMap: return eq_computations(to_apply(), other.to_apply()); @@ -2318,7 +2304,7 @@ string HloInstruction::ToCategory() const { return "data formatting"; } - auto conv_category = [&] { + if (opcode() == HloOpcode::kConvolution) { string category = "convolution"; if (window_util::HasBaseDilation(window())) { category += " base-dilated"; @@ -2327,10 +2313,6 @@ string HloInstruction::ToCategory() const { category += " window-dilated"; } return category; - }; - - if (opcode() == HloOpcode::kConvolution) { - return conv_category(); } // Give transpose-dot and backwards-conv fusions the categories "dot" and @@ -2348,9 +2330,6 @@ string HloInstruction::ToCategory() const { return "output fusion"; case FusionKind::kTransposeDot: return "dot"; - case FusionKind::kConvBackwardFilter: - case FusionKind::kConvBackwardInput: - return conv_category(); case FusionKind::kCustom: return "custom fusion"; } @@ -3125,10 +3104,6 @@ string ToString(HloInstruction::FusionKind kind) { return "kOutput"; case HloInstruction::FusionKind::kTransposeDot: return "kTransposeDot"; - case HloInstruction::FusionKind::kConvBackwardFilter: - return "kConvBackwardFilter"; - case HloInstruction::FusionKind::kConvBackwardInput: - return "kConvBackwardInput"; case HloInstruction::FusionKind::kCustom: return "kCustom"; } @@ -3148,12 +3123,6 @@ StatusOr StringToFusionKind( if (kind_name == "kTransposeDot") { return HloInstruction::FusionKind::kTransposeDot; } - if (kind_name == "kConvBackwardFilter") { - return HloInstruction::FusionKind::kConvBackwardFilter; - } - if (kind_name == "kConvBackwardInput") { - return HloInstruction::FusionKind::kConvBackwardInput; - } if (kind_name == "kCustom") { return HloInstruction::FusionKind::kCustom; } @@ -3261,7 +3230,13 @@ string HloInstruction::ConvolutionDimensionNumbersToString() const { result += "_"; append_dims(rhs_dims, operand(1)->shape()); result += "->"; - append_dims(output_dims, shape()); + + // A convolution can be represented as a kConvolution HLO or as a CustomCall + // that returns a tuple, the first element of which is the result of the + // convolution. + Shape this_shape = + ShapeUtil::IsTuple(shape()) ? shape().tuple_shapes(0) : shape(); + append_dims(output_dims, this_shape); return result; } diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 5e89dc79bea81e650331e320f7836fdde90b2a53..50931c563a5319ecdf8493ee5507eb2551b34673 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -162,17 +162,14 @@ class HloPrintOptions { class HloInstruction { public: enum class FusionKind { - kLoop, // Fused into a loop. - kInput, // Op's input is fused into the op itself. - kOutput, // Op's output is fused into the op itself. - // REQUIRES: At least one operand buffer must be able - // to alias the output buffer. - kTransposeDot, // Fused into a dot with transposed operands. - kConvBackwardFilter, // Fused into a backward filter convolution. - kConvBackwardInput, // Fused into a backward input convolution. - - kCustom, // Custom category for backend-specific fusions that - // do not match any of the more specific ones. + kLoop, // Fused into a loop. + kInput, // Op's input is fused into the op itself. + kOutput, // Op's output is fused into the op itself. + // REQUIRES: At least one operand buffer must be able + // to alias the output buffer. + kTransposeDot, // Fused into a dot with transposed operands. + kCustom, // Custom category for backend-specific fusions that + // do not match any of the more specific ones. }; ~HloInstruction(); @@ -466,14 +463,6 @@ class HloInstruction { tensorflow::gtl::ArraySlice operands, HloComputation* fusion_computation); - // Creates a fusion instruction that represents backward convolution. This is - // similar to CreateFusion, but with extra arguments indicating the window and - // dimemsion mapping of the backward convolution. - static std::unique_ptr CreateFusionForBackwardConvolution( - const Shape& shape, FusionKind fusion_kind, const Window& window, - const ConvolutionDimensionNumbers& conv_dnums, - HloInstruction* fused_root); - // Creates a call instruction that applies the given computation on the given // operands. "shape" is the resultant shape. static std::unique_ptr CreateCall( @@ -565,27 +554,36 @@ class HloInstruction { } // Returns true if "other" performs the same computation as this instruction. - // Layout of the instructions' output array is not considered. bool Identical( const HloInstruction& other, const std::function& eq_operands = std::equal_to(), const std::function& - eq_computations = std::equal_to()) const { + eq_computations = std::equal_to(), + bool layout_sensitive = true) const { // An instruction is always identical to itself. if (this == &other) { return true; } - // Identical instruction must have the same opcode and identical operands. - // In general, there is no need to check shape because shape is inferred - // from the shape of the operands. + // Identical instruction must have the same opcode, shape, and identical + // operands. if (opcode() != other.opcode()) { return false; } + auto eq_shapes = layout_sensitive + ? [](const Shape& a, + const Shape& b) { return ShapeUtil::Equal(a, b); } + : [](const Shape& a, const Shape& b) { + return ShapeUtil::Compatible(a, b); + }; + if (!eq_shapes(shape(), other.shape())) { + return false; + } if (operands().size() != other.operands().size()) { return false; } + // Use an explicit loop rather than ContainerEquals, because copying around // std::functions may be too expensive in some cases. for (size_t i = 0; i < operands().size(); ++i) { @@ -594,7 +592,7 @@ class HloInstruction { } } - return IdenticalSlowPath(other, eq_computations); + return IdenticalSlowPath(other, eq_computations, eq_shapes); } // Returns whether the instruction has a constant operand. @@ -885,8 +883,8 @@ class HloInstruction { // Returns true if this instruction is a fusion instruction that generates // multiple outputs. const bool IsMultiOutputFusion() const { - return (opcode() == HloOpcode::kFusion && - fused_expression_root()->opcode() == HloOpcode::kTuple); + return opcode() == HloOpcode::kFusion && + fused_expression_root()->opcode() == HloOpcode::kTuple; } FusionKind fusion_kind() const { @@ -1052,13 +1050,23 @@ class HloInstruction { return *padding_config_; } - // Returns data on the dimension numbers used for a convolution - // operation. + // Returns data on the dimension numbers used for a convolution operation, + // which may be a kConvolution instruction or a kCustomCall that implements a + // convolution. const ConvolutionDimensionNumbers& convolution_dimension_numbers() const { CHECK(convolution_dimension_numbers_ != nullptr); return *convolution_dimension_numbers_; } + // Sets the convolution dimension numbers on this instruction. In general you + // shouldn't need to call this; instead, specify the convolution dimension + // numbers when you create the instruction. + void set_convolution_dimension_numbers( + const ConvolutionDimensionNumbers& dnums) { + convolution_dimension_numbers_ = + MakeUnique(dnums); + } + FftType fft_type() const { CHECK_EQ(HloOpcode::kFft, opcode_); return fft_type_; @@ -1233,10 +1241,14 @@ class HloInstruction { class FusionReusesParamElements; // See comments on Identical(). + // eq_shapes() is used to check shapes for equality, and would normally be + // expected to be ShapeUtil::Equals or ShapeUtil::Compatible, depending on + // whether we want a layout-sensitive check or not. bool IdenticalSlowPath( const HloInstruction& other, const std::function& - eq_computations) const; + eq_computations, + const std::function& eq_shapes) const; // Creates an n-ary elementwise operation. static std::unique_ptr CreateNary( diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc index 99d8dd04e5279e0e8a977370beedc4448dc6dc4b..60270b0595dcfca8f1fcea5ab0914428880f35b5 100644 --- a/tensorflow/compiler/xla/service/hlo_module.cc +++ b/tensorflow/compiler/xla/service/hlo_module.cc @@ -38,12 +38,16 @@ HloModule::HloModule(const string& name, : name_(NameUniquer::GetSanitizedName(name)), config_(config), has_entry_computation_handle_(true), - entry_computation_handle_(entry_computation_handle) {} + entry_computation_handle_(entry_computation_handle), + unique_id_(next_unique_module_id_++) {} HloModule::HloModule(const string& name) - : name_(NameUniquer::GetSanitizedName(name)) {} + : name_(NameUniquer::GetSanitizedName(name)), + unique_id_(next_unique_module_id_++) {} HloModule::HloModule(const string& name, const HloModuleConfig& config) - : name_(NameUniquer::GetSanitizedName(name)), config_(config) {} + : name_(NameUniquer::GetSanitizedName(name)), + config_(config), + unique_id_(next_unique_module_id_++) {} HloComputation* HloModule::AddComputationInternal( std::unique_ptr computation, bool is_entry, @@ -564,4 +568,6 @@ uint64 HloModule::RandomNew64() const { return rng_(); } +/* static */ std::atomic HloModule::next_unique_module_id_(0); + } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h index e377654d024819d00f73f43a70d363bd902dc981..4bfe8d89ce0a285de6d05d4867aaa6b266d78d12 100644 --- a/tensorflow/compiler/xla/service/hlo_module.h +++ b/tensorflow/compiler/xla/service/hlo_module.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_H_ +#include #include #include #include @@ -201,6 +202,10 @@ class HloModule { // this point are guaranteed to be in the range [0..NumUniqueInstructionIds()) int NumUniqueInstructionIds() const { return next_unique_id_; } + // Returns an id that is unique to this module across all modules created over + // the lifetime of this process. + int unique_id() const { return unique_id_; } + private: HloComputation* AddComputationInternal( std::unique_ptr computation, bool is_entry, @@ -227,6 +232,11 @@ class HloModule { NameUniquer computation_name_uniquer_{/*separator=*/"."}; NameUniquer instruction_name_uniquer_{/*separator=*/"."}; int next_unique_id_ = 0; + + // Used to keep track of the next unique module id that should be assigned. + static std::atomic next_unique_module_id_; + // A unique id to label modules with. + int unique_id_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_module_test.cc b/tensorflow/compiler/xla/service/hlo_module_test.cc index cd51fa4e8549daba3e953eece50cb3538f627b89..7f28a804bfec9c2f1bbb5fa08f7dd4e68be14d35 100644 --- a/tensorflow/compiler/xla/service/hlo_module_test.cc +++ b/tensorflow/compiler/xla/service/hlo_module_test.cc @@ -188,6 +188,12 @@ TEST_F(HloModuleTest, LargeConstantToString) { module->ToString(HloPrintOptions().set_print_large_constants(true))); } +TEST_F(HloModuleTest, UniqueModuleId) { + auto module_a = CreateNewModule(); + auto module_b = CreateNewModule(); + EXPECT_NE(module_a->unique_id(), module_b->unique_id()); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc index 53bd46a641afcba1b9551895955742e74a9f374b..5120775737bfa32bbb656421216f2b3fbef590ea 100644 --- a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc +++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" +#include "tensorflow/compiler/xla/service/hlo_proto_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" @@ -32,12 +33,28 @@ using ::tensorflow::strings::StrCat; namespace xla { namespace { -void DumpModule(const HloModule& module, - const string& message) { +void DumpModuleGraph(const HloModule& module, const string& message) { hlo_graph_dumper::MaybeDumpHloModule(module, message); VLOG(3) << "HLO " << message << ":"; XLA_VLOG_LINES(3, module.ToString()); } + +void DumpModuleProto(const HloModule& module, const string& dump_to, + const string& pipeline_name, const string& pass_name) { + static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED); + static auto* const module_id_to_pass_number = + new tensorflow::gtl::FlatMap(); + + tensorflow::mutex_lock lock(mu); + const int64 pass_number = (*module_id_to_pass_number)[module.unique_id()]++; + + const string mod_name = SanitizeFileName(tensorflow::strings::Printf( + "module_%04d.%04lld.%s.after_%s", module.unique_id(), pass_number, + pipeline_name.c_str(), pass_name.c_str())); + + TF_QCHECK_OK(protobuf_util::DumpProtoToDirectory(MakeHloProto(module), + dump_to, mod_name)); +} } // namespace StatusOr HloPassPipeline::Run(HloModule* module) { @@ -78,6 +95,13 @@ StatusOr HloPassPipeline::Run(HloModule* module) { string message; TF_RETURN_IF_ERROR( run_invariant_checkers(StrCat("before running pipeline: ", name()))); + const string xla_dump_per_pass_hlo_proto_to = + module->config().debug_options().xla_dump_per_pass_hlo_proto_to(); + if (!xla_dump_per_pass_hlo_proto_to.empty()) { + DumpModuleProto(*module, xla_dump_per_pass_hlo_proto_to, name().ToString(), + "pipeline_start"); + } + for (auto& pass : passes_) { if (disabled_passes.count(pass->name().ToString()) > 0) { VLOG(1) << " Skipping HLO pass " << pass->name() @@ -90,17 +114,21 @@ StatusOr HloPassPipeline::Run(HloModule* module) { // Emit label containing: "after foo-pass, before bar-pass". message.clear(); StrAppend(&message, prefix, ", before ", pass->name()); - DumpModule(*module, message); + DumpModuleGraph(*module, message); TF_ASSIGN_OR_RETURN(bool changed_this_pass, pass->Run(module)); TF_RETURN_IF_ERROR( run_invariant_checkers(StrCat("after running pass: ", pass->name()))); + if (!xla_dump_per_pass_hlo_proto_to.empty()) { + DumpModuleProto(*module, xla_dump_per_pass_hlo_proto_to, + name().ToString(), pass->name().ToString()); + } changed |= changed_this_pass; prefix.clear(); StrAppend(&prefix, name(), ": after ", pass->name()); } - DumpModule(*module, prefix + ", pipeline end"); + DumpModuleGraph(*module, prefix + ", pipeline end"); return changed; } diff --git a/tensorflow/compiler/xla/service/hlo_runner.cc b/tensorflow/compiler/xla/service/hlo_runner.cc index e281538848949f8bc232d91692a2227e3550a0a7..41b079eb799d06321a31f7d7ae0630dc8d58c46b 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.cc +++ b/tensorflow/compiler/xla/service/hlo_runner.cc @@ -47,22 +47,11 @@ HloRunner::CreateModuleFromString(const tensorflow::StringPiece hlo_string, return tools::Parse(hlo_string, config); } -/*static*/ StatusOr> -HloRunner::ReadModuleFromHloProtoFile(const std::string& filename, - const DebugOptions& debug_options) { - HloProto proto; - - const Status s = - tensorflow::ReadBinaryProto(tensorflow::Env::Default(), filename, &proto); - - if (!s.ok()) { - const Status s2 = - tensorflow::ReadTextProto(tensorflow::Env::Default(), filename, &proto); - if (!s2.ok()) { - return Status(s2.code(), s.error_message() + "\n" + s2.error_message()); - } - } +namespace { +// Creates an HloModule from the given proto. +StatusOr> HloProtoToModule( + const HloProto& proto, const DebugOptions& debug_options) { TF_ASSIGN_OR_RETURN( HloModuleConfig config, HloModule::CreateModuleConfigFromProto(proto.hlo_module())); @@ -72,9 +61,29 @@ HloRunner::ReadModuleFromHloProtoFile(const std::string& filename, return std::move(module); } +} // namespace + /*static*/ StatusOr> -HloRunner::ReadModuleFromHloTextDumpFile(const std::string& filename, +HloRunner::ReadModuleFromBinaryProtoFile(const std::string& filename, const DebugOptions& debug_options) { + HloProto proto; + TF_RETURN_IF_ERROR(tensorflow::ReadBinaryProto(tensorflow::Env::Default(), + filename, &proto)); + return HloProtoToModule(proto, debug_options); +} + +/*static*/ StatusOr> +HloRunner::ReadModuleFromTextProtoFile(const std::string& filename, + const DebugOptions& debug_options) { + HloProto proto; + TF_RETURN_IF_ERROR( + tensorflow::ReadTextProto(tensorflow::Env::Default(), filename, &proto)); + return HloProtoToModule(proto, debug_options); +} + +/*static*/ StatusOr> +HloRunner::ReadModuleFromHloTextFile(const std::string& filename, + const DebugOptions& debug_options) { string hlo_string; TF_RETURN_IF_ERROR(tensorflow::ReadFileToString(tensorflow::Env::Default(), filename, &hlo_string)); @@ -83,19 +92,6 @@ HloRunner::ReadModuleFromHloTextDumpFile(const std::string& filename, return tools::Parse(hlo_string, config); } -/*static*/ StatusOr> HloRunner::ReadModule( - const std::string& filename, const DebugOptions& debug_options) { - auto module = HloRunner::ReadModuleFromHloProtoFile(filename, debug_options); - if (module.ok()) { - return module; - } - const std::string e = module.status().error_message(); - module = HloRunner::ReadModuleFromHloTextDumpFile(filename, debug_options); - return module.ok() ? std::move(module) - : Status(module.status().code(), - e + "\n" + module.status().error_message()); -} - // Define this in .cc file to avoid having to include eigen or forward declare // these types in the header. struct HloRunner::EigenThreadPoolWrapper { diff --git a/tensorflow/compiler/xla/service/hlo_runner.h b/tensorflow/compiler/xla/service/hlo_runner.h index d4b221fb52dff64dda264a931df6fd19b86e5260..cbaebc68bee708090b8ccb2eae19b556c4d6d453 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.h +++ b/tensorflow/compiler/xla/service/hlo_runner.h @@ -52,21 +52,15 @@ class HloRunner { const DebugOptions& debug_options); // Reads the proto file in xla.HloProto format, creates and returns the - // HloModule. Will try to parse the filename as binary proto, then try as - // text proto if that fails. - static StatusOr> ReadModuleFromHloProtoFile( + // HloModule. + static StatusOr> ReadModuleFromBinaryProtoFile( + const std::string& filename, const DebugOptions& debug_options); + static StatusOr> ReadModuleFromTextProtoFile( const std::string& filename, const DebugOptions& debug_options); // Reads the hlo text dump file in HloModule::ToString format, creates and // returns the HloModule. - static StatusOr> ReadModuleFromHloTextDumpFile( - const std::string& filename, const DebugOptions& debug_options); - - // Tries to parse the filename specified first as binary proto format, then - // as a textual proto format, then textual IR, then gives up if both fail. - // ReadModuleFromHloProtoFile or ReadModuleFromHloTextDumpFile should be used - // explicitly when you know the format, this if you don't. - static StatusOr> ReadModule( + static StatusOr> ReadModuleFromHloTextFile( const std::string& filename, const DebugOptions& debug_options); // Executes the given module with given literals as input and returns the diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc index 5413b95cfb6aad464da27c7b4aeaed5011e16393..fce135ef61a7868386b869def1a79167c428d928 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/layout_assignment.cc @@ -61,8 +61,8 @@ std::ostream& operator<<(std::ostream& out, BufferLayoutConstraint::BufferLayoutConstraint(const Layout& layout, const LogicalBuffer& buffer, - bool mandatory) - : LayoutConstraint(mandatory), layout_(layout), buffer_(&buffer) { + bool mandatory, bool dfs) + : LayoutConstraint(mandatory, dfs), layout_(layout), buffer_(&buffer) { CHECK(LayoutUtil::ValidateLayoutForShape(layout, buffer.shape()).ok()); } @@ -74,8 +74,8 @@ string BufferLayoutConstraint::ToString() const { OperandLayoutConstraint::OperandLayoutConstraint( const ShapeLayout& shape_layout, const HloInstruction* instruction, - int64 operand_no, bool mandatory) - : LayoutConstraint(mandatory), + int64 operand_no, bool mandatory, bool dfs) + : LayoutConstraint(mandatory, dfs), shape_layout_(shape_layout), instruction_(instruction), operand_no_(operand_no) { @@ -134,7 +134,7 @@ bool LayoutConstraints::OperandBufferForwarded( Status LayoutConstraints::SetBufferLayout(const Layout& layout, const LogicalBuffer& buffer, - bool mandatory) { + bool mandatory, bool dfs) { VLOG(3) << "SetBufferLayout : " << buffer << " : " << LayoutUtil::HumanString(layout); @@ -171,10 +171,11 @@ Status LayoutConstraints::SetBufferLayout(const Layout& layout, if (!overwrite) { iter = buffer_constraints_ .insert(std::make_pair( - &buffer, BufferLayoutConstraint(layout, buffer, mandatory))) + &buffer, + BufferLayoutConstraint(layout, buffer, mandatory, dfs))) .first; } else { - iter->second = BufferLayoutConstraint(layout, buffer, /*mandatory=*/true); + iter->second = BufferLayoutConstraint(layout, buffer, mandatory, dfs); } added_constraints_.push_back(&iter->second); @@ -188,7 +189,8 @@ Status LayoutConstraints::SetBufferLayout(const Layout& layout, Status LayoutConstraints::SetOperandLayout(const Shape& shape_with_layout, const HloInstruction* instruction, - int64 operand_no, bool mandatory) { + int64 operand_no, bool mandatory, + bool dfs) { VLOG(3) << "SetOperandLayout : " << instruction->name() << ", operand " << operand_no << " : " << ShapeUtil::HumanStringWithLayout(shape_with_layout); @@ -226,12 +228,12 @@ Status LayoutConstraints::SetOperandLayout(const Shape& shape_with_layout, if (iter == operand_constraints_.end()) { auto pair = std::make_pair( key, OperandLayoutConstraint(ShapeLayout(shape_with_layout), - instruction, operand_no, mandatory)); + instruction, operand_no, mandatory, dfs)); iter = operand_constraints_.insert(pair).first; } else { iter->second = OperandLayoutConstraint(ShapeLayout(shape_with_layout), instruction, - operand_no, /*mandatory=*/true); + operand_no, mandatory, dfs); } added_constraints_.push_back(&iter->second); @@ -240,16 +242,17 @@ Status LayoutConstraints::SetOperandLayout(const Shape& shape_with_layout, Status LayoutConstraints::SetArrayOperandLayout( const Layout& layout, const HloInstruction* instruction, int64 operand_no, - bool mandatory) { + bool mandatory, bool dfs) { const HloInstruction* operand = instruction->operand(operand_no); TF_RET_CHECK(ShapeUtil::IsArray(operand->shape())); Shape shape(operand->shape()); *shape.mutable_layout() = layout; TF_RETURN_IF_ERROR(LayoutUtil::ValidateLayoutInShape(shape)); - return SetOperandLayout(shape, instruction, operand_no, mandatory); + return SetOperandLayout(shape, instruction, operand_no, mandatory, dfs); } -Status LayoutConstraints::SetResultLayout(const Shape& shape_with_layout) { +Status LayoutConstraints::SetResultLayout(const Shape& shape_with_layout, + bool dfs) { VLOG(3) << "SetResultLayout : " << ShapeUtil::HumanStringWithLayout(shape_with_layout); @@ -267,14 +270,15 @@ Status LayoutConstraints::SetResultLayout(const Shape& shape_with_layout) { } result_constraint_.reset( - new ResultLayoutConstraint(ShapeLayout(shape_with_layout))); + new ResultLayoutConstraint(ShapeLayout(shape_with_layout), dfs)); added_constraints_.push_back(result_constraint_.get()); return Status::OK(); } Status LayoutConstraints::SetInstructionLayout( - const Shape& shape_with_layout, const HloInstruction* instruction) { + const Shape& shape_with_layout, const HloInstruction* instruction, + bool mandatory, bool dfs) { VLOG(3) << "SetInstructionLayout : " << instruction->name() << ", " << ShapeUtil::HumanStringWithLayout(shape_with_layout); @@ -290,8 +294,8 @@ Status LayoutConstraints::SetInstructionLayout( // instruction. return ShapeUtil::ForEachSubshapeWithStatus( shape_with_layout, - [this, instruction](const Shape& subshape, - const ShapeIndex& index) -> Status { + [this, instruction, mandatory](const Shape& subshape, + const ShapeIndex& index) -> Status { // The precondition for this method is that the instruction defines all // buffers in its output. auto buffers = @@ -300,7 +304,7 @@ Status LayoutConstraints::SetInstructionLayout( CHECK_EQ(buffers[0]->instruction(), instruction); if (ShapeUtil::IsArray(subshape)) { - return SetBufferLayout(subshape.layout(), *buffers[0]); + return SetBufferLayout(subshape.layout(), *buffers[0], mandatory); } else { return Status::OK(); } @@ -394,8 +398,7 @@ Status LayoutAssignment::AddMandatoryConstraints( // Constrain the input to the Outfeed instruction to be the expected // layout of the Outfeed. TF_RETURN_IF_ERROR(constraints->SetOperandLayout( - instruction->outfeed_shape(), instruction, 0, - /*mandatory=*/true)); + instruction->outfeed_shape(), instruction, 0)); } else if (instruction->opcode() == HloOpcode::kParameter) { // Parameter layouts must match the respective layout in // ComputationLayout. @@ -434,8 +437,8 @@ Status LayoutAssignment::AddMandatoryConstraints( {0})); Shape new_shape = channel_constraints->LayoutShapeForChannel( recv_buffer_shape, instruction->channel_id()); - TF_RETURN_IF_ERROR(constraints->SetBufferLayout( - new_shape.layout(), *buffer, /*mandatory=*/true)); + TF_RETURN_IF_ERROR( + constraints->SetBufferLayout(new_shape.layout(), *buffer)); } } } @@ -457,7 +460,7 @@ Status LayoutAssignment::AddMandatoryConstraints( for (int64 i = 0; i < instruction->operand_count(); ++i) { TF_RETURN_IF_ERROR(constraints->SetOperandLayout( called_computation_layout.parameter_layout(i).shape(), instruction, - i, /*mandatory=*/true)); + i)); } } else if (instruction->opcode() == HloOpcode::kWhile) { // Layout of input and output of kWhile instruction must be equal and must @@ -508,8 +511,7 @@ Status LayoutAssignment::AddMandatoryConstraints( TF_RETURN_IF_ERROR(constraints->SetInstructionLayout( body_layout.result_shape(), instruction)); TF_RETURN_IF_ERROR(constraints->SetOperandLayout( - body_layout.result_shape(), instruction, 0, - /*mandatory=*/true)); + body_layout.result_shape(), instruction, 0)); } else if (instruction->opcode() == HloOpcode::kCustomCall) { if (!CustomCallRequiresMajorFirstLayout(instruction)) { continue; @@ -533,7 +535,7 @@ Status LayoutAssignment::AddMandatoryConstraints( operand_shape.element_type(), AsInt64Slice(operand_shape.dimensions())); TF_RETURN_IF_ERROR(constraints->SetOperandLayout( - row_major_operand_shape, instruction, i, /*mandatory=*/true)); + row_major_operand_shape, instruction, i)); } } } @@ -907,7 +909,11 @@ Status LayoutAssignment::PropagateConstraints(LayoutConstraints* constraints) { auto add_new_constraints_to_worklist = [constraints, &worklist]() { // Add constraints to the front of the deque for DFS ordering. for (auto* constraint : constraints->ConsumeAddedConstraints()) { - worklist.push_front(constraint); + if (constraint->dfs()) { + worklist.push_front(constraint); + } else { + worklist.push_back(constraint); + } } }; add_new_constraints_to_worklist(); @@ -1390,7 +1396,7 @@ Status LayoutAssignment::RunOnComputation( // Add any backend-specific constraints. TF_RETURN_IF_ERROR(AddBackendConstraints(&constraints)); - // Propagates layouts from an HLO to its neighbors. + // Propagates layouts from mandatory and backend constraints. TF_RETURN_IF_ERROR(PropagateConstraints(&constraints)); // While any unconstrained buffers remain, pick an arbitrary buffer, give it a @@ -1455,7 +1461,12 @@ StatusOr LayoutAssignment::Run(HloModule* module) { // Assign layouts to computations in an order such that a callee computation // is handled before its caller computation. This ensures that the layout of // all callers of a computation will agree. + std::list computation_post_order = + module->MakeComputationPostOrder(); for (auto* computation : module->MakeComputationPostOrder()) { + if (computation->IsFusionComputation()) { + continue; + } // Clear existing layouts of the instructions. All layouts must be assigned // by the LayoutAssignment pass, except for those on infeeds, parameters, // and the computation result. The latter two are specified in @@ -1467,13 +1478,10 @@ StatusOr LayoutAssignment::Run(HloModule* module) { LayoutUtil::ClearLayout(instruction->mutable_shape()); } } - if (computation == module->entry_computation()) { TF_RETURN_IF_ERROR(RunOnComputation( *entry_computation_layout_, *points_to_analysis, module->entry_computation(), channel_layout_constraints_)); - } else if (computation->IsFusionComputation()) { - continue; } else { ComputationLayout computation_layout(computation->ComputeProgramShape()); // Setting all embedded computations to the default layout is potentially diff --git a/tensorflow/compiler/xla/service/layout_assignment.h b/tensorflow/compiler/xla/service/layout_assignment.h index 6bfae2998609c0482b91368f1891ce1e8e43fa23..29018584487cabfd740d7914625c2a50f552d6ff 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.h +++ b/tensorflow/compiler/xla/service/layout_assignment.h @@ -46,7 +46,8 @@ namespace xla { // gathered together in LayoutConstraints object. class LayoutConstraint { public: - LayoutConstraint(bool mandatory) : mandatory_(mandatory) {} + LayoutConstraint(bool mandatory, bool dfs) + : mandatory_(mandatory), dfs_(dfs) {} virtual ~LayoutConstraint() = default; virtual string ToString() const = 0; @@ -54,8 +55,12 @@ class LayoutConstraint { // True if this constraint cannot be overwritten by a different constraint. bool mandatory() const { return mandatory_; } + // When true, propagate in DFS. When false, constraint will propagate in BFS. + bool dfs() const { return dfs_; } + private: bool mandatory_; + bool dfs_; }; std::ostream& operator<<(std::ostream& out, const LayoutConstraint& constraint); @@ -65,7 +70,7 @@ std::ostream& operator<<(std::ostream& out, const LayoutConstraint& constraint); class BufferLayoutConstraint : public LayoutConstraint { public: BufferLayoutConstraint(const Layout& layout, const LogicalBuffer& buffer, - bool mandatory); + bool mandatory, bool dfs); const LogicalBuffer& buffer() const { return *buffer_; } const Layout& layout() const { return layout_; } @@ -86,7 +91,7 @@ class OperandLayoutConstraint : public LayoutConstraint { public: OperandLayoutConstraint(const ShapeLayout& shape_layout, const HloInstruction* instruction, int64 operand_no, - bool mandatory); + bool mandatory, bool dfs); const ShapeLayout& shape_layout() const { return shape_layout_; } const HloInstruction* instruction() const { return instruction_; } @@ -106,8 +111,10 @@ class OperandLayoutConstraint : public LayoutConstraint { // Constraint on the layout of the result of the entry computation. class ResultLayoutConstraint : public LayoutConstraint { public: - explicit ResultLayoutConstraint(const ShapeLayout& shape_layout) - : LayoutConstraint(/*mandatory=*/true), shape_layout_(shape_layout) {} + explicit ResultLayoutConstraint(const ShapeLayout& shape_layout, + bool dfs = false) + : LayoutConstraint(/*mandatory=*/true, dfs), + shape_layout_(shape_layout) {} const ShapeLayout& shape_layout() const { return shape_layout_; } string ToString() const override; @@ -157,23 +164,25 @@ class LayoutConstraints { // operand of the instruction, or the layout of the result of the computation, // respectively. Status SetBufferLayout(const Layout& layout, const LogicalBuffer& buffer, - bool mandatory = true); + bool mandatory = true, bool dfs = true); Status SetOperandLayout(const Shape& shape_with_layout, const HloInstruction* instruction, int64 operand_no, - bool mandatory = true); - Status SetResultLayout(const Shape& shape_with_layout); + bool mandatory = true, bool dfs = true); + Status SetResultLayout(const Shape& shape_with_layout, bool dfs = true); // Convenience wrapper around SetOperandLayout for setting the layout of a // operand using a Layout object. The operand must be array-shaped. Status SetArrayOperandLayout(const Layout& layout, const HloInstruction* instruction, - int64 operand_no, bool mandatory = true); + int64 operand_no, bool mandatory = true, + bool dfs = true); // Convenience wrapper around SetBufferLayout. Sets the layouts of all buffers // created by the instruction to the layouts in the given shape. The // instruction must define every logical buffer in its output. Status SetInstructionLayout(const Shape& shape_with_layout, - const HloInstruction* instruction); + const HloInstruction* instruction, + bool mandatory = true, bool dfs = true); // Returns true if any buffer in the given operand is forwarded to the output // of the given instruction. For example, the Tuple instruction forwards the diff --git a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h index 9ad7cd82cb8ca862fd7acec3dfb12c9fd61f6e27..b3b6026ef17daa184c0a015fdea618597ef068b3 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h +++ b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h @@ -32,8 +32,23 @@ limitations under the License. namespace xla { -// Unlike IrEmitter, this creates host functions which emit IR to generate the -// output element at the given index. It is used to generate fused operations. +// FusedIrEmitter is used to generate code for fusion nodes. +// +// Unlike IrEmitter and its ilk, which directly create LLVM IR in an LLVM +// Module, FusedIrEmitter is better understood as "IR generator generator". +// FusedIrEmitter recursively creates a generator (a host function) which the +// compiler can invoke at a later time. Invoking the generator emits LLVM IR +// that, when run, produces the value at a particular index of the output. +// +// After building this generator, the compiler creates a loop (or its moral +// equivalent, e.g. a GPU kernel) and calls the generator from within the loop. +// This generates code that produces each element of the output. +// +// This class handles both vanilla fusion and multi-output fusion. In the MOF +// case, the fusion node ends with a kTuple instruction, and the generator +// created produces an LLVM struct with N elements, one for each element of the +// arrays in the tuple. It follows that the arrays in the tuple must have the +// same length. class FusedIrEmitter : public DfsHloVisitorWithDefault { public: using Generator = llvm_ir::ElementGenerator; diff --git a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc index a5f7c850c33757fe8d48567ade35544d81224e46..b6b918ec78a27b90325f72eea14b97f9aee43c54 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc @@ -51,37 +51,40 @@ LoopEmitter::LoopEmitter(const ElementGenerator& target_element_generator, shape_(target_array.GetShape()), ir_builder_(ir_builder) {} +static LoopEmitter::BodyEmitter MakeBodyEmitterForMultiOutputFusion( + const ElementGenerator& target_element_generator, + const std::vector& target_arrays, llvm::IRBuilder<>* ir_builder) { + return [=](const llvm_ir::IrArray::Index array_index) { + TF_ASSIGN_OR_RETURN(llvm::Value * target_element, + target_element_generator(array_index)); + CHECK(target_element->getType()->isStructTy()) + << "This BodyEmitter is for multi-output fusion, but target element " + "generator does not produce values of struct type."; + CHECK_EQ(target_element->getType()->getStructNumElements(), + target_arrays.size()); + + for (int64 i = 0; i < target_arrays.size(); ++i) { + target_arrays[i].EmitWriteArrayElement( + array_index, ir_builder->CreateExtractValue(target_element, i), + ir_builder); + } + return Status::OK(); + }; +} + LoopEmitter::LoopEmitter(const ElementGenerator& target_element_generator, tensorflow::gtl::ArraySlice target_arrays, llvm::IRBuilder<>* ir_builder) - : body_emitter_([=](const llvm_ir::IrArray::Index array_index) - -> ::tensorflow::Status { - // Convert target_element_generator to a BodyEmitter. - TF_ASSIGN_OR_RETURN(llvm::Value * target_element, - target_element_generator(array_index)); - if (target_arrays.size() == 1) { - target_arrays[0].EmitWriteArrayElement(array_index, target_element, - ir_builder); - return tensorflow::Status::OK(); - } - - for (int64 i = 0; i < target_arrays.size(); ++i) { - target_arrays[i].EmitWriteArrayElement( - array_index, ir_builder_->CreateExtractValue(target_element, i), - ir_builder); - } - return tensorflow::Status::OK(); - }), + : body_emitter_(MakeBodyEmitterForMultiOutputFusion( + target_element_generator, + std::vector(target_arrays.begin(), target_arrays.end()), + ir_builder)), + shape_(target_arrays[0].GetShape()), ir_builder_(ir_builder) { - if (target_arrays.size() > 1) { - // The sanity check for multiple outputs. - shape_ = target_arrays[0].GetShape(); - for (int64 i = 1; i < target_arrays.size(); ++i) { - const Shape& element_shape = target_arrays[i].GetShape(); - CHECK(ShapeUtil::SameDimensions(shape_, element_shape)); - } - } else { - shape_ = target_arrays[0].GetShape(); + // Sanity check: In multi-output fusion, all shapes produced must have the + // same dimensions. + for (const IrArray& array : target_arrays) { + CHECK(ShapeUtil::SameDimensions(shape_, array.GetShape())); } } diff --git a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h index 1ef1dc246442041698d96f6aff48794c8788f1d1..0fc528439a0d5bf8382dfcf2d8b3051f8900bf1d 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h +++ b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h @@ -47,10 +47,16 @@ class LoopEmitter { // element of the given target array. LoopEmitter(const ElementGenerator& target_element_generator, const IrArray& target_array, llvm::IRBuilder<>* ir_builder); - // Same as previous method except emits multiple targets in an array. + + // Constructs a LoopEmitter that emits one element into each of N separate + // arrays on each iteration of the loop. + // + // This is used for multi-output fusion. target_element_generator must + // produce an LLVM struct with N elements. LoopEmitter(const ElementGenerator& target_element_generator, tensorflow::gtl::ArraySlice target_arrays, llvm::IRBuilder<>* ir_builder); + LoopEmitter(const LoopEmitter&) = delete; LoopEmitter& operator=(const LoopEmitter&) = delete; virtual ~LoopEmitter() = default; diff --git a/tensorflow/compiler/xla/service/local_service.cc b/tensorflow/compiler/xla/service/local_service.cc index bb9fd447d98b8bcb59af139af2ec0d4fa073844e..07f989d4faea199e812e54d2ae74d3ff9e7fa19a 100644 --- a/tensorflow/compiler/xla/service/local_service.cc +++ b/tensorflow/compiler/xla/service/local_service.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/client/executable_build_options.h" #include "tensorflow/compiler/xla/execution_options_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/backend.h" @@ -71,8 +72,7 @@ LocalService::LocalService(const ServiceOptions& options, StatusOr> LocalService::CompileExecutable( const ComputationHandle& computation, const tensorflow::gtl::ArraySlice argument_layouts, - const Shape* result_layout, int device_ordinal, - DeviceMemoryAllocator* device_allocator) { + const ExecutableBuildOptions& build_options) { TF_ASSIGN_OR_RETURN(UserComputation * user_computation, computation_tracker_.Resolve(computation)); VersionedComputationHandle versioned_handle = @@ -113,14 +113,19 @@ StatusOr> LocalService::CompileExecutable( ShapeUtil::HumanString(argument_shape).c_str()); } } - if (result_layout != nullptr) { - TF_RETURN_IF_ERROR( - ValidateResultShapeWithLayout(*result_layout, program_shape->result())); + if (build_options.result_layout() != nullptr) { + TF_RETURN_IF_ERROR(ValidateResultShapeWithLayout( + *build_options.result_layout(), program_shape->result())); } ExecutionOptions execution_options = CreateDefaultExecutionOptions(); - if (result_layout != nullptr) { - *execution_options.mutable_shape_with_output_layout() = *result_layout; + if (build_options.generate_hlo_graph().has_value()) { + execution_options.mutable_debug_options()->set_xla_generate_hlo_graph( + build_options.generate_hlo_graph().value()); + } + if (build_options.result_layout() != nullptr) { + *execution_options.mutable_shape_with_output_layout() = + *build_options.result_layout(); } else { *execution_options.mutable_shape_with_output_layout() = program_shape->result(); @@ -132,11 +137,13 @@ StatusOr> LocalService::CompileExecutable( CreateModuleConfig(*program_shape, argument_layouts, &execution_options, *user_computation)); - TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor, - execute_backend_->stream_executor(device_ordinal)); + TF_ASSIGN_OR_RETURN( + se::StreamExecutor * executor, + execute_backend_->stream_executor(build_options.device_ordinal())); return BuildExecutable(versioned_handle, std::move(module_config), - execute_backend_.get(), executor, device_allocator); + execute_backend_.get(), executor, + build_options.device_allocator()); } StatusOr LocalService::ReplicaNumberToDeviceOrdinal(int replica_number) { diff --git a/tensorflow/compiler/xla/service/local_service.h b/tensorflow/compiler/xla/service/local_service.h index 16c71b25c4b7456f8f1a2bf873e0f25ae5af2c32..15e120685e1be9190d49fdaf5ed6706bdf991a6c 100644 --- a/tensorflow/compiler/xla/service/local_service.h +++ b/tensorflow/compiler/xla/service/local_service.h @@ -18,6 +18,7 @@ limitations under the License. #include +#include "tensorflow/compiler/xla/client/executable_build_options.h" #include "tensorflow/compiler/xla/service/backend.h" #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" @@ -47,8 +48,7 @@ class LocalService : public Service { StatusOr> CompileExecutable( const ComputationHandle& computation, const tensorflow::gtl::ArraySlice argument_layouts, - const Shape* result_layout, int device_ordinal, - DeviceMemoryAllocator* device_allocator); + const ExecutableBuildOptions& options); // Returns the device ordinal that corresponds to the given replica number. // diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index fea69563456edf63b604bcfb187fc6b745f12c77..98dfc89867ab33788c4cc837a66d6751a1ef2507 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -1453,9 +1453,9 @@ tensorflow::Status Service::Op(const OpRequest* arg, OpResponse* result) { handle_status = computation->AddInfeedInstruction(arg->infeed_request()); break; case OpRequest::kOutfeedRequest: - TF_RETURN_IF_ERROR( - computation->AddOutfeedInstruction(arg->outfeed_request())); - return tensorflow::Status::OK(); + handle_status = + computation->AddOutfeedInstruction(arg->outfeed_request()); + break; case OpRequest::kMapRequest: { TF_ASSIGN_OR_RETURN( UserComputation * to_apply, @@ -1619,14 +1619,14 @@ StatusOr> Service::Replicas( } Status Service::MaybeDumpHloModule(const HloModule& module) const { - const string xla_dump_prepass_hlo_proto_to = - module.config().debug_options().xla_dump_prepass_hlo_proto_to(); - if (xla_dump_prepass_hlo_proto_to.empty()) { + const string xla_dump_unoptimized_hlo_proto_to = + module.config().debug_options().xla_dump_unoptimized_hlo_proto_to(); + if (xla_dump_unoptimized_hlo_proto_to.empty()) { return Status::OK(); } HloProto proto = MakeHloProto(module); return protobuf_util::DumpProtoToDirectory( - proto, xla_dump_prepass_hlo_proto_to, module.name()); + proto, xla_dump_unoptimized_hlo_proto_to, module.name()); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/user_computation.cc b/tensorflow/compiler/xla/service/user_computation.cc index 2ea6507900e712200ce43e9b63577a4967381fdf..ef9c80b0431cb7dbf813732724d24d113dfbc4ab 100644 --- a/tensorflow/compiler/xla/service/user_computation.cc +++ b/tensorflow/compiler/xla/service/user_computation.cc @@ -1185,7 +1185,7 @@ StatusOr UserComputation::AddInfeedInstruction( return handle; } -Status UserComputation::AddOutfeedInstruction( +StatusOr UserComputation::AddOutfeedInstruction( const OutfeedRequest& outfeed_request) { tensorflow::mutex_lock lock(mutex_); @@ -1197,8 +1197,6 @@ Status UserComputation::AddOutfeedInstruction( // Verify that operand is valid. TF_RETURN_IF_ERROR(LookUpRequest(outfeed_request.operand()).status()); - // No handle is returned, but a handle must be assigned to this instruction - // for computation versioning. ComputationDataHandle handle = CreateComputationDataHandle(); OperationRequest& request = (*session_computation_.mutable_requests())[handle.handle()]; @@ -1209,7 +1207,7 @@ Status UserComputation::AddOutfeedInstruction( VLOG(1) << "AddOutfeedInstruction (" << GetVersionedHandleInternal() << "), data handle " << handle.handle() << ": " << outfeed_request.ShortDebugString(); - return Status::OK(); + return handle; } StatusOr UserComputation::AddCallInstruction( diff --git a/tensorflow/compiler/xla/service/user_computation.h b/tensorflow/compiler/xla/service/user_computation.h index 4f92e58877a1d06728fdd250744ca2ce7b57d9ad..54bb24d6d7fe7aa8cc7c684795e40464e4eb6614 100644 --- a/tensorflow/compiler/xla/service/user_computation.h +++ b/tensorflow/compiler/xla/service/user_computation.h @@ -146,7 +146,8 @@ class UserComputation { const InfeedRequest& infeed_request); // Enqueues an outfeed instruction onto this user computation. - Status AddOutfeedInstruction(const OutfeedRequest& outfeed_request); + StatusOr AddOutfeedInstruction( + const OutfeedRequest& outfeed_request); // Enqueues a call instruction onto this user computation. StatusOr AddCallInstruction( diff --git a/tensorflow/compiler/xla/service/user_computation_test.cc b/tensorflow/compiler/xla/service/user_computation_test.cc index ca02115863e6906ef709ba63259024877e0dcef4..2fa163953f638c0038e9f6bb11ce2a3742e0558c 100644 --- a/tensorflow/compiler/xla/service/user_computation_test.cc +++ b/tensorflow/compiler/xla/service/user_computation_test.cc @@ -67,7 +67,8 @@ TEST_F(UserComputationTest, SimpleComputation) { *outfeed_request.mutable_operand() = constant_handle; *outfeed_request.mutable_shape() = kVectorShape; outfeed_request.set_outfeed_config("abc"); - TF_ASSERT_OK(computation.AddOutfeedInstruction(outfeed_request)); + TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle outfeed_handle, + computation.AddOutfeedInstruction(outfeed_request)); auto hlo_resolver = [](const VersionedComputationHandle& handle) { return nullptr; diff --git a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc index 56fc21d019bb823f8f4631420a15fd607ef46a9a..b788631fa3e7ae45f4d72e8d6fb006c2332d2d1e 100644 --- a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc +++ b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc @@ -1879,20 +1879,73 @@ XLA_TEST_F(ArrayElementwiseOpTest, ClampF32ScalarVector) { auto min_scalar = builder.ConstantR0(0.0f); auto min_vector = builder.ConstantR1({1.0f, -6.5f, 1.0f, 2.25f, 0.0f}); auto arg_vector = builder.ConstantR1({2.0f, 10.0f, -5.0f, 1.0f, 4.0f}); - auto arg_scalar = builder.ConstantR1({2.0f, 10.0f, -5.0f, 1.0f, 4.0f}); auto max_scalar = builder.ConstantR0(3.0f); auto max_vector = builder.ConstantR1({3.0f, 0.5f, 25.5f, 5.0f, 123.0}); // Perform clamp with broadcasted scalar and vector. auto clamp = builder.Add( builder.Add(builder.Clamp(min_vector, arg_vector, max_scalar), builder.Clamp(min_scalar, arg_vector, max_vector)), - builder.Add(builder.Clamp(min_vector, arg_scalar, max_vector), - builder.Clamp(min_scalar, arg_scalar, max_vector))); + builder.Add(builder.Clamp(min_vector, arg_vector, max_vector), + builder.Clamp(min_scalar, arg_vector, max_scalar))); - ComputeAndCompareR1(&builder, {8.0f, 4.5f, 2.0f, 6.5f, 15.0f}, {}, + ComputeAndCompareR1(&builder, {8.0f, 7.0f, 2.0f, 6.5f, 14.0f}, {}, error_spec_); } +XLA_TEST_F(ArrayElementwiseOpTest, ClampS32Vector) { + ComputationBuilder builder(client_, TestName()); + auto min_vector = builder.ConstantR1({1, -6, 1, 2, 0, -5}); + auto arg_vector = builder.ConstantR1({2, 10, -5, 1, 4, 10}); + auto max_vector = builder.ConstantR1({3, 0, 25, 5, 123, -1}); + auto clamp = builder.Clamp(min_vector, arg_vector, max_vector); + + ComputeAndCompareR1(&builder, {2, 0, 1, 2, 4, -1}, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, ClampS32ScalarVector) { + ComputationBuilder builder(client_, TestName()); + auto min_scalar = builder.ConstantR0(0); + auto min_vector = builder.ConstantR1({1, -6, 1, 2, 0}); + auto arg_vector = builder.ConstantR1({2, 10, -5, 1, 4}); + auto max_scalar = builder.ConstantR0(3); + auto max_vector = builder.ConstantR1({3, 1, 25, 5, 123}); + // Perform clamp with broadcasted scalar and vector. + auto clamp = builder.Add( + builder.Add(builder.Clamp(min_vector, arg_vector, max_scalar), + builder.Clamp(min_scalar, arg_vector, max_vector)), + builder.Add(builder.Clamp(min_vector, arg_vector, max_vector), + builder.Clamp(min_scalar, arg_vector, max_scalar))); + + ComputeAndCompareR1(&builder, {8, 8, 2, 6, 14}, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, ClampU32Vector) { + ComputationBuilder builder(client_, TestName()); + auto min_vector = builder.ConstantR1({1, 2, 1, 2, 0, ~0u - 4}); + auto arg_vector = builder.ConstantR1({2, 10, 5, 1, 4, 10}); + auto max_vector = builder.ConstantR1({3, 5, 25, 5, 123, ~0u}); + auto clamp = builder.Clamp(min_vector, arg_vector, max_vector); + + ComputeAndCompareR1(&builder, {2, 5, 5, 2, 4, ~0u - 4}, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, ClampU32ScalarVector) { + ComputationBuilder builder(client_, TestName()); + auto min_scalar = builder.ConstantR0(0); + auto min_vector = builder.ConstantR1({1, 0, 1, 2, 0}); + auto arg_vector = builder.ConstantR1({2, 10, 0, 1, 4}); + auto max_scalar = builder.ConstantR0(3); + auto max_vector = builder.ConstantR1({3, 1, 25, 5, 123}); + // Perform clamp with broadcasted scalar and vector. + auto clamp = builder.Add( + builder.Add(builder.Clamp(min_vector, arg_vector, max_scalar), + builder.Clamp(min_scalar, arg_vector, max_vector)), + builder.Add(builder.Clamp(min_vector, arg_vector, max_vector), + builder.Clamp(min_scalar, arg_vector, max_scalar))); + + ComputeAndCompareR1(&builder, {8, 8, 2, 6, 14}, {}); +} + XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersF32s) { ComputationBuilder builder(client_, TestName()); diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc index 7c1a993b478a0e0878e85c0e4192da053e33619f..9f5806c5e16c30cf198027cffab5f78c315cb957 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc @@ -230,7 +230,7 @@ template const string& filename, const tensorflow::gtl::optional& error, const std::function& reference_preprocessor) { auto module_or_status = - HloRunner::ReadModule(filename, GetDebugOptionsForTest()); + HloRunner::ReadModuleFromHloTextFile(filename, GetDebugOptionsForTest()); if (!module_or_status.ok()) { return ::testing::AssertionFailure() << "failed reading hlo module from file"; @@ -258,7 +258,7 @@ template const string& filename, const tensorflow::gtl::optional& error, const std::function& reference_preprocessor) { auto module_or_status = - HloRunner::ReadModule(filename, GetDebugOptionsForTest()); + HloRunner::ReadModuleFromHloTextFile(filename, GetDebugOptionsForTest()); if (!module_or_status.ok()) { return ::testing::AssertionFailure() << "failed reading hlo module from file"; diff --git a/tensorflow/compiler/xla/tests/literal_test_util.cc b/tensorflow/compiler/xla/tests/literal_test_util.cc index 39c07297d69b6683f0fbe75fdd2186effda5043c..474d2547aeba1ec478eb3aa0cacfc04d9dee142e 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util.cc +++ b/tensorflow/compiler/xla/tests/literal_test_util.cc @@ -376,6 +376,10 @@ class NearComparator { abs_expected_miscompare_sum_ = 0.0; max_rel_err_ = 0.0; max_abs_err_ = 0.0; + first_linear_index_ = -1; + last_linear_index_ = -1; + max_rel_linear_index_ = -1; + max_abs_linear_index_ = -1; miscompares_ = Literal(ShapeUtil::ChangeElementType(actual.shape(), PRED)); miscompares_.PopulateWithValue(false); multi_index_.resize(expected.shape().dimensions_size(), 0); @@ -482,11 +486,11 @@ class NearComparator { const float rel_err = abs_diff / std::abs(expected); abs_diff_sum_ += abs_diff; abs_expected_sum_ += std::abs(expected); - if (rel_err > max_rel_err_) { + if (rel_err > max_rel_err_ || std::isnan(rel_err)) { max_rel_err_ = rel_err; max_rel_linear_index_ = linear_index; } - if (abs_diff > max_abs_err_) { + if (abs_diff > max_abs_err_ || std::isnan(abs_diff)) { max_abs_err_ = abs_diff; max_abs_linear_index_ = linear_index; } diff --git a/tensorflow/compiler/xla/tests/literal_test_util_test.cc b/tensorflow/compiler/xla/tests/literal_test_util_test.cc index e477784557a3b9340cff644a3695485389d8cc22..3a421f8458268a14dcdd84889bcae4990c095ea4 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util_test.cc +++ b/tensorflow/compiler/xla/tests/literal_test_util_test.cc @@ -97,5 +97,29 @@ TEST(LiteralTestUtilTest, ExpectNearFailurePlacesResultsInTemporaryDirectory) { } } +TEST(LiteralTestUtilTest, NearComparatorR1) { + auto a = + Literal::CreateR1({0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8}); + auto b = + Literal::CreateR1({0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8}); + EXPECT_TRUE(LiteralTestUtil::Near(*a, *b, ErrorSpec{0.0001})); +} + +TEST(LiteralTestUtilTest, NearComparatorR1Nan) { + auto a = + Literal::CreateR1({0.0, 0.1, 0.2, 0.3, NAN, 0.5, 0.6, 0.7, 0.8}); + auto b = + Literal::CreateR1({0.0, 0.1, 0.2, 0.3, NAN, 0.5, 0.6, 0.7, 0.8}); + EXPECT_TRUE(LiteralTestUtil::Near(*a, *b, ErrorSpec{0.0001})); +} + +TEST(LiteralTestUtil, NearComparatorDifferentLengths) { + auto a = + Literal::CreateR1({0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8}); + auto b = Literal::CreateR1({0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7}); + EXPECT_FALSE(LiteralTestUtil::Near(*a, *b, ErrorSpec{0.0001})); + EXPECT_FALSE(LiteralTestUtil::Near(*b, *a, ErrorSpec{0.0001})); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/pad_test.cc b/tensorflow/compiler/xla/tests/pad_test.cc index 3fd83a4c3b104831f03366339fb7b8b5d816a3f7..8cef8dd34dc7b16b1e58ded67d6b6a4ba79f20db 100644 --- a/tensorflow/compiler/xla/tests/pad_test.cc +++ b/tensorflow/compiler/xla/tests/pad_test.cc @@ -33,6 +33,14 @@ limitations under the License. namespace xla { namespace { +#ifdef XLA_BACKEND_SUPPORTS_BFLOAT16 +// Tests both F32 and BF16. +static std::array use_bfloat16_params{false, true}; +#else +// Only tests F32. +static std::array use_bfloat16_params{false}; +#endif + class PadTest : public ClientLibraryTestBase { protected: PadTest() { @@ -61,8 +69,22 @@ class PadTest : public ClientLibraryTestBase { PaddingConfig r4_padding_on_dim0_dim1_; }; +class PadTestFloat : public PadTest, + public ::testing::WithParamInterface { + protected: + PadTestFloat() { set_use_bfloat16(GetParam()); } + + ErrorSpec DefaultErrorSpec() const { + if (use_bfloat16()) { + return ErrorSpec(1e-3, 1e-3); + } else { + return ErrorSpec(1e-5, 1e-5); + } + } +}; + // Tests a Pad() with a zero-element input and output. -XLA_TEST_F(PadTest, Pad1DS0ToS0Array) { +XLA_TEST_P(PadTestFloat, Pad1DS0ToS0Array) { ComputationBuilder b(client_, TestName()); // Set up the padding configuration {low: 0, high: 0, interior: 0}. PaddingConfig padding_config; @@ -71,12 +93,13 @@ XLA_TEST_F(PadTest, Pad1DS0ToS0Array) { dimension->set_edge_padding_high(0); dimension->set_interior_padding(0); - b.Pad(b.ConstantR1({}), b.ConstantR0(0.1), padding_config); - ComputeAndCompareR1(&b, {}, {}, ErrorSpec(0.0001)); + b.Pad(AddParam(*Literal::CreateR1({}), &b), + AddParam(*Literal::CreateR0(0.1), &b), padding_config); + ComputeAndCompareR1(&b, {}, {}, DefaultErrorSpec()); } // Tests a Pad() with a zero-element input but a non-zero-element output. -XLA_TEST_F(PadTest, Pad1DS0ToS5Array) { +XLA_TEST_P(PadTestFloat, Pad1DS0ToS5Array) { ComputationBuilder b(client_, TestName()); // Set up the padding configuration {low: 3, high: 0, interior: 1}. PaddingConfig padding_config; @@ -85,12 +108,13 @@ XLA_TEST_F(PadTest, Pad1DS0ToS5Array) { dimension->set_edge_padding_high(4); dimension->set_interior_padding(7); - b.Pad(b.ConstantR1({}), b.ConstantR0(0.1), padding_config); + b.Pad(AddParam(*Literal::CreateR1({}), &b), + AddParam(*Literal::CreateR0(0.1), &b), padding_config); ComputeAndCompareR1(&b, std::vector(5, 0.1), {}, - ErrorSpec(0.0001)); + DefaultErrorSpec()); } -XLA_TEST_F(PadTest, Pad1DS3Array) { +XLA_TEST_P(PadTestFloat, Pad1DS3Array) { ComputationBuilder b(client_, TestName()); // Set up the padding configuration {low: 3, high: 0, interior: 1}. PaddingConfig padding_config; @@ -99,21 +123,21 @@ XLA_TEST_F(PadTest, Pad1DS3Array) { dimension->set_edge_padding_high(0); dimension->set_interior_padding(1); - b.Pad(b.ConstantR1({1, 2, 3}), b.ConstantR0(0.1), - padding_config); + b.Pad(AddParam(*Literal::CreateR1({1, 2, 3}), &b), + AddParam(*Literal::CreateR0(0.1), &b), padding_config); std::vector expected({0.1, 0.1, 0.1, 1, 0.1, 2, 0.1, 3}); - ComputeAndCompareR1(&b, expected, {}, ErrorSpec(0.0001)); + ComputeAndCompareR1(&b, expected, {}, DefaultErrorSpec()); } -XLA_TEST_F(PadTest, Pad4D_2x0x3x2_FloatArray) { +XLA_TEST_P(PadTestFloat, Pad4D_2x0x3x2_FloatArray) { ComputationBuilder b(client_, TestName()); - b.Pad(b.ConstantR4FromArray4D(Array4D(2, 0, 3, 2)), - b.ConstantR0(1.5), r4_padding_on_dim0_dim1_); + b.Pad(AddParam(Array4D(2, 0, 3, 2), &b), + AddParam(*Literal::CreateR0(1.5), &b), r4_padding_on_dim0_dim1_); ComputeAndCompareR4(&b, Array4D(5, 2, 3, 2, 1.5f), {}, - ErrorSpec(0.0001)); + DefaultErrorSpec()); } -TEST_F(PadTest, Pad4DFloat_1x1x3x2_Array) { +TEST_P(PadTestFloat, Pad4DFloat_1x1x3x2_Array) { ComputationBuilder b(client_, TestName()); auto input = MakeUnique>(1, 1, 3, 2); Array2D input_xy({ @@ -123,7 +147,7 @@ TEST_F(PadTest, Pad4DFloat_1x1x3x2_Array) { }); input->FillWithYX(input_xy); - b.Pad(b.ConstantR4FromArray4D(*input), b.ConstantR0(1.5), + b.Pad(AddParam(*input, &b), AddParam(*Literal::CreateR0(1.5), &b), r4_padding_on_dim0_dim1_); auto expected = MakeUnique>(2, 3, 3, 2); @@ -134,15 +158,15 @@ TEST_F(PadTest, Pad4DFloat_1x1x3x2_Array) { (*expected)(1, 0, 1, 1) = 4.0f; (*expected)(1, 0, 2, 0) = 5.0f; (*expected)(1, 0, 2, 1) = 6.0f; - ComputeAndCompareR4(&b, *expected, {}, ErrorSpec(0.0001)); + ComputeAndCompareR4(&b, *expected, {}, DefaultErrorSpec()); } -TEST_F(PadTest, Pad4DFloatArrayWithInteriorPadding) { +TEST_P(PadTestFloat, Pad4DFloatArrayWithInteriorPadding) { ComputationBuilder b(client_, TestName()); const float pad_value = 1.5f; Array4D input(3, 2, 1, 1, {1, 2, 3, 4, 5, 6}); - b.Pad(b.ConstantR4FromArray4D(input), b.ConstantR0(pad_value), + b.Pad(AddParam(input, &b), AddParam(*Literal::CreateR0(pad_value), &b), r4_padding_on_dim0_dim1_); auto expected = MakeUnique>(8, 5, 1, 1); @@ -156,7 +180,7 @@ TEST_F(PadTest, Pad4DFloatArrayWithInteriorPadding) { ComputeAndCompareR4(&b, *expected, {}, ErrorSpec(0.0001)); } -TEST_F(PadTest, Pad4DFloatArrayMinorFirstSmall) { +TEST_P(PadTestFloat, Pad4DFloatArrayMinorFirstSmall) { ComputationBuilder b(client_, TestName()); PaddingConfig padding_config; @@ -184,7 +208,8 @@ TEST_F(PadTest, Pad4DFloatArrayMinorFirstSmall) { auto input = Literal::CreateR4FromArray4D(input_array); input = input->Relayout(layout); - b.Pad(b.ConstantLiteral(*input), b.ConstantR0(pad_value), padding_config); + b.Pad(AddParam(*input, &b), + AddParam(*Literal::CreateR0(pad_value), &b), padding_config); Array4D expected_array(1, 1, 5, 8); expected_array.Fill(pad_value); @@ -197,7 +222,7 @@ TEST_F(PadTest, Pad4DFloatArrayMinorFirstSmall) { ComputeAndCompareR4(&b, expected_array, {}, ErrorSpec(0.0001)); } -XLA_TEST_F(PadTest, Pad4DFloatArrayMinorFirstNonTrivialMinorDimensions) { +XLA_TEST_P(PadTestFloat, Pad4DFloatArrayMinorFirstNonTrivialMinorDimensions) { ComputationBuilder b(client_, TestName()); PaddingConfig padding_config; @@ -229,7 +254,8 @@ XLA_TEST_F(PadTest, Pad4DFloatArrayMinorFirstNonTrivialMinorDimensions) { auto input = Literal::CreateR4FromArray4D(input_array); input = input->Relayout(layout); - b.Pad(b.ConstantLiteral(*input), b.ConstantR0(pad_value), padding_config); + b.Pad(AddParam(*input, &b), + AddParam(*Literal::CreateR0(pad_value), &b), padding_config); Array4D expected_array(1, 25, 17, 11); expected_array.Fill(pad_value); @@ -249,7 +275,7 @@ XLA_TEST_F(PadTest, Pad4DU8Array) { }); input->FillWithYX(input_xy); - b.Pad(b.ConstantR4FromArray4D(*input), b.ConstantR0(35), + b.Pad(AddParam(*input, &b), b.ConstantR0(35), r4_padding_on_dim0_dim1_); auto expected = MakeUnique>(2, 3, 3, 2); @@ -277,8 +303,7 @@ XLA_TEST_F(PadTest, Pad4DPredArray) { auto ones = MakeUnique>(2, 3, 3, 2); zeros->Fill(0); ones->Fill(1); - b.Select(padded, b.ConstantR4FromArray4D(*ones), - b.ConstantR4FromArray4D(*zeros)); + b.Select(padded, AddParam(*ones, &b), AddParam(*zeros, &b)); auto expected = MakeUnique>(2, 3, 3, 2); expected->Fill(0); @@ -291,10 +316,12 @@ XLA_TEST_F(PadTest, Pad4DPredArray) { ComputeAndCompareR4(&b, *expected, {}); } -XLA_TEST_F(PadTest, Large2DPad) { +XLA_TEST_P(PadTestFloat, Large2DPad) { ComputationBuilder b(client_, TestName()); - auto input = b.Parameter(0, ShapeUtil::MakeShape(F32, {4, 4}), "input"); + auto ones = MakeUnique>(4, 4); + ones->Fill(1.0f); + auto input = AddParam(*ones, &b); PaddingConfig padding_config = MakeNoPaddingConfig(2); for (int dim : {0, 1}) { padding_config.mutable_dimensions(dim)->set_edge_padding_low( @@ -302,25 +329,22 @@ XLA_TEST_F(PadTest, Large2DPad) { padding_config.mutable_dimensions(dim)->set_edge_padding_high(58 + 100 * dim); } - auto padded = b.Pad(input, b.ConstantR0(0.0f), padding_config); - - auto ones = MakeUnique>(4, 4); - ones->Fill(1.0f); - auto input_literal = Literal::CreateR2FromArray2D(*ones); - std::unique_ptr input_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + auto padded = b.Pad(input, AddParam(*Literal::CreateR0(0.0f), &b), + padding_config); auto expected = ReferenceUtil::PadArray2D(*ones, padding_config, 0.0f); - ComputeAndCompareR2(&b, *expected, {input_data.get()}); + ComputeAndCompareR2(&b, *expected, {}, DefaultErrorSpec()); } -XLA_TEST_F(PadTest, AllTypes2DPad) { +XLA_TEST_P(PadTestFloat, AllTypes2DPad) { ComputationBuilder b(client_, TestName()); constexpr int64 in_rows = 35; constexpr int64 in_cols = 35; - auto input = - b.Parameter(0, ShapeUtil::MakeShape(F32, {in_rows, in_cols}), "input"); + auto operand = MakeUnique>(in_rows, in_cols); + operand->FillUnique(0.0f); + auto input = AddParam(*operand, &b); + PaddingConfig padding_config = MakeNoPaddingConfig(2); padding_config.mutable_dimensions(0)->set_edge_padding_low(7); padding_config.mutable_dimensions(0)->set_edge_padding_high(5); @@ -328,20 +352,14 @@ XLA_TEST_F(PadTest, AllTypes2DPad) { padding_config.mutable_dimensions(1)->set_edge_padding_low(6); padding_config.mutable_dimensions(1)->set_edge_padding_high(4); padding_config.mutable_dimensions(1)->set_interior_padding(2); - auto padded = b.Pad(input, b.ConstantR0(3.14f), padding_config); - - auto operand = MakeUnique>(in_rows, in_cols); - operand->FillUnique(0.0f); - auto input_literal = Literal::CreateR2FromArray2D(*operand); - std::unique_ptr input_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + auto padded = b.Pad(input, AddParam(*Literal::CreateR0(3.14f), &b), + padding_config); auto expected = ReferenceUtil::PadArray2D(*operand, padding_config, 3.14f); - ComputeAndCompareR2(&b, *expected, {input_data.get()}, - ErrorSpec{0.0001}); + ComputeAndCompareR2(&b, *expected, {}, DefaultErrorSpec()); } -XLA_TEST_F(PadTest, High2DPad) { +XLA_TEST_P(PadTestFloat, High2DPad) { ComputationBuilder b(client_, TestName()); constexpr int64 in_rows = 129; @@ -349,8 +367,9 @@ XLA_TEST_F(PadTest, High2DPad) { constexpr int64 low_padding = 0; int64 high_padding[2] = {5, 7}; constexpr int64 interior_padding = 0; - auto input = - b.Parameter(0, ShapeUtil::MakeShape(F32, {in_rows, in_cols}), "input"); + auto operand = MakeUnique>(in_rows, in_cols); + operand->FillUnique(1.0f); + auto input = AddParam(*operand, &b); PaddingConfig padding_config = MakeNoPaddingConfig(2); for (int dim : {0, 1}) { padding_config.mutable_dimensions(dim)->set_edge_padding_low(low_padding); @@ -359,20 +378,15 @@ XLA_TEST_F(PadTest, High2DPad) { padding_config.mutable_dimensions(dim)->set_interior_padding( interior_padding); } - auto padded = b.Pad(input, b.ConstantR0(2.718f), padding_config); + auto padded = b.Pad(input, AddParam(*Literal::CreateR0(2.718f), &b), + padding_config); - auto operand = MakeUnique>(in_rows, in_cols); - operand->FillUnique(1.0f); - auto input_literal = Literal::CreateR2FromArray2D(*operand); auto expected = ReferenceUtil::PadArray2D(*operand, padding_config, 2.718f); - std::unique_ptr input_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); - ComputeAndCompareR2(&b, *expected, {input_data.get()}, - ErrorSpec(0.0001)); + ComputeAndCompareR2(&b, *expected, {}, DefaultErrorSpec()); } -XLA_TEST_F(PadTest, NegativePadding2D) { +XLA_TEST_P(PadTestFloat, NegativePadding2D) { ComputationBuilder b(client_, TestName()); constexpr int64 in_rows = 129; @@ -380,8 +394,9 @@ XLA_TEST_F(PadTest, NegativePadding2D) { int64 low_padding[2] = {-1, -2}; int64 high_padding[2] = {-3, 4}; constexpr int64 interior_padding = 0; - auto input = - b.Parameter(0, ShapeUtil::MakeShape(F32, {in_rows, in_cols}), "input"); + auto operand = MakeUnique>(in_rows, in_cols); + operand->FillUnique(1.0f); + auto input = AddParam(*operand, &b); PaddingConfig padding_config = MakeNoPaddingConfig(2); for (int dim : {0, 1}) { padding_config.mutable_dimensions(dim)->set_edge_padding_low( @@ -391,20 +406,15 @@ XLA_TEST_F(PadTest, NegativePadding2D) { padding_config.mutable_dimensions(dim)->set_interior_padding( interior_padding); } - auto padded = b.Pad(input, b.ConstantR0(2.718f), padding_config); + auto padded = b.Pad(input, AddParam(*Literal::CreateR0(2.718f), &b), + padding_config); - auto operand = MakeUnique>(in_rows, in_cols); - operand->FillUnique(1.0f); - auto input_literal = Literal::CreateR2FromArray2D(*operand); auto expected = ReferenceUtil::PadArray2D(*operand, padding_config, 2.718f); - std::unique_ptr input_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); - ComputeAndCompareR2(&b, *expected, {input_data.get()}, - ErrorSpec(0.0001)); + ComputeAndCompareR2(&b, *expected, {}, DefaultErrorSpec()); } -XLA_TEST_F(PadTest, NegativeAndInteriorPadding2D) { +XLA_TEST_P(PadTestFloat, NegativeAndInteriorPadding2D) { ComputationBuilder b(client_, TestName()); constexpr int64 in_rows = 8; @@ -412,8 +422,9 @@ XLA_TEST_F(PadTest, NegativeAndInteriorPadding2D) { int64 low_padding[2] = {4, -1}; int64 high_padding[2] = {-2, -4}; int64 interior_padding[2] = {1, 2}; - auto input = - b.Parameter(0, ShapeUtil::MakeShape(F32, {in_rows, in_cols}), "input"); + auto operand = MakeUnique>(in_rows, in_cols); + operand->FillUnique(1.0f); + auto input = AddParam(*operand, &b); PaddingConfig padding_config = MakeNoPaddingConfig(2); for (int dim : {0, 1}) { padding_config.mutable_dimensions(dim)->set_edge_padding_low( @@ -423,44 +434,40 @@ XLA_TEST_F(PadTest, NegativeAndInteriorPadding2D) { padding_config.mutable_dimensions(dim)->set_interior_padding( interior_padding[dim]); } - auto padded = b.Pad(input, b.ConstantR0(2.718f), padding_config); + auto padded = b.Pad(input, AddParam(*Literal::CreateR0(2.718f), &b), + padding_config); - auto operand = MakeUnique>(in_rows, in_cols); - operand->FillUnique(1.0f); - auto input_literal = Literal::CreateR2FromArray2D(*operand); auto expected = ReferenceUtil::PadArray2D(*operand, padding_config, 2.718f); - std::unique_ptr input_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); - ComputeAndCompareR2(&b, *expected, {input_data.get()}, - ErrorSpec(0.0001)); + ComputeAndCompareR2(&b, *expected, {}, DefaultErrorSpec()); } // Regression test for b/31827337. -XLA_TEST_F(PadTest, ReducePad) { +XLA_TEST_P(PadTestFloat, ReducePad) { ComputationBuilder b(client_, TestName()); - auto input = b.Parameter(0, ShapeUtil::MakeShape(F32, {2, 2, 2, 2}), "input"); + auto ones = MakeUnique>(2, 2, 2, 2); + ones->Fill(1.0); + auto input = AddParam(*ones, &b); - Computation add_f32 = CreateScalarAddComputation(F32, &b); - auto reduce = b.Reduce(input, b.ConstantR0(0.0), add_f32, {0}); + Computation add = CreateScalarAddComputation(FloatType(), &b); + auto reduce = + b.Reduce(input, AddParam(*Literal::CreateR0(0.0), &b), add, {0}); PaddingConfig padding_config = MakeNoPaddingConfig(3); padding_config.mutable_dimensions(0)->set_edge_padding_low(1); padding_config.mutable_dimensions(0)->set_edge_padding_high(1); - auto pad = b.Pad(reduce, b.ConstantR0(0.0), padding_config); - - auto ones = MakeUnique>(2, 2, 2, 2); - ones->Fill(1.0); - auto input_literal = Literal::CreateR4FromArray4D(*ones); - std::unique_ptr input_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + auto padded = b.Pad(reduce, AddParam(*Literal::CreateR0(0.0f), &b), + padding_config); Array3D expected({{{0.0, 0.0}, {0.0, 0.0}}, {{2.0, 2.0}, {2.0, 2.0}}, {{2.0, 2.0}, {2.0, 2.0}}, {{0.0, 0.0}, {0.0, 0.0}}}); - ComputeAndCompareR3(&b, expected, {input_data.get()}); + ComputeAndCompareR3(&b, expected, {}, DefaultErrorSpec()); } +INSTANTIATE_TEST_CASE_P(PadTestFloatInstantiation, PadTestFloat, + ::testing::ValuesIn(use_bfloat16_params)); + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/reduce_test.cc b/tensorflow/compiler/xla/tests/reduce_test.cc index a766fa2db0e193c52171490981855843ab3ee158..50d7b5074d201d2292cf90224ef4cd37efdbb8d3 100644 --- a/tensorflow/compiler/xla/tests/reduce_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_test.cc @@ -494,6 +494,26 @@ XLA_TEST_F(ReduceTest, TransposeAndReduceElementwiseR2_111x50_To_R1) { ErrorSpec(0.01, 1e-4)); } +// Test that algebraic simplifier does not incorrectly fold a transpose into a +// reduction operation. +XLA_TEST_F(ReduceTest, TransposeAndReduceR3_12x111x50_To_R2) { + ComputationBuilder builder(client_, TestName()); + Computation add_f32 = CreateScalarAddComputation(F32, &builder); + const Shape input_shape = ShapeUtil::MakeShape(F32, {12, 111, 50}); + ComputationDataHandle input = builder.Parameter(0, input_shape, "input"); + ComputationDataHandle zero = builder.ConstantR0(0.0); + ComputationDataHandle transpose = + builder.Transpose(input, /*permutation=*/{1, 0, 2}); + ComputationDataHandle reduce = + builder.Reduce(transpose, zero, add_f32, /*dimensions_to_reduce=*/{0}); + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr input_data, + MakeFakeLiteral(input_shape)); + + ComputeAndCompare(&builder, reduce, {std::move(*input_data)}, + ErrorSpec(0.01, 1e-4)); +} + XLA_TEST_F(ReduceTest, Reshape_111x2x25Reduce_111x50_To_R1) { const int64 rows = 111, cols = 50; diff --git a/tensorflow/compiler/xla/tests/scalar_computations_test.cc b/tensorflow/compiler/xla/tests/scalar_computations_test.cc index debf2d2d317fe64ca1ef86cb1f2978e76af1b55d..4da6ee91607941b395b00befc98a10e7c17746ed 100644 --- a/tensorflow/compiler/xla/tests/scalar_computations_test.cc +++ b/tensorflow/compiler/xla/tests/scalar_computations_test.cc @@ -737,7 +737,61 @@ XLA_TEST_F(ScalarComputationsTest, PowScalar) { ComputeAndCompareR0(&builder, 8.0, {}, error_spec_); } -XLA_TEST_F(ScalarComputationsTest, ClampScalarHigh) { +XLA_TEST_F(ScalarComputationsTest, ClampScalarHighS32) { + ComputationBuilder builder(client_, TestName()); + builder.Clamp(builder.ConstantR0(-1), // The lower bound. + builder.ConstantR0(5), // The operand to be clamped. + builder.ConstantR0(3)); // The upper bound. + + ComputeAndCompareR0(&builder, 3, {}); +} + +XLA_TEST_F(ScalarComputationsTest, ClampScalarMiddleS32) { + ComputationBuilder builder(client_, TestName()); + builder.Clamp(builder.ConstantR0(-1), // The lower bound. + builder.ConstantR0(2), // The operand to be clamped. + builder.ConstantR0(3)); // The upper bound. + + ComputeAndCompareR0(&builder, 2, {}); +} + +XLA_TEST_F(ScalarComputationsTest, ClampScalarLowS32) { + ComputationBuilder builder(client_, TestName()); + builder.Clamp(builder.ConstantR0(-1), // The lower bound. + builder.ConstantR0(-5), // The operand to be clamped. + builder.ConstantR0(3)); // The upper bound. + + ComputeAndCompareR0(&builder, -1, {}); +} + +XLA_TEST_F(ScalarComputationsTest, ClampScalarHighU32) { + ComputationBuilder builder(client_, TestName()); + builder.Clamp(builder.ConstantR0(1), // The lower bound. + builder.ConstantR0(5), // The operand to be clamped. + builder.ConstantR0(3)); // The upper bound. + + ComputeAndCompareR0(&builder, 3, {}); +} + +XLA_TEST_F(ScalarComputationsTest, ClampScalarMiddleU32) { + ComputationBuilder builder(client_, TestName()); + builder.Clamp(builder.ConstantR0(1), // The lower bound. + builder.ConstantR0(2), // The operand to be clamped. + builder.ConstantR0(3)); // The upper bound. + + ComputeAndCompareR0(&builder, 2, {}); +} + +XLA_TEST_F(ScalarComputationsTest, ClampScalarLowU32) { + ComputationBuilder builder(client_, TestName()); + builder.Clamp(builder.ConstantR0(1), // The lower bound. + builder.ConstantR0(0), // The operand to be clamped. + builder.ConstantR0(3)); // The upper bound. + + ComputeAndCompareR0(&builder, 1, {}); +} + +XLA_TEST_F(ScalarComputationsTest, ClampScalarHighF32) { ComputationBuilder builder(client_, TestName()); builder.Clamp(builder.ConstantR0(2.0f), // The lower bound. builder.ConstantR0(5.0f), // The operand to be clamped. @@ -746,7 +800,7 @@ XLA_TEST_F(ScalarComputationsTest, ClampScalarHigh) { ComputeAndCompareR0(&builder, 3.0, {}, error_spec_); } -XLA_TEST_F(ScalarComputationsTest, ClampScalarMiddle) { +XLA_TEST_F(ScalarComputationsTest, ClampScalarMiddleF32) { ComputationBuilder builder(client_, TestName()); builder.Clamp(builder.ConstantR0(2.0f), // The lower bound. builder.ConstantR0(2.5f), // The operand to be clamped. @@ -755,7 +809,7 @@ XLA_TEST_F(ScalarComputationsTest, ClampScalarMiddle) { ComputeAndCompareR0(&builder, 2.5, {}, error_spec_); } -XLA_TEST_F(ScalarComputationsTest, ClampScalarLow) { +XLA_TEST_F(ScalarComputationsTest, ClampScalarLowF32) { ComputationBuilder builder(client_, TestName()); builder.Clamp(builder.ConstantR0(2.0f), // The lower bound. builder.ConstantR0(-5.0f), // The operand to be clamped. @@ -852,5 +906,12 @@ XLA_TEST_F(ScalarComputationsTest, SqrtF320) { ComputeAndCompareR0(&builder, 0.0f, {zero_data.get()}, error_spec_); } +XLA_TEST_F(ScalarComputationsTest, RoundScalar) { + ComputationBuilder builder(client_, TestName()); + builder.Round(builder.ConstantR0(1.4f)); + + ComputeAndCompareR0(&builder, 1.0f, {}, error_spec_); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc index 4ad356d045bbbd106c3f1a3271a684554edd48a0..b82f1c81c84b487c1661af5267b9123da97bb107 100644 --- a/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc +++ b/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc @@ -85,10 +85,12 @@ void RealMain(tensorflow::gtl::ArraySlice args) { for (int i = 0; i < program_shape->parameters_size(); ++i) { layouts.push_back(&program_shape->parameters(i)); } + ExecutableBuildOptions build_options; + build_options.set_device_ordinal(0); + build_options.set_result_layout(program_shape->result()); StatusOr> executable = - local_service->CompileExecutable( - computation.handle(), layouts, &program_shape->result(), - /*device_ordinal=*/0, /*device_allocator=*/nullptr); + local_service->CompileExecutable(computation.handle(), layouts, + build_options); const HloModule& module = executable.ValueOrDie()->module(); diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc index 5ebb75a31c1280882aa56b30ad4d568651515492..05c0fdf97d27c09eb2bbb0f265b5b2a5982ca7b1 100644 --- a/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc +++ b/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc @@ -60,10 +60,13 @@ void RealMain(tensorflow::gtl::ArraySlice args, bool compile) { for (int i = 0; i < program_shape->parameters_size(); ++i) { layouts.push_back(&program_shape->parameters(i)); } + + ExecutableBuildOptions build_options; + build_options.set_device_ordinal(0); + build_options.set_result_layout(program_shape->result()); StatusOr> executable = - local_service->CompileExecutable( - computation.handle(), layouts, &program_shape->result(), - /*device_ordinal=*/0, /*device_allocator=*/nullptr); + local_service->CompileExecutable(computation.handle(), layouts, + build_options); const HloModule& module = executable.ValueOrDie()->module(); diff --git a/tensorflow/compiler/xla/tools/hlo_proto_to_json.cc b/tensorflow/compiler/xla/tools/hlo_proto_to_json.cc index 4e02e17db65c0a4220672733be8319e1a0cc4f0f..8460ae3e4991ee091af72d2553a8491f627c722e 100644 --- a/tensorflow/compiler/xla/tools/hlo_proto_to_json.cc +++ b/tensorflow/compiler/xla/tools/hlo_proto_to_json.cc @@ -19,7 +19,7 @@ limitations under the License. // // Reads one serilized Hlo module, convert it into JSON format and dump into // some output directory. some_binaray_proto is obtained by serializing Hlo -// module to disk using --xla_dump_hlo_proto_to debug optoin. +// module to disk using --xla_dump_optimized_hlo_proto_to debug option. #include #include diff --git a/tensorflow/compiler/xla/util.cc b/tensorflow/compiler/xla/util.cc index b0209050350e6d9a70ab14c6f9ed6577809f7801..1f0c626bbb2d64ef4e67c9ec51485ae96ae73d04 100644 --- a/tensorflow/compiler/xla/util.cc +++ b/tensorflow/compiler/xla/util.cc @@ -339,7 +339,7 @@ std::vector> CommonFactors( string SanitizeFileName(string file_name) { for (char& c : file_name) { - if (c == '/' || c == '\\' || c == '[' || c == ']') { + if (c == '/' || c == '\\' || c == '[' || c == ']' || c == ' ') { c = '_'; } } diff --git a/tensorflow/compiler/xla/util.h b/tensorflow/compiler/xla/util.h index 2da9bb21b7ad20ffbdcce600efda946155061cb5..08df5b12b3a53a138f56705531baa3333b23c5d8 100644 --- a/tensorflow/compiler/xla/util.h +++ b/tensorflow/compiler/xla/util.h @@ -217,6 +217,24 @@ Status Unavailable(const char* format, ...) TF_PRINTF_ATTRIBUTE(1, 2); // Passed-varargs variant of the InvalidArgument factory above. Status InvalidArgumentV(const char* format, va_list args); +template +Status UnimplementedStrCat(Args&&... concat) { + return Unimplemented( + "%s", tensorflow::strings::StrCat(std::forward(concat)...).c_str()); +} + +template +Status InternalErrorStrCat(Args&&... concat) { + return InternalError( + "%s", tensorflow::strings::StrCat(std::forward(concat)...).c_str()); +} + +template +Status ResourceExhaustedStrCat(Args&&... concat) { + return ResourceExhausted( + "%s", tensorflow::strings::StrCat(std::forward(concat)...).c_str()); +} + // Splits the lines of the original, replaces leading whitespace with the prefix // given by "indentation", and returns the string joined by newlines again. As a // side effect, any additional trailing whitespace is removed. diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto index e1ed08c8480fa73e9c5ff914bb9f5e38f1ce96e9..56162ab44e2e0e3e4478fe631888f243332dc1d8 100644 --- a/tensorflow/compiler/xla/xla.proto +++ b/tensorflow/compiler/xla/xla.proto @@ -82,8 +82,9 @@ message DebugOptions { // Dump all HLO modules as text into the provided directory path. string xla_generate_hlo_text_to = 7; - // Dump compilation artifacts in binary proto into this directory. - string xla_dump_hlo_proto_to = 8; + // Dump Hlo after all hlo passes are executed as proto binary into this + // directory. + string xla_dump_optimized_hlo_proto_to = 8; // Instrument the computation to collect per-HLO cycle counts. bool xla_hlo_profile = 9; @@ -179,9 +180,13 @@ message DebugOptions { // ops. bool xla_gpu_use_cudnn_batchnorm = 94; - // Dump compilation artifacts, before hlo passes are executed, in binary proto - // into this directory. - string xla_dump_prepass_hlo_proto_to = 95; + // Dump HLO before any hlo passes are executed as proto binary into this + // directory. + string xla_dump_unoptimized_hlo_proto_to = 95; + + // Dump HLO after each pass as an HloProto in binary file format into this + // directory. + string xla_dump_per_pass_hlo_proto_to = 96; // Extra options to pass to the compilation backend; specific interpretation // of these values is left to the backend. diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD index 5ac5955626a83439f5a73e961e2ce056739956fe..1c497c666bd73e0ae39b750307c25e3b28bfaf2d 100644 --- a/tensorflow/contrib/BUILD +++ b/tensorflow/contrib/BUILD @@ -24,6 +24,7 @@ py_library( "//tensorflow/contrib/bayesflow:bayesflow_py", "//tensorflow/contrib/boosted_trees:init_py", "//tensorflow/contrib/cloud:cloud_py", + "//tensorflow/contrib/cluster_resolver:cluster_resolver_pip", "//tensorflow/contrib/cluster_resolver:cluster_resolver_py", "//tensorflow/contrib/coder:coder_ops_py", "//tensorflow/contrib/compiler:compiler_py", @@ -77,6 +78,7 @@ py_library( "//tensorflow/contrib/predictor", "//tensorflow/contrib/quantization:quantization_py", "//tensorflow/contrib/quantize:quantize_graph", + "//tensorflow/contrib/py2tf", "//tensorflow/contrib/receptive_field:receptive_field_py", "//tensorflow/contrib/reduce_slice_ops:reduce_slice_ops_py", "//tensorflow/contrib/remote_fused_graph/pylib:remote_fused_graph_ops_py", diff --git a/tensorflow/contrib/android/jni/run_stats_jni.cc b/tensorflow/contrib/android/jni/run_stats_jni.cc index 119fa9cd2c378d2ba2383ea8b0e09e1b6083d84e..707853b59befc2625145ad96952fbf9f66d62b43 100644 --- a/tensorflow/contrib/android/jni/run_stats_jni.cc +++ b/tensorflow/contrib/android/jni/run_stats_jni.cc @@ -21,8 +21,8 @@ limitations under the License. #include "tensorflow/core/protobuf/config.pb.h" #include "tensorflow/core/util/stat_summarizer.h" -using tensorflow::StatSummarizer; using tensorflow::RunMetadata; +using tensorflow::StatSummarizer; namespace { StatSummarizer* requireHandle(JNIEnv* env, jlong handle) { diff --git a/tensorflow/contrib/bayesflow/BUILD b/tensorflow/contrib/bayesflow/BUILD index 11c3c037c4e8b4ba41eae60d28d6aac49f1488f2..6e0f0a05726a46b513a4270fd5843ff20fc95a18 100644 --- a/tensorflow/contrib/bayesflow/BUILD +++ b/tensorflow/contrib/bayesflow/BUILD @@ -217,6 +217,7 @@ cuda_py_test( "//tensorflow/python:platform_test", "//tensorflow/python:random_seed", ], + tags = ["notsan"], ) cuda_py_test( diff --git a/tensorflow/contrib/boosted_trees/kernels/model_ops.cc b/tensorflow/contrib/boosted_trees/kernels/model_ops.cc index 4b5d5ba0de6c3995ee2da7a44ab0ba099cbf1b35..754b7bc3270d647fc381033b769eadd7b791771e 100644 --- a/tensorflow/contrib/boosted_trees/kernels/model_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/model_ops.cc @@ -48,8 +48,9 @@ class CreateTreeEnsembleVariableOp : public OpKernel { if (!result->InitFromSerialized(tree_ensemble_config_t->scalar()(), stamp_token)) { result->Unref(); - OP_REQUIRES(context, false, errors::InvalidArgument( - "Unable to parse tree ensemble config.")); + OP_REQUIRES( + context, false, + errors::InvalidArgument("Unable to parse tree ensemble config.")); } // Only create one, if one does not exist already. Report status for all diff --git a/tensorflow/contrib/boosted_trees/kernels/prediction_ops.cc b/tensorflow/contrib/boosted_trees/kernels/prediction_ops.cc index f8086b0c2bb93eae6af0336bbe33fc23f8fcde22..b3fe38614e05801b223f0c96f7a70ce7e432a70b 100644 --- a/tensorflow/contrib/boosted_trees/kernels/prediction_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/prediction_ops.cc @@ -47,8 +47,8 @@ namespace boosted_trees { using boosted_trees::learner::LearnerConfig; using boosted_trees::learner::LearningRateConfig; using boosted_trees::learner::LearningRateDropoutDrivenConfig; -using boosted_trees::models::MultipleAdditiveTrees; using boosted_trees::models::DecisionTreeEnsembleResource; +using boosted_trees::models::MultipleAdditiveTrees; using boosted_trees::utils::DropoutUtils; using boosted_trees::utils::TensorUtils; diff --git a/tensorflow/contrib/boosted_trees/kernels/quantile_ops.cc b/tensorflow/contrib/boosted_trees/kernels/quantile_ops.cc index e91232bf1054048b6b3fd54e678980137b73871a..0f4c2298f56be48bb32f52d5d44cff8afe284f1e 100644 --- a/tensorflow/contrib/boosted_trees/kernels/quantile_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/quantile_ops.cc @@ -36,8 +36,8 @@ namespace tensorflow { using ::boosted_trees::QuantileConfig; -using boosted_trees::utils::TensorUtils; using boosted_trees::QuantileStreamResource; +using boosted_trees::utils::TensorUtils; namespace { const char* const kExampleWeightsName = "example_weights"; @@ -384,7 +384,7 @@ class MakeQuantileSummariesOp : public OpKernel { protobuf::Arena arena; ::boosted_trees::QuantileSummaryState* summary_proto = protobuf::Arena::CreateMessage< - ::boosted_trees::QuantileSummaryState>(&arena); + ::boosted_trees::QuantileSummaryState>(&arena); const auto& summary = stream.GetFinalSummary(); CopySummaryToProto(summary, summary_proto); // Output to tensor. diff --git a/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc b/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc index 18b4abd654ea3541d646a43ac901aca1a678446f..44a8ffaf4b2f5a9c11b3abc46ce55a18c80ad318 100644 --- a/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc @@ -34,10 +34,10 @@ namespace tensorflow { +using boosted_trees::learner::LearnerConfig_MultiClassStrategy; using boosted_trees::learner::SplitInfo; using boosted_trees::learner::stochastic::GradientStats; using boosted_trees::learner::stochastic::NodeStats; -using boosted_trees::learner::LearnerConfig_MultiClassStrategy; namespace { const int32 DUMMY_FEATURE_DIMENSION = -1; @@ -47,9 +47,8 @@ class BaseBuildSplitOp : public OpKernel { public: explicit BaseBuildSplitOp(OpKernelConstruction* const context) : OpKernel(context) { - OP_REQUIRES_OK( - context, - context->GetAttr("feature_column_group_id", &feature_column_group_id_)); + OP_REQUIRES_OK(context, context->GetAttr("feature_column_group_id", + &feature_column_group_id_)); OP_REQUIRES_OK(context, context->GetAttr("l1_regularization", &l1_regularization_)); OP_REQUIRES_OK(context, diff --git a/tensorflow/contrib/boosted_trees/kernels/stats_accumulator_ops.cc b/tensorflow/contrib/boosted_trees/kernels/stats_accumulator_ops.cc index a9a229c8ae0c26bba5f0a684dad7e546298577bb..90a0655201f8cb8df6fc6417cb51216dec91b4d7 100644 --- a/tensorflow/contrib/boosted_trees/kernels/stats_accumulator_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/stats_accumulator_ops.cc @@ -134,10 +134,9 @@ void SerializeScalarAccumulatorToOutput( OpKernelContext* context) { int64 num_slots = accumulator_resource.values().size(); Tensor* partition_ids_t = nullptr; - OP_REQUIRES_OK( - context, - context->allocate_output("output_partition_ids", TensorShape({num_slots}), - &partition_ids_t)); + OP_REQUIRES_OK(context, context->allocate_output("output_partition_ids", + TensorShape({num_slots}), + &partition_ids_t)); auto partition_ids = partition_ids_t->vec(); // Feature ids tensor has ids of feature columns and their dimensions. @@ -149,15 +148,14 @@ void SerializeScalarAccumulatorToOutput( Tensor* gradients_t = nullptr; OP_REQUIRES_OK( - context, - context->allocate_output("output_gradients", TensorShape({num_slots}), - &gradients_t)); + context, context->allocate_output( + "output_gradients", TensorShape({num_slots}), &gradients_t)); auto gradients = gradients_t->vec(); Tensor* hessians_t = nullptr; - OP_REQUIRES_OK(context, - context->allocate_output( - "output_hessians", TensorShape({num_slots}), &hessians_t)); + OP_REQUIRES_OK( + context, context->allocate_output("output_hessians", + TensorShape({num_slots}), &hessians_t)); auto hessians = hessians_t->vec(); int i = 0; @@ -177,10 +175,9 @@ void SerializeTensorAccumulatorToOutput( OpKernelContext* context) { int64 num_slots = accumulator_resource.values().size(); Tensor* partition_ids_t = nullptr; - OP_REQUIRES_OK( - context, - context->allocate_output("output_partition_ids", TensorShape({num_slots}), - &partition_ids_t)); + OP_REQUIRES_OK(context, context->allocate_output("output_partition_ids", + TensorShape({num_slots}), + &partition_ids_t)); auto partition_ids = partition_ids_t->vec(); Tensor* feature_ids_t = nullptr; @@ -202,9 +199,8 @@ void SerializeTensorAccumulatorToOutput( int64 num_hessian_elements = hessian_shape.num_elements(); hessian_shape.InsertDim(0, num_slots); Tensor* hessians_t = nullptr; - OP_REQUIRES_OK( - context, - context->allocate_output("output_hessians", hessian_shape, &hessians_t)); + OP_REQUIRES_OK(context, context->allocate_output("output_hessians", + hessian_shape, &hessians_t)); auto hessians = hessians_t->flat_outer_dims(); int i = 0; diff --git a/tensorflow/contrib/boosted_trees/lib/learner/common/stats/node-stats_test.cc b/tensorflow/contrib/boosted_trees/lib/learner/common/stats/node-stats_test.cc index f867e77d3ef0609774628b2a9c36ca52bcf2a957..8bca132acfde9397942b198db9a8d4c0e4d74897 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/common/stats/node-stats_test.cc +++ b/tensorflow/contrib/boosted_trees/lib/learner/common/stats/node-stats_test.cc @@ -17,8 +17,8 @@ #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/platform/test.h" -using tensorflow::test::AsTensor; using std::vector; +using tensorflow::test::AsTensor; namespace tensorflow { namespace boosted_trees { diff --git a/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_stream.h b/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_stream.h index 1c4181f1b13b01f85833157e554c3b821f96ff90..8ad97fedc923ac50bcaad86e0ba2c2e46df6821b 100644 --- a/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_stream.h +++ b/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_stream.h @@ -15,9 +15,9 @@ #ifndef TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_QUANTILES_WEIGHTED_QUANTILES_STREAM_H_ #define TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_QUANTILES_WEIGHTED_QUANTILES_STREAM_H_ +#include #include #include -#include #include "tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_buffer.h" #include "tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_summary.h" diff --git a/tensorflow/contrib/boosted_trees/lib/testutil/random_tree_gen.cc b/tensorflow/contrib/boosted_trees/lib/testutil/random_tree_gen.cc index cbe26ba918d384ad903fb854ca3e88e84d16a923..705b65e9db9f1aed9af1be153240d57e163c2d5b 100644 --- a/tensorflow/contrib/boosted_trees/lib/testutil/random_tree_gen.cc +++ b/tensorflow/contrib/boosted_trees/lib/testutil/random_tree_gen.cc @@ -22,9 +22,9 @@ namespace tensorflow { namespace boosted_trees { namespace testutil { +using boosted_trees::trees::DenseFloatBinarySplit; using tensorflow::boosted_trees::trees::DecisionTreeConfig; using tensorflow::boosted_trees::trees::TreeNode; -using boosted_trees::trees::DenseFloatBinarySplit; namespace { diff --git a/tensorflow/contrib/boosted_trees/lib/utils/batch_features_test.cc b/tensorflow/contrib/boosted_trees/lib/utils/batch_features_test.cc index 9de3e32b097a151b3bd6f5c30df2db0938b65e9c..609519e8b1153a27d987c5f9ca9bfcc9ee6717d6 100644 --- a/tensorflow/contrib/boosted_trees/lib/utils/batch_features_test.cc +++ b/tensorflow/contrib/boosted_trees/lib/utils/batch_features_test.cc @@ -25,8 +25,8 @@ namespace boosted_trees { namespace utils { namespace { -using test::AsTensor; using errors::InvalidArgument; +using test::AsTensor; class BatchFeaturesTest : public ::testing::Test {}; diff --git a/tensorflow/contrib/boosted_trees/lib/utils/dropout_utils.cc b/tensorflow/contrib/boosted_trees/lib/utils/dropout_utils.cc index 38f0151255bbf4fcd87f1d0d76fd111649ee4a12..db34db998a7442c69f2ab468f4557d991429f4ee 100644 --- a/tensorflow/contrib/boosted_trees/lib/utils/dropout_utils.cc +++ b/tensorflow/contrib/boosted_trees/lib/utils/dropout_utils.cc @@ -23,10 +23,10 @@ #include "tensorflow/core/lib/random/simple_philox.h" #include "tensorflow/core/platform/logging.h" +using tensorflow::Status; using tensorflow::boosted_trees::learner::LearningRateDropoutDrivenConfig; using tensorflow::random::PhiloxRandom; using tensorflow::random::SimplePhilox; -using tensorflow::Status; namespace tensorflow { namespace boosted_trees { diff --git a/tensorflow/contrib/boosted_trees/lib/utils/dropout_utils_test.cc b/tensorflow/contrib/boosted_trees/lib/utils/dropout_utils_test.cc index ce7632e58987f5890beaded5dd305724f950e1e8..02f972c8e00e8229426ac53d8f20765484787b6e 100644 --- a/tensorflow/contrib/boosted_trees/lib/utils/dropout_utils_test.cc +++ b/tensorflow/contrib/boosted_trees/lib/utils/dropout_utils_test.cc @@ -26,9 +26,9 @@ #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/env.h" +using std::unordered_set; using tensorflow::boosted_trees::learner::LearningRateDropoutDrivenConfig; using tensorflow::boosted_trees::trees::DecisionTreeEnsembleConfig; -using std::unordered_set; namespace tensorflow { namespace boosted_trees { diff --git a/tensorflow/contrib/boosted_trees/ops/quantile_ops.cc b/tensorflow/contrib/boosted_trees/ops/quantile_ops.cc index bb57dcf8ae7475486bcc0fc82460cbbce9a18b68..ae99d53a2cf805d70d60746cd44f73f7fd9dc6e2 100644 --- a/tensorflow/contrib/boosted_trees/ops/quantile_ops.cc +++ b/tensorflow/contrib/boosted_trees/ops/quantile_ops.cc @@ -19,8 +19,8 @@ namespace tensorflow { namespace boosted_trees { -using shape_inference::InferenceContext; using shape_inference::DimensionHandle; +using shape_inference::InferenceContext; using shape_inference::ShapeHandle; REGISTER_RESOURCE_HANDLE_OP(QuantileStreamResource); diff --git a/tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc b/tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc index 0d27ddaf3a1d540efee268c2bcca217077ff5871..5d0ebbf73ce1272b51a475f67984db3a181b7130 100644 --- a/tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc +++ b/tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc @@ -18,9 +18,9 @@ namespace tensorflow { +using shape_inference::DimensionHandle; using shape_inference::InferenceContext; using shape_inference::ShapeHandle; -using shape_inference::DimensionHandle; REGISTER_OP("BuildDenseInequalitySplits") .Attr("feature_column_group_id: int") diff --git a/tensorflow/contrib/boosted_trees/ops/stats_accumulator_ops.cc b/tensorflow/contrib/boosted_trees/ops/stats_accumulator_ops.cc index 0354f7853cbedf22d0a299273b4dbd225b3121ab..179505eef01f79bb149137400468b84285fe478a 100644 --- a/tensorflow/contrib/boosted_trees/ops/stats_accumulator_ops.cc +++ b/tensorflow/contrib/boosted_trees/ops/stats_accumulator_ops.cc @@ -19,9 +19,9 @@ namespace tensorflow { namespace boosted_trees { +using shape_inference::DimensionHandle; using shape_inference::InferenceContext; using shape_inference::ShapeHandle; -using shape_inference::DimensionHandle; REGISTER_RESOURCE_HANDLE_OP(StatsAccumulatorScalarResource); diff --git a/tensorflow/contrib/cloud/kernels/bigquery_table_accessor_test_data.h b/tensorflow/contrib/cloud/kernels/bigquery_table_accessor_test_data.h index 59f23332983e2328286d3b1b8b8c8fa228be991e..fea6b15640ded74432f35112bc5d5d68e641c9dc 100644 --- a/tensorflow/contrib/cloud/kernels/bigquery_table_accessor_test_data.h +++ b/tensorflow/contrib/cloud/kernels/bigquery_table_accessor_test_data.h @@ -399,6 +399,6 @@ const string kTestEmptyRow = R"({ }]}]})"; } // namespace -} // namepsace tensorflow +} // namespace tensorflow #endif // TENSORFLOW_CORE_KERNELS_CLOUD_BIGQUERY_TABLE_ACCESSOR_TEST_DATA_H_ diff --git a/tensorflow/contrib/cluster_resolver/BUILD b/tensorflow/contrib/cluster_resolver/BUILD index 15abd2be0385eb776ff4f76484133efb6e34f076..80e18a43a71cc9d6c9e2ccf5836e50c6427a30f6 100644 --- a/tensorflow/contrib/cluster_resolver/BUILD +++ b/tensorflow/contrib/cluster_resolver/BUILD @@ -34,6 +34,7 @@ py_library( ":cluster_resolver_py", ":gce_cluster_resolver_py", ":tpu_cluster_resolver_py", + "//tensorflow/python:util", ], ) diff --git a/tensorflow/contrib/cluster_resolver/__init__.py b/tensorflow/contrib/cluster_resolver/__init__.py index d17501e87e79158b1602ac6ddecc091bd86f2c2d..b4d8cd4a7cf42e910e7506dbeec8656a2cef62eb 100644 --- a/tensorflow/contrib/cluster_resolver/__init__.py +++ b/tensorflow/contrib/cluster_resolver/__init__.py @@ -26,3 +26,15 @@ from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import from tensorflow.contrib.cluster_resolver.python.training.gce_cluster_resolver import GceClusterResolver from tensorflow.contrib.cluster_resolver.python.training.tpu_cluster_resolver import TPUClusterResolver # pylint: enable=wildcard-import,unused-import + +from tensorflow.python.util.all_util import remove_undocumented + +_allowed_symbols = [ + 'ClusterResolver', + 'SimpleClusterResolver', + 'UnionClusterResolver', + 'GceClusterResolver', + 'TPUClusterResolver', +] + +remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py index 2e75ac226ea74e879edda5e03dff3d53c8a76569..a6a6e642e4e4c721b94821a70d55d6fe931347d6 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py +++ b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py @@ -143,7 +143,8 @@ class TPUClusterResolver(ClusterResolver): request = self._service.projects().locations().nodes().get(name=full_name) response = request.execute() - instance_url = '%s:%s' % (response['ipAddress'], response['port']) - worker_list.append(instance_url) + if 'health' in response and response['health'] == 'HEALTHY': + instance_url = '%s:%s' % (response['ipAddress'], response['port']) + worker_list.append(instance_url) return ClusterSpec({self._job_name: worker_list}) diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py index 0c4730613af4ad9ca87deb6200ab4bb93d3f6a53..4fd34629cf74f90869c77b8cb098d3c585a49404 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py +++ b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py @@ -105,7 +105,8 @@ class TPUClusterResolverTest(test.TestCase): tpu_map = { 'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': { 'ipAddress': '10.1.2.3', - 'port': '8470' + 'port': '8470', + 'health': 'HEALTHY' } } @@ -126,7 +127,8 @@ class TPUClusterResolverTest(test.TestCase): tpu_map = { 'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': { 'ipAddress': '10.1.2.3', - 'port': '8470' + 'port': '8470', + 'health': 'HEALTHY' } } @@ -147,11 +149,13 @@ class TPUClusterResolverTest(test.TestCase): tpu_map = { 'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': { 'ipAddress': '10.1.2.3', - 'port': '8470' + 'port': '8470', + 'health': 'HEALTHY' }, 'projects/test-project/locations/us-central1-c/nodes/test-tpu-2': { 'ipAddress': '10.4.5.6', - 'port': '8470' + 'port': '8470', + 'health': 'HEALTHY' } } @@ -169,15 +173,54 @@ class TPUClusterResolverTest(test.TestCase): """ self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto) + def testHealthyTpuNodeRetrieval(self): + tpu_map = { + 'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': { + 'ipAddress': '10.1.2.3', + 'port': '8470', + 'health': 'HEALTHY' + }, + 'projects/test-project/locations/us-central1-c/nodes/test-tpu-2': { + 'ipAddress': '10.4.5.6', + 'port': '8470', + }, + 'projects/test-project/locations/us-central1-c/nodes/test-tpu-3': { + 'ipAddress': '10.7.8.9', + 'port': '8470', + 'health': 'UNHEALTHY' + } + } + + tpu_cluster_resolver = TPUClusterResolver( + project='test-project', + zone='us-central1-c', + tpu_names=['test-tpu-2', 'test-tpu-1', 'test-tpu-3'], + credentials=None, + service=self.mock_service_client(tpu_map=tpu_map)) + + actual_cluster_spec = tpu_cluster_resolver.cluster_spec() + expected_proto = """ + job { + name: 'tpu_worker' + tasks { + key: 0 + value: '10.1.2.3:8470' + } + } + """ + self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto) + def testGetMasterMultipleEntries(self): tpu_map = { 'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': { 'ipAddress': '10.1.2.3', - 'port': '8470' + 'port': '8470', + 'health': 'HEALTHY' }, 'projects/test-project/locations/us-central1-c/nodes/test-tpu-2': { 'ipAddress': '10.4.5.6', - 'port': '8470' + 'port': '8470', + 'health': 'HEALTHY' } } diff --git a/tensorflow/contrib/cmake/external/protobuf.cmake b/tensorflow/contrib/cmake/external/protobuf.cmake index aedb793d2aef4bf6950cd074cd065909667eaf75..fd05fa6d47209edd825b6a97aa0b77b3f9cb8ee1 100644 --- a/tensorflow/contrib/cmake/external/protobuf.cmake +++ b/tensorflow/contrib/cmake/external/protobuf.cmake @@ -16,7 +16,7 @@ include (ExternalProject) set(PROTOBUF_INCLUDE_DIRS ${CMAKE_CURRENT_BINARY_DIR}/protobuf/src/protobuf/src) set(PROTOBUF_URL https://github.com/google/protobuf.git) -set(PROTOBUF_TAG b04e5cba356212e4e8c66c61bbe0c3a20537c5b9) +set(PROTOBUF_TAG 396336eb961b75f03b25824fe86cf6490fb75e3a) if(WIN32) set(protobuf_STATIC_LIBRARIES diff --git a/tensorflow/contrib/cmake/python_modules.txt b/tensorflow/contrib/cmake/python_modules.txt index 9ce8b3cc9cd4783c4b940ea9c7bf0b57fa2a3f28..a7938f1f0752a3e50ebdb18fbd81ed797bb037d7 100644 --- a/tensorflow/contrib/cmake/python_modules.txt +++ b/tensorflow/contrib/cmake/python_modules.txt @@ -6,6 +6,7 @@ tensorflow/core/example tensorflow/core/framework tensorflow/core/lib tensorflow/core/lib/core +tensorflow/core/profiler tensorflow/core/protobuf tensorflow/core/util tensorflow/examples @@ -216,6 +217,8 @@ tensorflow/contrib/input_pipeline/python/ops tensorflow/contrib/integrate tensorflow/contrib/integrate/python tensorflow/contrib/integrate/python/ops +tensorflow/contrib/kafka/python +tensorflow/contrib/kafka/python/ops tensorflow/contrib/keras tensorflow/contrib/keras/api tensorflow/contrib/keras/api/keras diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake index b7c816c24f82c7747f53b4c127866e6008085ef3..34c466fa01e5795fc477403b6fd9704dc0afa1e2 100755 --- a/tensorflow/contrib/cmake/tf_python.cmake +++ b/tensorflow/contrib/cmake/tf_python.cmake @@ -307,7 +307,7 @@ function(GENERATE_PYTHON_OP_LIB tf_python_op_lib_name) # containing the wrappers. add_custom_command( OUTPUT ${GENERATE_PYTHON_OP_LIB_DESTINATION} - COMMAND ${tf_python_op_lib_name}_gen_python ${tensorflow_source_dir}/tensorflow/core/api_def/base_api,${tensorflow_source_dir}/tensorflow/core/api_def/python_api @${tensorflow_source_dir}/tensorflow/python/ops/hidden_ops.txt ${require_shape_fn} > ${GENERATE_PYTHON_OP_LIB_DESTINATION} + COMMAND ${tf_python_op_lib_name}_gen_python ${tensorflow_source_dir}/tensorflow/core/api_def/base_api,${tensorflow_source_dir}/tensorflow/core/api_def/python_api ${require_shape_fn} > ${GENERATE_PYTHON_OP_LIB_DESTINATION} DEPENDS ${tf_python_op_lib_name}_gen_python ) diff --git a/tensorflow/contrib/cmake/tf_tests.cmake b/tensorflow/contrib/cmake/tf_tests.cmake index 2e79eadf7f566690a7742757ceb56e147ebd6ea0..73edd616ea43510d0b34196cbf9c4caba2c8219f 100644 --- a/tensorflow/contrib/cmake/tf_tests.cmake +++ b/tensorflow/contrib/cmake/tf_tests.cmake @@ -310,6 +310,8 @@ if (tensorflow_BUILD_PYTHON_TESTS) "${tensorflow_source_dir}/tensorflow/python/kernel_tests/control_flow_util_test.py" # Flaky replicate_model_fn_test "${tensorflow_source_dir}/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py" # b/71901810 + # Broken io_utils_test + "${tensorflow_source_dir}/tensorflow/python/keras/_impl/keras/utils/io_utils_test.py" # b/72894325 ) endif() list(REMOVE_ITEM tf_test_src_py ${tf_test_src_py_exclude}) diff --git a/tensorflow/contrib/cmake/tools/create_def_file.py b/tensorflow/contrib/cmake/tools/create_def_file.py index f67698eb99a38eae307b52e55de748a67b798cbd..77ea914380dfa4f3ec903e9fc7062d429f8c0d6f 100644 --- a/tensorflow/contrib/cmake/tools/create_def_file.py +++ b/tensorflow/contrib/cmake/tools/create_def_file.py @@ -31,6 +31,7 @@ from __future__ import division from __future__ import print_function import argparse +import codecs import io import os import re @@ -103,7 +104,7 @@ def main(): for lib_path in args.input: proc = subprocess.Popen([DUMPBIN, "/nologo", "/linkermember:1", lib_path], stdout=subprocess.PIPE) - for line in io.TextIOWrapper(proc.stdout, encoding="utf-8"): + for line in codecs.getreader("utf-8")(proc.stdout): cols = line.split() if len(cols) < 2: continue @@ -131,7 +132,7 @@ def main(): # We compare on undname but use the decorated name from candidates. dupes = 0 proc = subprocess.Popen([UNDNAME, tmpfile.name], stdout=subprocess.PIPE) - for idx, line in enumerate(io.TextIOWrapper(proc.stdout, encoding="utf-8")): + for idx, line in enumerate(codecs.getreader("utf-8")(proc.stdout)): decorated = candidates[idx] if decorated in taken: # Symbol is already in output, done. diff --git a/tensorflow/contrib/coder/kernels/range_coder.cc b/tensorflow/contrib/coder/kernels/range_coder.cc index f4f076b6c4e0c82cc297266bedc63034d5f5bf8b..21b35155ff317c6afbb1b86745f05385726505b6 100644 --- a/tensorflow/contrib/coder/kernels/range_coder.cc +++ b/tensorflow/contrib/coder/kernels/range_coder.cc @@ -276,7 +276,7 @@ void RangeEncoder::Finalize(string* sink) { } } else if (base_ != 0) { // If base == 0, then pick 0 from [base, base + size) and no zeros are - // explcitly written. + // explicitly written. // // Otherwise, pick (base + (2^16 - base[16:0])), i.e., round up base to the // next multiple of 2^16. As 2^16 < size, this value should be in the diff --git a/tensorflow/contrib/compiler/jit_test.py b/tensorflow/contrib/compiler/jit_test.py index 2108e42bce4eba1eed158fe85888f1699a69ba7e..29a593f6bcfa05dcafcdb2f94087380ad720dba1 100644 --- a/tensorflow/contrib/compiler/jit_test.py +++ b/tensorflow/contrib/compiler/jit_test.py @@ -24,6 +24,7 @@ from tensorflow.python.framework import function from tensorflow.python.framework import op_def_registry from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed +from tensorflow.python.framework import test_util from tensorflow.python.ops import gradients from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops @@ -169,6 +170,7 @@ class JITTest(test.TestCase): self.assertEqual(b"jit_scope_0", func_attrs["_XlaScope"].s) +@test_util.with_c_api class CompilationEnabledInGradientTest(test.TestCase): def testCompilationInGradient(self): @@ -188,7 +190,7 @@ class CompilationEnabledInGradientTest(test.TestCase): for cg in c_grad_ops: self.assertTrue(cg.get_attr("_XlaCompile")) for ncg in nc_grad_ops: - with self.assertRaisesRegexp(ValueError, "No attr named"): + with self.assertRaisesRegexp(ValueError, "[Nn]o attr named"): ncg.get_attr("_XlaCompile") # d/dx (x ** 4) = 4 * (x ** 3) diff --git a/tensorflow/contrib/cudnn_rnn/ops/cudnn_rnn_ops.cc b/tensorflow/contrib/cudnn_rnn/ops/cudnn_rnn_ops.cc index 9e41e67857101534e8bfef8d5d0b8a45ed8f1f76..1a79bf066c3a27e040099729fb079ee963f59270 100644 --- a/tensorflow/contrib/cudnn_rnn/ops/cudnn_rnn_ops.cc +++ b/tensorflow/contrib/cudnn_rnn/ops/cudnn_rnn_ops.cc @@ -251,9 +251,8 @@ REGISTER_OP("CudnnRNNParamsToCanonical") TF_RETURN_IF_ERROR(c->GetAttr("num_params", &num_params)); // Set shape for weight matrices for (int i = 0; i < num_params; i++) { - c->set_output(i, - c->Matrix(InferenceContext::kUnknownDim, - InferenceContext::kUnknownDim)); + c->set_output(i, c->Matrix(InferenceContext::kUnknownDim, + InferenceContext::kUnknownDim)); } // Set shape for bias vectors for (int i = 0; i < num_params; i++) { @@ -300,6 +299,7 @@ upcoming training or inferences. num_params: number of parameter sets for all layers. Each layer may contain multiple parameter sets, with each set consisting of a weight matrix and a bias vector. -)doc", kCudnnRNNCommonAttrs)); +)doc", + kCudnnRNNCommonAttrs)); } // namespace tensorflow diff --git a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py index 015f69c5673f185c53e61a5df2636333699ae203..0c2827b1e49919d236aeb922645236251f1344e0 100644 --- a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py @@ -744,6 +744,23 @@ class BatchDatasetSerializationTest( lambda: self._build_dataset_dense_to_sparse(diff_comp), num_outputs) + def _sparse(self, i): + return sparse_tensor.SparseTensorValue( + indices=[[0]], values=(i * [1]), dense_shape=[1]) + + def _build_dataset_sparse(self, batch_size=5): + return dataset_ops.Dataset.range(10).map(self._sparse).batch(batch_size) + + def testSparseCore(self): + self.run_core_tests(self._build_dataset_sparse, + lambda: self._build_dataset_sparse(2), 2) + + def _build_dataset_nested_sparse(self): + return dataset_ops.Dataset.range(10).map(self._sparse).batch(5).batch(2) + + def testNestedSparseCore(self): + self.run_core_tests(self._build_dataset_nested_sparse, None, 1) + class PaddedBatchDatasetSerializationTest( dataset_serialization_test_base.DatasetSerializationTestBase): diff --git a/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py b/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py index 701fc8247e21e4c018c704d82f3a40c7daf1d742..dbc35097ddda9f0375060d43aeb43efa8107f929 100644 --- a/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py +++ b/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py @@ -24,6 +24,7 @@ import numpy as np from tensorflow.contrib.data.python.ops import iterator_ops as contrib_iterator_ops from tensorflow.python.data.ops import iterator_ops +from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor @@ -35,14 +36,29 @@ from tensorflow.python.training import saver as saver_lib from tensorflow.python.util import nest +def remove_variants(get_next_op): + # TODO(b/72408568): Remove this once session.run can get + # variant tensors. + """Remove variants from a nest structure, so sess.run will execute.""" + + def _remove_variant(x): + if isinstance(x, ops.Tensor) and x.dtype == dtypes.variant: + return () + else: + return x + + return nest.map_structure(_remove_variant, get_next_op) + + class DatasetSerializationTestBase(test.TestCase): """Base class for testing serializable datasets.""" def tearDown(self): self._delete_ckpt() - # TODO(b/70988345): Support native `tf.SparseTensor` objects and get rid of - # `sparse_tensors` argument. + # TODO(b/72657739): Remove sparse_tensor argument, which is to test the + # (deprecated) saveable `SparseTensorSliceDataset`, once the API + # `from_sparse_tensor_slices()`and related tests are deleted. def run_core_tests(self, ds_fn1, ds_fn2, num_outputs, sparse_tensors=False): """Runs the core tests. @@ -234,6 +250,7 @@ class DatasetSerializationTestBase(test.TestCase): saver = self._import_meta_graph() init_op, get_next_op = self._get_iterator_ops_from_collection( ds_fn, sparse_tensors=sparse_tensors) + get_next_op = remove_variants(get_next_op) with self.test_session(graph=g) as sess: self._restore(saver, sess) self._initialize(init_op, sess) @@ -296,6 +313,7 @@ class DatasetSerializationTestBase(test.TestCase): with ops.Graph().as_default() as g: _, get_next_op, saver = self._build_graph( ds_fn2, sparse_tensors=sparse_tensors) + get_next_op = remove_variants(get_next_op) with self.test_session(graph=g) as sess: self._restore(saver, sess) for _ in range(num_outputs - break_point): @@ -356,6 +374,7 @@ class DatasetSerializationTestBase(test.TestCase): with ops.Graph().as_default() as g: get_next_op, saver = self._build_empty_graph( ds_fn, sparse_tensors=sparse_tensors) + get_next_op = remove_variants(get_next_op) with self.test_session(graph=g) as sess: self._restore(saver, sess) for _ in range(num_outputs - break_point): @@ -389,6 +408,7 @@ class DatasetSerializationTestBase(test.TestCase): with ops.Graph().as_default() as g: init_op, get_next_op, saver = self._build_graph( ds_fn, sparse_tensors=sparse_tensors) + get_next_op = remove_variants(get_next_op) with self.test_session(graph=g) as sess: self._initialize(init_op, sess) for _ in range(break_point): @@ -484,11 +504,13 @@ class DatasetSerializationTestBase(test.TestCase): else: init_op, get_next_op, saver = self._build_graph( ds_fn, sparse_tensors=sparse_tensors) + get_next_op = remove_variants(get_next_op) return init_op, get_next_op, saver for i in range(len(break_points) + 1): with ops.Graph().as_default() as g: init_op, get_next_op, saver = get_ops() + get_next_op = remove_variants(get_next_op) with self.test_session(graph=g) as sess: if ckpt_saved: if init_before_restore: @@ -559,13 +581,16 @@ class DatasetSerializationTestBase(test.TestCase): get_next = sparse_tensor.SparseTensor(*iterator.get_next()) else: get_next = iterator.get_next() - self._add_iterator_ops_to_collection(init_op, get_next, sparse_tensors) + self._add_iterator_ops_to_collection(init_op, get_next, ds_fn, + sparse_tensors) saver = saver_lib.Saver(allow_empty=True) return init_op, get_next, saver def _build_empty_graph(self, ds_fn, sparse_tensors=False): iterator = iterator_ops.Iterator.from_structure( - self._get_output_types(ds_fn), self._get_output_shapes(ds_fn)) + self._get_output_types(ds_fn), + output_shapes=self._get_output_shapes(ds_fn), + output_classes=self._get_output_classes(ds_fn)) saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator) ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) if sparse_tensors: @@ -578,12 +603,19 @@ class DatasetSerializationTestBase(test.TestCase): def _add_iterator_ops_to_collection(self, init_op, get_next, + ds_fn, sparse_tensors=False): ops.add_to_collection("iterator_ops", init_op) # `get_next` may be a tuple e.g. in TensorSliceDataset. Since Collections # do not support tuples we flatten the tensors and restore the shape in # `_get_iterator_ops_from_collection`. - if sparse_tensors: + + # TODO(shivaniagrwal): `output_classes` is a nested structure of classes, + # this base class is specific to current test cases. Update when tests are + # added with `output_classes` as a nested structure with at least one of the + # component being `tf.SparseTensor`. + if (sparse_tensors or + self._get_output_classes(ds_fn) is sparse_tensor.SparseTensor): ops.add_to_collection("iterator_ops", get_next.indices) ops.add_to_collection("iterator_ops", get_next.values) ops.add_to_collection("iterator_ops", get_next.dense_shape) @@ -593,7 +625,8 @@ class DatasetSerializationTestBase(test.TestCase): def _get_iterator_ops_from_collection(self, ds_fn, sparse_tensors=False): all_ops = ops.get_collection("iterator_ops") - if sparse_tensors: + if (sparse_tensors or + self._get_output_classes(ds_fn) is sparse_tensor.SparseTensor): init_op, indices, values, dense_shape = all_ops return init_op, sparse_tensor.SparseTensor(indices, values, dense_shape) else: @@ -608,6 +641,10 @@ class DatasetSerializationTestBase(test.TestCase): with ops.Graph().as_default(): return ds_fn().output_shapes + def _get_output_classes(self, ds_fn): + with ops.Graph().as_default(): + return ds_fn().output_classes + def _ckpt_path(self): return os.path.join(self.get_temp_dir(), "iterator") diff --git a/tensorflow/contrib/data/python/kernel_tests/filter_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/filter_dataset_op_test.py index 5921be2ae89ba1bbbb8d6e3a509cf49c65949544..06883934d044c2c5faf467dd1708b858a2f8f9ab 100644 --- a/tensorflow/contrib/data/python/kernel_tests/filter_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/filter_dataset_op_test.py @@ -194,6 +194,10 @@ class FilterDatasetSerializationTest( return dataset_ops.Dataset.range(10).map(_map_fn).filter(_filter_fn).map( lambda x, i: x) + def testSparseCore(self): + num_outputs = 5 + self.run_core_tests(self._build_sparse_filter, None, num_outputs) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/flat_map_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/flat_map_dataset_op_test.py index d4fbaa5cdcdd315aa0524134b48eb0515169722c..86d69495ef47da0bc93b8d9b1299e552fc676ee1 100644 --- a/tensorflow/contrib/data/python/kernel_tests/flat_map_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/flat_map_dataset_op_test.py @@ -225,6 +225,21 @@ class FlatMapDatasetSerializationTest( self.verify_error_on_save(build_ds, 500, errors.InvalidArgumentError) + def testSparseCore(self): + + def _map_fn(i): + return sparse_tensor.SparseTensorValue( + indices=[[0, 0], [1, 1]], values=(i * [1, -1]), dense_shape=[2, 2]) + + def _flat_map_fn(x): + return dataset_ops.Dataset.from_tensor_slices( + sparse_ops.sparse_to_dense(x.indices, x.dense_shape, x.values)) + + def _build_ds(): + return dataset_ops.Dataset.range(10).map(_map_fn).flat_map(_flat_map_fn) + + self.run_core_tests(_build_ds, None, 20) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py index b1937c08f347734d0d6871bd30ed209ff520623a..db8429512bf2bf944e67b65d185aca99477c86d3 100644 --- a/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py @@ -252,6 +252,22 @@ class InterleaveDatasetSeriazationTest( None, num_outputs) # pylint: enable=g-long-lambda + def testSparseCore(self): + + def _map_fn(i): + return sparse_tensor.SparseTensorValue( + indices=[[0, 0], [1, 1]], values=(i * [1, -1]), dense_shape=[2, 2]) + + def _interleave_fn(x): + return dataset_ops.Dataset.from_tensor_slices( + sparse_ops.sparse_to_dense(x.indices, x.dense_shape, x.values)) + + def _build_dataset(): + return dataset_ops.Dataset.range(10).map(_map_fn).interleave( + _interleave_fn, cycle_length=1) + + self.run_core_tests(_build_dataset, None, 20) + class ParallelInterleaveDatasetTest(test.TestCase): diff --git a/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py index dd8247bfd47a9880c7cfe905103702e43b1f2165..d3ce89298be342e22f12c46e8e8213ef636d0dc6 100644 --- a/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py @@ -805,6 +805,21 @@ class MapDatasetSerializationTest( self.run_core_tests(_build_ds, None, num_outputs) + def testSparseCore(self): + + def _sparse(i): + return sparse_tensor.SparseTensorValue( + indices=np.array([[0, 0]]), + values=(i * np.array([1])), + dense_shape=np.array([1, 1])) + + def _build_ds(num_outputs): + return contrib_dataset_ops.Dataset.range(num_outputs).map(_sparse) + + num_outputs = 10 + self.run_core_tests(lambda: _build_ds(num_outputs), + lambda: _build_ds(int(num_outputs / 2)), num_outputs) + class ParallelMapDatasetSerializationTest( dataset_serialization_test_base.DatasetSerializationTestBase): @@ -851,7 +866,8 @@ class ParallelMapDatasetSerializationTest( return random_ops.random_uniform( (), 0, 10, dtype=dtypes.int32) * math_ops.to_int32(x) - return contrib_dataset_ops.Dataset.range(100).map(_map_fn) + return contrib_dataset_ops.Dataset.range(100).map( + _map_fn, num_parallel_calls=2).prefetch(2) self.verify_error_on_save(_build_ds, 15, errors.InvalidArgumentError) @@ -861,7 +877,8 @@ class ParallelMapDatasetSerializationTest( counter_var = variable_scope.get_variable( "counter", (), dtypes.int32, use_resource=True) return (contrib_dataset_ops.Dataset.from_tensors(0).repeat(10).map( - lambda _: counter_var.assign_add(1))) + lambda _: counter_var.assign_add(1), + num_parallel_calls=2).prefetch(2)) self.verify_error_on_save(_build_ds, 15, errors.InvalidArgumentError) @@ -870,7 +887,7 @@ class ParallelMapDatasetSerializationTest( def _build_ds(): constant_var = constant_op.constant(5) return (contrib_dataset_ops.Dataset.from_tensors(0).repeat(10).map( - lambda x: x + constant_var)) + lambda x: x + constant_var, num_parallel_calls=2).prefetch(2)) self.run_core_tests(_build_ds, None, 10) @@ -883,7 +900,8 @@ class ParallelMapDatasetSerializationTest( def defun_fn(x): return constant_op.constant(1000) + math_ops.to_int32(x) - return contrib_dataset_ops.Dataset.range(num_outputs).map(defun_fn) + return contrib_dataset_ops.Dataset.range(num_outputs).map( + defun_fn, num_parallel_calls=2).prefetch(2) self.run_core_tests(_build_ds, None, num_outputs) @@ -901,7 +919,8 @@ class ParallelMapDatasetSerializationTest( return constant_op.constant(11000) + defun_fn_deep(math_ops.to_int32(x)) - return contrib_dataset_ops.Dataset.range(num_outputs).map(defun_fn) + return contrib_dataset_ops.Dataset.range(num_outputs).map( + defun_fn, num_parallel_calls=2).prefetch(2) self.run_core_tests(_build_ds, None, num_outputs) diff --git a/tensorflow/contrib/data/python/ops/batching.py b/tensorflow/contrib/data/python/ops/batching.py index 76c07b2c999e1424e8efe4af515fddee73922c9c..6eb512dec67cb7b9c8c4518d03aee0b436205f9a 100644 --- a/tensorflow/contrib/data/python/ops/batching.py +++ b/tensorflow/contrib/data/python/ops/batching.py @@ -403,7 +403,7 @@ def map_and_batch(map_func, batch_size, num_parallel_batches=1): num_parallel_batches: A `tf.int64` scalar `tf.Tensor`, representing the number of batches to create in parallel. On one hand, higher values can help mitigate the effect of stragglers. On the other hand, higher values - can increasing contention if CPU is scarce. + can increase contention if CPU is scarce. Returns: A `Dataset` transformation function, which can be passed to diff --git a/tensorflow/contrib/data/python/ops/stats_ops.py b/tensorflow/contrib/data/python/ops/stats_ops.py index 1dd0729513c0d46db25226178eb17b41efaae0ae..9cd1701c397b5a0bf5cc47c1bcab033704794d80 100644 --- a/tensorflow/contrib/data/python/ops/stats_ops.py +++ b/tensorflow/contrib/data/python/ops/stats_ops.py @@ -20,6 +20,7 @@ from __future__ import print_function from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import iterator_ops from tensorflow.python.data.util import nest +from tensorflow.python.data.util import sparse from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import gen_dataset_ops @@ -161,8 +162,10 @@ class _StatsDataset(dataset_ops.Dataset): return self._op_function( self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access self._tag, - output_shapes=nest.flatten(self.output_shapes), - output_types=nest.flatten(self.output_types)) + output_types=nest.flatten( + sparse.as_dense_types(self.output_types, self.output_classes)), + output_shapes=nest.flatten( + sparse.as_dense_shapes(self.output_shapes, self.output_classes))) @property def output_shapes(self): diff --git a/tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py b/tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py index a255d4fc890e67180532e342332a8e3f63a869cd..31d24aa9ea09007b8db40e4869371b1f62639ac7 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py @@ -23,10 +23,15 @@ import itertools import numpy as np from tensorflow.contrib.distributions.python.ops import distribution_util +from tensorflow.contrib.distributions.python.ops import mixture +from tensorflow.contrib.distributions.python.ops import mixture_same_family +from tensorflow.contrib.distributions.python.ops import mvn_diag from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops +from tensorflow.python.ops.distributions import categorical +from tensorflow.python.ops.distributions import normal from tensorflow.python.ops.linalg import linear_operator_diag import tensorflow.python.ops.nn_grad # pylint: disable=unused-import from tensorflow.python.platform import test @@ -395,6 +400,41 @@ class MixtureStddevTest(test.TestCase): self.assertAllClose(actual_devs, expected_devs) +class PadMixtureDimensionsTest(test.TestCase): + + def test_pad_mixture_dimensions_mixture(self): + with self.test_session() as sess: + gm = mixture.Mixture( + cat=categorical.Categorical(probs=[[0.3, 0.7]]), + components=[ + normal.Normal(loc=[-1.0], scale=[1.0]), + normal.Normal(loc=[1.0], scale=[0.5]) + ]) + + x = array_ops.constant([[1.0, 2.0], [3.0, 4.0]]) + x_pad = distribution_util.pad_mixture_dimensions( + x, gm, gm.cat, gm.event_shape.ndims) + x_out, x_pad_out = sess.run([x, x_pad]) + + self.assertAllEqual(x_pad_out.shape, [2, 2]) + self.assertAllEqual(x_out.reshape([-1]), x_pad_out.reshape([-1])) + + def test_pad_mixture_dimensions_mixture_same_family(self): + with self.test_session() as sess: + gm = mixture_same_family.MixtureSameFamily( + mixture_distribution=categorical.Categorical(probs=[0.3, 0.7]), + components_distribution=mvn_diag.MultivariateNormalDiag( + loc=[[-1., 1], [1, -1]], scale_identity_multiplier=[1.0, 0.5])) + + x = array_ops.constant([[1.0, 2.0], [3.0, 4.0]]) + x_pad = distribution_util.pad_mixture_dimensions( + x, gm, gm.mixture_distribution, gm.event_shape.ndims) + x_out, x_pad_out = sess.run([x, x_pad]) + + self.assertAllEqual(x_pad_out.shape, [2, 2, 1]) + self.assertAllEqual(x_out.reshape([-1]), x_pad_out.reshape([-1])) + + class _PadTest(object): def testNegAxisCorrectness(self): diff --git a/tensorflow/contrib/distributions/python/kernel_tests/kumaraswamy_test.py b/tensorflow/contrib/distributions/python/kernel_tests/kumaraswamy_test.py new file mode 100644 index 0000000000000000000000000000000000000000..ea3c86b5c0f42b64fc6e4e362cbcc162bccf74a2 --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/kumaraswamy_test.py @@ -0,0 +1,388 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import importlib + +import numpy as np + +from tensorflow.contrib.distributions.python.ops import kumaraswamy as kumaraswamy_lib +from tensorflow.python.client import session +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import random_seed +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test +from tensorflow.python.platform import tf_logging + + +def try_import(name): # pylint: disable=invalid-name + module = None + try: + module = importlib.import_module(name) + except ImportError as e: + tf_logging.warning("Could not import %s: %s" % (name, str(e))) + return module + + +special = try_import("scipy.special") +stats = try_import("scipy.stats") + + +def _kumaraswamy_mode(a, b): + a = np.asarray(a) + b = np.asarray(b) + return ((a - 1) / (a * b - 1))**(1 / a) + + +def _kumaraswamy_moment(a, b, n): + a = np.asarray(a) + b = np.asarray(b) + return b * special.beta(1.0 + n / a, b) + + +def _harmonic_number(b): + b = np.asarray(b) + return special.psi(b + 1) - special.psi(1) + + +def _kumaraswamy_cdf(a, b, x): + a = np.asarray(a) + b = np.asarray(b) + x = np.asarray(x) + return 1 - (1 - x**a)**b + + +def _kumaraswamy_pdf(a, b, x): + a = np.asarray(a) + b = np.asarray(b) + x = np.asarray(x) + return a * b * x ** (a - 1) * (1 - x ** a) ** (b - 1) + + +class KumaraswamyTest(test.TestCase): + + def testSimpleShapes(self): + with self.test_session(): + a = np.random.rand(3) + b = np.random.rand(3) + dist = kumaraswamy_lib.Kumaraswamy(a, b) + self.assertAllEqual([], dist.event_shape_tensor().eval()) + self.assertAllEqual([3], dist.batch_shape_tensor().eval()) + self.assertEqual(tensor_shape.TensorShape([]), dist.event_shape) + self.assertEqual(tensor_shape.TensorShape([3]), dist.batch_shape) + + def testComplexShapes(self): + with self.test_session(): + a = np.random.rand(3, 2, 2) + b = np.random.rand(3, 2, 2) + dist = kumaraswamy_lib.Kumaraswamy(a, b) + self.assertAllEqual([], dist.event_shape_tensor().eval()) + self.assertAllEqual([3, 2, 2], dist.batch_shape_tensor().eval()) + self.assertEqual(tensor_shape.TensorShape([]), dist.event_shape) + self.assertEqual(tensor_shape.TensorShape([3, 2, 2]), dist.batch_shape) + + def testComplexShapesBroadcast(self): + with self.test_session(): + a = np.random.rand(3, 2, 2) + b = np.random.rand(2, 2) + dist = kumaraswamy_lib.Kumaraswamy(a, b) + self.assertAllEqual([], dist.event_shape_tensor().eval()) + self.assertAllEqual([3, 2, 2], dist.batch_shape_tensor().eval()) + self.assertEqual(tensor_shape.TensorShape([]), dist.event_shape) + self.assertEqual(tensor_shape.TensorShape([3, 2, 2]), dist.batch_shape) + + def testAProperty(self): + a = [[1., 2, 3]] + b = [[2., 4, 3]] + with self.test_session(): + dist = kumaraswamy_lib.Kumaraswamy(a, b) + self.assertEqual([1, 3], dist.concentration1.get_shape()) + self.assertAllClose(a, dist.concentration1.eval()) + + def testBProperty(self): + a = [[1., 2, 3]] + b = [[2., 4, 3]] + with self.test_session(): + dist = kumaraswamy_lib.Kumaraswamy(a, b) + self.assertEqual([1, 3], dist.concentration0.get_shape()) + self.assertAllClose(b, dist.concentration0.eval()) + + def testPdfXProper(self): + a = [[1., 2, 3]] + b = [[2., 4, 3]] + with self.test_session(): + dist = kumaraswamy_lib.Kumaraswamy(a, b, validate_args=True) + dist.prob([.1, .3, .6]).eval() + dist.prob([.2, .3, .5]).eval() + # Either condition can trigger. + with self.assertRaisesOpError("sample must be positive"): + dist.prob([-1., 0.1, 0.5]).eval() + with self.assertRaisesOpError("sample must be positive"): + dist.prob([0., 0.1, 0.5]).eval() + with self.assertRaisesOpError("sample must be no larger than `1`"): + dist.prob([.1, .2, 1.2]).eval() + + def testPdfTwoBatches(self): + with self.test_session(): + a = [1., 2] + b = [1., 2] + x = [.5, .5] + dist = kumaraswamy_lib.Kumaraswamy(a, b) + pdf = dist.prob(x) + expected_pdf = _kumaraswamy_pdf(a, b, x) + self.assertAllClose(expected_pdf, pdf.eval()) + self.assertEqual((2,), pdf.get_shape()) + + def testPdfTwoBatchesNontrivialX(self): + with self.test_session(): + a = [1., 2] + b = [1., 2] + x = [.3, .7] + dist = kumaraswamy_lib.Kumaraswamy(a, b) + pdf = dist.prob(x) + expected_pdf = _kumaraswamy_pdf(a, b, x) + self.assertAllClose(expected_pdf, pdf.eval()) + self.assertEqual((2,), pdf.get_shape()) + + def testPdfUniformZeroBatch(self): + with self.test_session(): + # This is equivalent to a uniform distribution + a = 1. + b = 1. + x = np.array([.1, .2, .3, .5, .8], dtype=np.float32) + dist = kumaraswamy_lib.Kumaraswamy(a, b) + pdf = dist.prob(x) + expected_pdf = _kumaraswamy_pdf(a, b, x) + self.assertAllClose(expected_pdf, pdf.eval()) + self.assertEqual((5,), pdf.get_shape()) + + def testPdfAStretchedInBroadcastWhenSameRank(self): + with self.test_session(): + a = [[1., 2]] + b = [[1., 2]] + x = [[.5, .5], [.3, .7]] + dist = kumaraswamy_lib.Kumaraswamy(a, b) + pdf = dist.prob(x) + expected_pdf = _kumaraswamy_pdf(a, b, x) + self.assertAllClose(expected_pdf, pdf.eval()) + self.assertEqual((2, 2), pdf.get_shape()) + + def testPdfAStretchedInBroadcastWhenLowerRank(self): + with self.test_session(): + a = [1., 2] + b = [1., 2] + x = [[.5, .5], [.2, .8]] + pdf = kumaraswamy_lib.Kumaraswamy(a, b).prob(x) + expected_pdf = _kumaraswamy_pdf(a, b, x) + self.assertAllClose(expected_pdf, pdf.eval()) + self.assertEqual((2, 2), pdf.get_shape()) + + def testPdfXStretchedInBroadcastWhenSameRank(self): + with self.test_session(): + a = [[1., 2], [2., 3]] + b = [[1., 2], [2., 3]] + x = [[.5, .5]] + pdf = kumaraswamy_lib.Kumaraswamy(a, b).prob(x) + expected_pdf = _kumaraswamy_pdf(a, b, x) + self.assertAllClose(expected_pdf, pdf.eval()) + self.assertEqual((2, 2), pdf.get_shape()) + + def testPdfXStretchedInBroadcastWhenLowerRank(self): + with self.test_session(): + a = [[1., 2], [2., 3]] + b = [[1., 2], [2., 3]] + x = [.5, .5] + pdf = kumaraswamy_lib.Kumaraswamy(a, b).prob(x) + expected_pdf = _kumaraswamy_pdf(a, b, x) + self.assertAllClose(expected_pdf, pdf.eval()) + self.assertEqual((2, 2), pdf.get_shape()) + + def testKumaraswamyMean(self): + with session.Session(): + a = [1., 2, 3] + b = [2., 4, 1.2] + dist = kumaraswamy_lib.Kumaraswamy(a, b) + self.assertEqual(dist.mean().get_shape(), (3,)) + if not stats: + return + expected_mean = _kumaraswamy_moment(a, b, 1) + self.assertAllClose(expected_mean, dist.mean().eval()) + + def testKumaraswamyVariance(self): + with session.Session(): + a = [1., 2, 3] + b = [2., 4, 1.2] + dist = kumaraswamy_lib.Kumaraswamy(a, b) + self.assertEqual(dist.variance().get_shape(), (3,)) + if not stats: + return + expected_variance = _kumaraswamy_moment(a, b, 2) - _kumaraswamy_moment( + a, b, 1)**2 + self.assertAllClose(expected_variance, dist.variance().eval()) + + def testKumaraswamyMode(self): + with session.Session(): + a = np.array([1.1, 2, 3]) + b = np.array([2., 4, 1.2]) + expected_mode = _kumaraswamy_mode(a, b) + dist = kumaraswamy_lib.Kumaraswamy(a, b) + self.assertEqual(dist.mode().get_shape(), (3,)) + self.assertAllClose(expected_mode, dist.mode().eval()) + + def testKumaraswamyModeInvalid(self): + with session.Session(): + a = np.array([1., 2, 3]) + b = np.array([2., 4, 1.2]) + dist = kumaraswamy_lib.Kumaraswamy(a, b, allow_nan_stats=False) + with self.assertRaisesOpError("Condition x < y.*"): + dist.mode().eval() + + a = np.array([2., 2, 3]) + b = np.array([1., 4, 1.2]) + dist = kumaraswamy_lib.Kumaraswamy(a, b, allow_nan_stats=False) + with self.assertRaisesOpError("Condition x < y.*"): + dist.mode().eval() + + def testKumaraswamyModeEnableAllowNanStats(self): + with session.Session(): + a = np.array([1., 2, 3]) + b = np.array([2., 4, 1.2]) + dist = kumaraswamy_lib.Kumaraswamy(a, b, allow_nan_stats=True) + + expected_mode = _kumaraswamy_mode(a, b) + expected_mode[0] = np.nan + self.assertEqual((3,), dist.mode().get_shape()) + self.assertAllClose(expected_mode, dist.mode().eval()) + + a = np.array([2., 2, 3]) + b = np.array([1., 4, 1.2]) + dist = kumaraswamy_lib.Kumaraswamy(a, b, allow_nan_stats=True) + + expected_mode = _kumaraswamy_mode(a, b) + expected_mode[0] = np.nan + self.assertEqual((3,), dist.mode().get_shape()) + self.assertAllClose(expected_mode, dist.mode().eval()) + + def testKumaraswamyEntropy(self): + with session.Session(): + a = np.array([1., 2, 3]) + b = np.array([2., 4, 1.2]) + dist = kumaraswamy_lib.Kumaraswamy(a, b) + self.assertEqual(dist.entropy().get_shape(), (3,)) + if not stats: + return + expected_entropy = (1 - 1. / a) + ( + 1 - 1. / b) * _harmonic_number(b) + np.log(a * b) + self.assertAllClose(expected_entropy, dist.entropy().eval()) + + def testKumaraswamySample(self): + with self.test_session(): + a = 1. + b = 2. + kumaraswamy = kumaraswamy_lib.Kumaraswamy(a, b) + n = constant_op.constant(100000) + samples = kumaraswamy.sample(n) + sample_values = samples.eval() + self.assertEqual(sample_values.shape, (100000,)) + self.assertFalse(np.any(sample_values < 0.0)) + if not stats: + return + self.assertLess( + stats.kstest( + # Kumaraswamy is a univariate distribution. + sample_values, + lambda x: _kumaraswamy_cdf(1., 2., x))[0], + 0.01) + # The standard error of the sample mean is 1 / (sqrt(18 * n)) + expected_mean = _kumaraswamy_moment(a, b, 1) + self.assertAllClose(sample_values.mean(axis=0), expected_mean, atol=1e-2) + expected_variance = _kumaraswamy_moment(a, b, 2) - _kumaraswamy_moment( + a, b, 1)**2 + self.assertAllClose( + np.cov(sample_values, rowvar=0), expected_variance, atol=1e-1) + + # Test that sampling with the same seed twice gives the same results. + def testKumaraswamySampleMultipleTimes(self): + with self.test_session(): + a_val = 1. + b_val = 2. + n_val = 100 + + random_seed.set_random_seed(654321) + kumaraswamy1 = kumaraswamy_lib.Kumaraswamy( + concentration1=a_val, concentration0=b_val, name="kumaraswamy1") + samples1 = kumaraswamy1.sample(n_val, seed=123456).eval() + + random_seed.set_random_seed(654321) + kumaraswamy2 = kumaraswamy_lib.Kumaraswamy( + concentration1=a_val, concentration0=b_val, name="kumaraswamy2") + samples2 = kumaraswamy2.sample(n_val, seed=123456).eval() + + self.assertAllClose(samples1, samples2) + + def testKumaraswamySampleMultidimensional(self): + with self.test_session(): + a = np.random.rand(3, 2, 2).astype(np.float32) + b = np.random.rand(3, 2, 2).astype(np.float32) + kumaraswamy = kumaraswamy_lib.Kumaraswamy(a, b) + n = constant_op.constant(100000) + samples = kumaraswamy.sample(n) + sample_values = samples.eval() + self.assertEqual(sample_values.shape, (100000, 3, 2, 2)) + self.assertFalse(np.any(sample_values < 0.0)) + if not stats: + return + self.assertAllClose( + sample_values[:, 1, :].mean(axis=0), + _kumaraswamy_moment(a, b, 1)[1, :], + atol=1e-1) + + def testKumaraswamyCdf(self): + with self.test_session(): + shape = (30, 40, 50) + for dt in (np.float32, np.float64): + a = 10. * np.random.random(shape).astype(dt) + b = 10. * np.random.random(shape).astype(dt) + x = np.random.random(shape).astype(dt) + actual = kumaraswamy_lib.Kumaraswamy(a, b).cdf(x).eval() + self.assertAllEqual(np.ones(shape, dtype=np.bool), 0. <= x) + self.assertAllEqual(np.ones(shape, dtype=np.bool), 1. >= x) + if not stats: + return + self.assertAllClose( + _kumaraswamy_cdf(a, b, x), actual, rtol=1e-4, atol=0) + + def testKumaraswamyLogCdf(self): + with self.test_session(): + shape = (30, 40, 50) + for dt in (np.float32, np.float64): + a = 10. * np.random.random(shape).astype(dt) + b = 10. * np.random.random(shape).astype(dt) + x = np.random.random(shape).astype(dt) + actual = math_ops.exp(kumaraswamy_lib.Kumaraswamy(a, + b).log_cdf(x)).eval() + self.assertAllEqual(np.ones(shape, dtype=np.bool), 0. <= x) + self.assertAllEqual(np.ones(shape, dtype=np.bool), 1. >= x) + if not stats: + return + self.assertAllClose( + _kumaraswamy_cdf(a, b, x), actual, rtol=1e-4, atol=0) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/mixture_test.py b/tensorflow/contrib/distributions/python/kernel_tests/mixture_test.py index 1e514fe0ff21cd53c8c235da417890773db50c37..02064891758a86c5108e11da6a3666f2d5c56c64 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/mixture_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/mixture_test.py @@ -107,7 +107,7 @@ def _test_capture_normal_sample_outputs(): ds.Normal._call_sample_n = true_normal_call_sample_n -def make_univariate_mixture(batch_shape, num_components): +def make_univariate_mixture(batch_shape, num_components, use_static_graph): batch_shape = ops.convert_to_tensor(batch_shape, dtypes.int32) logits = random_ops.random_uniform( array_ops.concat((batch_shape, [num_components]), axis=0), @@ -119,11 +119,11 @@ def make_univariate_mixture(batch_shape, num_components): for _ in range(num_components) ] cat = ds.Categorical(logits, dtype=dtypes.int32) - return ds.Mixture(cat, components) + return ds.Mixture(cat, components, use_static_graph=use_static_graph) def make_multivariate_mixture(batch_shape, num_components, event_shape, - batch_shape_tensor=None): + use_static_graph, batch_shape_tensor=None): if batch_shape_tensor is None: batch_shape_tensor = batch_shape batch_shape_tensor = ops.convert_to_tensor(batch_shape_tensor, dtypes.int32) @@ -145,15 +145,17 @@ def make_multivariate_mixture(batch_shape, num_components, event_shape, loc=loc, scale_diag=scale_diag) components = [create_component() for _ in range(num_components)] cat = ds.Categorical(logits, dtype=dtypes.int32) - return ds.Mixture(cat, components) + return ds.Mixture(cat, components, use_static_graph=use_static_graph) class MixtureTest(test.TestCase): + use_static_graph = False def testShapes(self): with self.test_session(): for batch_shape in ([], [1], [2, 3, 4]): - dist = make_univariate_mixture(batch_shape, num_components=10) + dist = make_univariate_mixture(batch_shape, num_components=10, + use_static_graph=self.use_static_graph) self.assertAllEqual(batch_shape, dist.batch_shape) self.assertAllEqual(batch_shape, dist.batch_shape_tensor().eval()) self.assertAllEqual([], dist.event_shape) @@ -161,7 +163,8 @@ class MixtureTest(test.TestCase): for event_shape in ([1], [2]): dist = make_multivariate_mixture( - batch_shape, num_components=10, event_shape=event_shape) + batch_shape, num_components=10, event_shape=event_shape, + use_static_graph=self.use_static_graph) self.assertAllEqual(batch_shape, dist.batch_shape) self.assertAllEqual(batch_shape, dist.batch_shape_tensor().eval()) self.assertAllEqual(event_shape, dist.event_shape) @@ -172,7 +175,8 @@ class MixtureTest(test.TestCase): r"cat.num_classes != len"): ds.Mixture( ds.Categorical([0.1, 0.5]), # 2 classes - [ds.Normal(loc=1.0, scale=2.0)]) + [ds.Normal(loc=1.0, scale=2.0)], + use_static_graph=self.use_static_graph) with self.assertRaisesWithPredicateMatch( ValueError, r"\(\) and \(2,\) are not compatible"): # The value error is raised because the batch shapes of the @@ -185,13 +189,15 @@ class MixtureTest(test.TestCase): loc=1.0, scale=2.0), # scalar dist ds.Normal( loc=[1.0, 1.0], scale=[2.0, 2.0]) - ]) + ], + use_static_graph=self.use_static_graph) with self.assertRaisesWithPredicateMatch(ValueError, r"Could not infer"): cat_logits = array_ops.placeholder(shape=[1, None], dtype=dtypes.float32) ds.Mixture( ds.Categorical(cat_logits), [ds.Normal( - loc=[1.0], scale=[2.0])]) + loc=[1.0], scale=[2.0])], + use_static_graph=self.use_static_graph) def testBrokenShapesDynamic(self): with self.test_session(): @@ -203,29 +209,37 @@ class MixtureTest(test.TestCase): loc=d0_param, scale=d0_param), ds.Normal( loc=d1_param, scale=d1_param) ], - validate_args=True) - with self.assertRaisesOpError(r"batch shape must match"): + validate_args=True, + use_static_graph=self.use_static_graph) + + if self.use_static_graph: + error_string = r"Shapes of all inputs must match" + else: + error_string = r"batch shape must match" + + with self.assertRaisesOpError(error_string): d.sample().eval(feed_dict={d0_param: [2.0, 3.0], d1_param: [1.0]}) - with self.assertRaisesOpError(r"batch shape must match"): + with self.assertRaisesOpError(error_string): d.sample().eval(feed_dict={d0_param: [2.0, 3.0], d1_param: 1.0}) def testBrokenTypes(self): with self.assertRaisesWithPredicateMatch(TypeError, "Categorical"): - ds.Mixture(None, []) + ds.Mixture(None, [], use_static_graph=self.use_static_graph) cat = ds.Categorical([0.3, 0.2]) # components must be a list of distributions with self.assertRaisesWithPredicateMatch( TypeError, "all .* must be Distribution instances"): - ds.Mixture(cat, [None]) + ds.Mixture(cat, [None], use_static_graph=self.use_static_graph) with self.assertRaisesWithPredicateMatch(TypeError, "same dtype"): ds.Mixture( cat, [ ds.Normal(loc=[1.0], scale=[2.0]), ds.Normal(loc=[np.float16(1.0)], scale=[np.float16(2.0)]), - ]) + ], use_static_graph=self.use_static_graph) with self.assertRaisesWithPredicateMatch(ValueError, "non-empty list"): - ds.Mixture(ds.Categorical([0.3, 0.2]), None) + ds.Mixture(ds.Categorical([0.3, 0.2]), None, + use_static_graph=self.use_static_graph) # TODO(ebrevdo): once distribution Domains have been added, add a # test to ensure that the domains of the distributions in a @@ -235,7 +249,8 @@ class MixtureTest(test.TestCase): with self.test_session() as sess: for batch_shape in ((), (2,), (2, 3)): dist = make_univariate_mixture( - batch_shape=batch_shape, num_components=2) + batch_shape=batch_shape, num_components=2, + use_static_graph=self.use_static_graph) mean = dist.mean() self.assertEqual(batch_shape, mean.get_shape()) @@ -256,7 +271,8 @@ class MixtureTest(test.TestCase): with self.test_session() as sess: for batch_shape in ((), (2,), (2, 3)): dist = make_multivariate_mixture( - batch_shape=batch_shape, num_components=2, event_shape=(4,)) + batch_shape=batch_shape, num_components=2, event_shape=(4,), + use_static_graph=self.use_static_graph) mean = dist.mean() self.assertEqual(batch_shape + (4,), mean.get_shape()) @@ -283,7 +299,8 @@ class MixtureTest(test.TestCase): with self.test_session() as sess: for batch_shape in ((), (2,), (2, 3)): dist = make_univariate_mixture( - batch_shape=batch_shape, num_components=num_components) + batch_shape=batch_shape, num_components=num_components, + use_static_graph=self.use_static_graph) dev = dist.stddev() self.assertEqual(batch_shape, dev.get_shape()) @@ -325,7 +342,8 @@ class MixtureTest(test.TestCase): dist = make_multivariate_mixture( batch_shape=batch_shape, num_components=num_components, - event_shape=(4,)) + event_shape=(4,), + use_static_graph=self.use_static_graph) dev = dist.stddev() self.assertEqual(batch_shape + (4,), dev.get_shape()) @@ -371,7 +389,8 @@ class MixtureTest(test.TestCase): scale=component_devs[0]), ds.Normal(loc=component_means[1], scale=component_devs[1]), - ]) + ], + use_static_graph=self.use_static_graph) mix_dev = mixture_dist.stddev() with self.test_session() as sess: actual_stddev = sess.run(mix_dev) @@ -379,7 +398,8 @@ class MixtureTest(test.TestCase): def testProbScalarUnivariate(self): with self.test_session() as sess: - dist = make_univariate_mixture(batch_shape=[], num_components=2) + dist = make_univariate_mixture(batch_shape=[], num_components=2, + use_static_graph=self.use_static_graph) for x in [ np.array( [1.0, 2.0], dtype=np.float32), np.array( @@ -405,7 +425,8 @@ class MixtureTest(test.TestCase): def testProbScalarMultivariate(self): with self.test_session() as sess: dist = make_multivariate_mixture( - batch_shape=[], num_components=2, event_shape=[3]) + batch_shape=[], num_components=2, event_shape=[3], + use_static_graph=self.use_static_graph) for x in [ np.array( [[-1.0, 0.0, 1.0], [0.5, 1.0, -0.3]], dtype=np.float32), np.array( @@ -432,7 +453,8 @@ class MixtureTest(test.TestCase): def testProbBatchUnivariate(self): with self.test_session() as sess: - dist = make_univariate_mixture(batch_shape=[2, 3], num_components=2) + dist = make_univariate_mixture(batch_shape=[2, 3], num_components=2, + use_static_graph=self.use_static_graph) for x in [ np.random.randn(2, 3).astype(np.float32), @@ -459,7 +481,8 @@ class MixtureTest(test.TestCase): def testProbBatchMultivariate(self): with self.test_session() as sess: dist = make_multivariate_mixture( - batch_shape=[2, 3], num_components=2, event_shape=[4]) + batch_shape=[2, 3], num_components=2, event_shape=[4], + use_static_graph=self.use_static_graph) for x in [ np.random.randn(2, 3, 4).astype(np.float32), @@ -487,7 +510,8 @@ class MixtureTest(test.TestCase): num_components = 3 batch_shape = [] dist = make_univariate_mixture( - batch_shape=batch_shape, num_components=num_components) + batch_shape=batch_shape, num_components=num_components, + use_static_graph=self.use_static_graph) n = 4 with _test_capture_normal_sample_outputs() as component_samples: samples = dist.sample(n, seed=123) @@ -502,7 +526,10 @@ class MixtureTest(test.TestCase): which_c = np.where(cat_sample_values == c)[0] size_c = which_c.size # Scalar Batch univariate case: batch_size == 1, rank 1 - which_dist_samples = dist_sample_values[c][:size_c] + if self.use_static_graph: + which_dist_samples = dist_sample_values[c][which_c] + else: + which_dist_samples = dist_sample_values[c][:size_c] self.assertAllClose(which_dist_samples, sample_values[which_c]) # Test that sampling with the same seed twice gives the same results. @@ -522,7 +549,8 @@ class MixtureTest(test.TestCase): ] cat = ds.Categorical( logits, dtype=dtypes.int32, name="cat1") - dist1 = ds.Mixture(cat, components, name="mixture1") + dist1 = ds.Mixture(cat, components, name="mixture1", + use_static_graph=self.use_static_graph) samples1 = dist1.sample(n, seed=123456).eval() random_seed.set_random_seed(654321) @@ -532,7 +560,8 @@ class MixtureTest(test.TestCase): ] cat2 = ds.Categorical( logits, dtype=dtypes.int32, name="cat2") - dist2 = ds.Mixture(cat2, components2, name="mixture2") + dist2 = ds.Mixture(cat2, components2, name="mixture2", + use_static_graph=self.use_static_graph) samples2 = dist2.sample(n, seed=123456).eval() self.assertAllClose(samples1, samples2) @@ -541,7 +570,8 @@ class MixtureTest(test.TestCase): with self.test_session() as sess: num_components = 3 dist = make_multivariate_mixture( - batch_shape=[], num_components=num_components, event_shape=[2]) + batch_shape=[], num_components=num_components, event_shape=[2], + use_static_graph=self.use_static_graph) n = 4 with _test_capture_mvndiag_sample_outputs() as component_samples: samples = dist.sample(n, seed=123) @@ -555,14 +585,18 @@ class MixtureTest(test.TestCase): which_c = np.where(cat_sample_values == c)[0] size_c = which_c.size # Scalar Batch multivariate case: batch_size == 1, rank 2 - which_dist_samples = dist_sample_values[c][:size_c, :] + if self.use_static_graph: + which_dist_samples = dist_sample_values[c][which_c, :] + else: + which_dist_samples = dist_sample_values[c][:size_c, :] self.assertAllClose(which_dist_samples, sample_values[which_c, :]) def testSampleBatchUnivariate(self): with self.test_session() as sess: num_components = 3 dist = make_univariate_mixture( - batch_shape=[2, 3], num_components=num_components) + batch_shape=[2, 3], num_components=num_components, + use_static_graph=self.use_static_graph) n = 4 with _test_capture_normal_sample_outputs() as component_samples: samples = dist.sample(n, seed=123) @@ -576,8 +610,12 @@ class MixtureTest(test.TestCase): which_c_s, which_c_b0, which_c_b1 = np.where(cat_sample_values == c) size_c = which_c_s.size # Batch univariate case: batch_size == [2, 3], rank 3 - which_dist_samples = dist_sample_values[c][range(size_c), which_c_b0, - which_c_b1] + if self.use_static_graph: + which_dist_samples = dist_sample_values[c][which_c_s, which_c_b0, + which_c_b1] + else: + which_dist_samples = dist_sample_values[c][range(size_c), which_c_b0, + which_c_b1] self.assertAllClose(which_dist_samples, sample_values[which_c_s, which_c_b0, which_c_b1]) @@ -594,7 +632,8 @@ class MixtureTest(test.TestCase): dist = make_multivariate_mixture( batch_shape=batch_shape, num_components=num_components, event_shape=[4], - batch_shape_tensor=batch_shape_tensor) + batch_shape_tensor=batch_shape_tensor, + use_static_graph=self.use_static_graph) n = 5 with _test_capture_mvndiag_sample_outputs() as component_samples: samples = dist.sample(n, seed=123) @@ -617,8 +656,12 @@ class MixtureTest(test.TestCase): which_c_s, which_c_b0, which_c_b1 = np.where(cat_sample_values == c) size_c = which_c_s.size # Batch univariate case: batch_size == [2, 3], rank 4 (multivariate) - which_dist_samples = dist_sample_values[c][range(size_c), which_c_b0, - which_c_b1, :] + if self.use_static_graph: + which_dist_samples = dist_sample_values[c][which_c_s, which_c_b0, + which_c_b1, :] + else: + which_dist_samples = dist_sample_values[c][range(size_c), which_c_b0, + which_c_b1, :] self.assertAllClose(which_dist_samples, sample_values[which_c_s, which_c_b0, which_c_b1, :]) @@ -632,7 +675,8 @@ class MixtureTest(test.TestCase): with self.test_session() as sess: for batch_shape in ((), (2,), (2, 3)): dist = make_multivariate_mixture( - batch_shape=batch_shape, num_components=2, event_shape=(4,)) + batch_shape=batch_shape, num_components=2, event_shape=(4,), + use_static_graph=self.use_static_graph) entropy_lower_bound = dist.entropy_lower_bound() self.assertEqual(batch_shape, entropy_lower_bound.get_shape()) @@ -673,7 +717,8 @@ class MixtureTest(test.TestCase): cat_tf = ds.Categorical(probs=mixture_weights) components_tf = [ds.Normal(loc=mu, scale=sigma) for (mu, sigma) in zip(means, sigmas)] - mixture_tf = ds.Mixture(cat=cat_tf, components=components_tf) + mixture_tf = ds.Mixture(cat=cat_tf, components=components_tf, + use_static_graph=self.use_static_graph) x_tensor = array_ops.placeholder(shape=(), dtype=dtypes.float32) @@ -721,7 +766,8 @@ class MixtureTest(test.TestCase): cat_tf = ds.Categorical(probs=mixture_weights) components_tf = [ds.Normal(loc=mu, scale=sigma) for (mu, sigma) in zip(means, sigmas)] - mixture_tf = ds.Mixture(cat=cat_tf, components=components_tf) + mixture_tf = ds.Mixture(cat=cat_tf, components=components_tf, + use_static_graph=self.use_static_graph) x_tensor = array_ops.placeholder(shape=psize, dtype=dtypes.float32) xs_to_check = [ @@ -760,12 +806,18 @@ class MixtureTest(test.TestCase): gm = ds.Mixture( cat=ds.Categorical(probs=[.3, .7]), components=[ds.Gamma(1., 2.), - ds.Gamma(2., 1.)]) + ds.Gamma(2., 1.)], + use_static_graph=self.use_static_graph) x_ = gm.sample().eval() self.assertAllEqual([], x_.shape) +class MixtureStaticSampleTest(MixtureTest): + use_static_graph = True + + class MixtureBenchmark(test.Benchmark): + use_static_graph = False def _runSamplingBenchmark(self, name, create_distribution, use_gpu, num_components, batch_size, num_features, @@ -811,7 +863,7 @@ class MixtureBenchmark(test.Benchmark): components = list( ds.MultivariateNormalDiag( loc=mu, scale_diag=sigma) for (mu, sigma) in zip(mus, sigmas)) - return ds.Mixture(cat, components) + return ds.Mixture(cat, components, use_static_graph=self.use_static_graph) for use_gpu in False, True: if use_gpu and not test.is_gpu_available(): @@ -853,7 +905,7 @@ class MixtureBenchmark(test.Benchmark): ds.MultivariateNormalTriL( loc=mu, scale_tril=linalg_ops.cholesky(sigma)) for (mu, sigma) in zip(mus, sigmas)) - return ds.Mixture(cat, components) + return ds.Mixture(cat, components, use_static_graph=self.use_static_graph) for use_gpu in False, True: if use_gpu and not test.is_gpu_available(): @@ -872,5 +924,9 @@ class MixtureBenchmark(test.Benchmark): sample_size=sample_size) +class MixtureStaticSampleBenchmark(MixtureBenchmark): + use_static_graph = True + + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py b/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py index dc8ae1eed19eda772219287d8661f534ac242d10..5251dbcb5748f75688aa43ce6e4e9dbd76be78bb 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py @@ -237,6 +237,11 @@ class MaskedAutoregressiveFlow(bijector_lib.Bijector): return y event_size = array_ops.shape(x)[-1] + # If the event size is available at graph construction time, we can inform + # the graph compiler of the maximum number of steps. If not, + # static_event_size will be None, and the maximum_iterations argument will + # have no effect. + static_event_size = x.shape.with_rank_at_least(1)[-1].value y0 = array_ops.zeros_like(x, name="y0") # call the template once to ensure creation _ = self._shift_and_log_scale_fn(y0) @@ -258,7 +263,8 @@ class MaskedAutoregressiveFlow(bijector_lib.Bijector): _, y = control_flow_ops.while_loop( cond=lambda index, _: index < event_size, body=_loop_body, - loop_vars=[0, y0]) + loop_vars=(0, y0), + maximum_iterations=static_event_size) return y def _inverse(self, y): diff --git a/tensorflow/contrib/distributions/python/ops/distribution_util.py b/tensorflow/contrib/distributions/python/ops/distribution_util.py index a4d249d41ec9733721a3583d3708e0da56db1733..289e1d50e1146a641c0cc433ece3465aed73b1c2 100644 --- a/tensorflow/contrib/distributions/python/ops/distribution_util.py +++ b/tensorflow/contrib/distributions/python/ops/distribution_util.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib import linalg +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops @@ -442,6 +443,44 @@ def maybe_check_scalar_distribution( return assertions +def pad_mixture_dimensions(x, mixture_distribution, categorical_distribution, + event_ndims): + """Pad dimensions of event tensors for mixture distributions. + + See `Mixture._sample_n` and `MixtureSameFamily._sample_n` for usage examples. + + Args: + x: event tensor to pad. + mixture_distribution: Base distribution of the mixture. + categorical_distribution: `Categorical` distribution that mixes the base + distribution. + event_ndims: Integer specifying the number of event dimensions in the event + tensor. + + Returns: + A padded version of `x` that can broadcast with `categorical_distribution`. + """ + with ops.name_scope("pad_mix_dims", values=[x]): + def _get_ndims(d): + if d.batch_shape.ndims is not None: + return d.batch_shape.ndims + return array_ops.shape(d.batch_shape_tensor())[0] + dist_batch_ndims = _get_ndims(mixture_distribution) + cat_batch_ndims = _get_ndims(categorical_distribution) + pad_ndims = array_ops.where( + categorical_distribution.is_scalar_batch(), + dist_batch_ndims, + dist_batch_ndims - cat_batch_ndims) + s = array_ops.shape(x) + x = array_ops.reshape(x, shape=array_ops.concat([ + s[:-1], + array_ops.ones([pad_ndims], dtype=dtypes.int32), + s[-1:], + array_ops.ones([event_ndims], dtype=dtypes.int32), + ], axis=0)) + return x + + def static_value(x): """Returns the static value of a `Tensor` or `None`.""" return tensor_util.constant_value(ops.convert_to_tensor(x)) diff --git a/tensorflow/contrib/distributions/python/ops/kumaraswamy.py b/tensorflow/contrib/distributions/python/ops/kumaraswamy.py new file mode 100644 index 0000000000000000000000000000000000000000..74d5d8773cf3e69a52554c87d656fea2835c8354 --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/kumaraswamy.py @@ -0,0 +1,258 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""The Kumaraswamy distribution class.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.ops import special_math_ops +from tensorflow.python.ops.distributions import beta +from tensorflow.python.ops.distributions import distribution +from tensorflow.python.ops.distributions import util as distribution_util +from tensorflow.python.util.tf_export import tf_export + +__all__ = [ + "Kumaraswamy", +] + +_kumaraswamy_sample_note = """Note: `x` must have dtype `self.dtype` and be in +`[0, 1].` It must have a shape compatible with `self.batch_shape()`.""" + + +def _harmonic_number(x): + """Compute the harmonic number from its analytic continuation. + + Derivation from [1] and Euler's constant [2]. + [1] - + https://en.wikipedia.org/wiki/Digamma_function#Relation_to_harmonic_numbers + [2] - https://en.wikipedia.org/wiki/Euler%E2%80%93Mascheroni_constant + + + Args: + x: input float. + + Returns: + z: The analytic continuation of the harmonic number for the input. + + """ + one = array_ops.ones([], dtype=x.dtype) + return math_ops.digamma(x + one) - math_ops.digamma(one) + + +@tf_export("distributions.Kumaraswamy") +class Kumaraswamy(beta.Beta): + """Kumaraswamy distribution. + + The Kumaraswamy distribution is defined over the `(0, 1)` interval using + parameters + `concentration1` (aka "alpha") and `concentration0` (aka "beta"). It has a + shape similar to the Beta distribution, but is reparameterizeable. + + #### Mathematical Details + + The probability density function (pdf) is, + + ```none + pdf(x; alpha, beta) = alpha * beta * x**(alpha - 1) * (1 - x**alpha)**(beta - + 1) + ``` + + where: + + * `concentration1 = alpha`, + * `concentration0 = beta`, + + Distribution parameters are automatically broadcast in all functions; see + examples for details. + + #### Examples + + ```python + # Create a batch of three Kumaraswamy distributions. + alpha = [1, 2, 3] + beta = [1, 2, 3] + dist = Kumaraswamy(alpha, beta) + + dist.sample([4, 5]) # Shape [4, 5, 3] + + # `x` has three batch entries, each with two samples. + x = [[.1, .4, .5], + [.2, .3, .5]] + # Calculate the probability of each pair of samples under the corresponding + # distribution in `dist`. + dist.prob(x) # Shape [2, 3] + ``` + + ```python + # Create batch_shape=[2, 3] via parameter broadcast: + alpha = [[1.], [2]] # Shape [2, 1] + beta = [3., 4, 5] # Shape [3] + dist = Kumaraswamy(alpha, beta) + + # alpha broadcast as: [[1., 1, 1,], + # [2, 2, 2]] + # beta broadcast as: [[3., 4, 5], + # [3, 4, 5]] + # batch_Shape [2, 3] + dist.sample([4, 5]) # Shape [4, 5, 2, 3] + + x = [.2, .3, .5] + # x will be broadcast as [[.2, .3, .5], + # [.2, .3, .5]], + # thus matching batch_shape [2, 3]. + dist.prob(x) # Shape [2, 3] + ``` + + """ + + def __init__(self, + concentration1=None, + concentration0=None, + validate_args=False, + allow_nan_stats=True, + name="Kumaraswamy"): + """Initialize a batch of Kumaraswamy distributions. + + Args: + concentration1: Positive floating-point `Tensor` indicating mean + number of successes; aka "alpha". Implies `self.dtype` and + `self.batch_shape`, i.e., + `concentration1.shape = [N1, N2, ..., Nm] = self.batch_shape`. + concentration0: Positive floating-point `Tensor` indicating mean + number of failures; aka "beta". Otherwise has same semantics as + `concentration1`. + validate_args: Python `bool`, default `False`. When `True` distribution + parameters are checked for validity despite possibly degrading runtime + performance. When `False` invalid inputs may silently render incorrect + outputs. + allow_nan_stats: Python `bool`, default `True`. When `True`, statistics + (e.g., mean, mode, variance) use the value "`NaN`" to indicate the + result is undefined. When `False`, an exception is raised if one or + more of the statistic's batch members are undefined. + name: Python `str` name prefixed to Ops created by this class. + """ + super(Kumaraswamy, self).__init__( + concentration1=concentration1, + concentration0=concentration0, + validate_args=validate_args, + allow_nan_stats=allow_nan_stats, + name=name) + self._reparameterization_type = distribution.FULLY_REPARAMETERIZED + + def _sample_n(self, n, seed=None): + expanded_concentration1 = array_ops.ones_like( + self.total_concentration, dtype=self.dtype) * self.concentration1 + expanded_concentration0 = array_ops.ones_like( + self.total_concentration, dtype=self.dtype) * self.concentration0 + shape = array_ops.concat([[n], self.batch_shape_tensor()], 0) + uniform_sample = random_ops.random_uniform( + shape=shape, minval=0.0, maxval=1.0, dtype=self.dtype, seed=seed) + + kumaraswamy_sample = (1 - uniform_sample**(1. / expanded_concentration0))**( + 1. / expanded_concentration1) + return kumaraswamy_sample + + @distribution_util.AppendDocstring(_kumaraswamy_sample_note) + def _log_cdf(self, x): + a = self.concentration1 + b = self.concentration0 + return math_ops.log1p(-(1 - x**a)**b) + + @distribution_util.AppendDocstring(_kumaraswamy_sample_note) + def _cdf(self, x): + a = self.concentration1 + b = self.concentration0 + return 1 - (1 - x**a)**b + + def _survival_function(self, x): + a = self.concentration1 + b = self.concentration0 + return (1 - x**a)**b + + def _log_survival_function(self, x): + a = self.concentration1 + b = self.concentration0 + return b * math_ops.log1p(-x**a) + + def _log_unnormalized_prob(self, x): + x = self._maybe_assert_valid_sample(x) + a = self.concentration1 + b = self.concentration0 + return (a - 1) * math_ops.log(x) + (b - 1) * math_ops.log1p(-x**a) + + def _log_normalization(self): + a = self.concentration1 + b = self.concentration0 + return -(math_ops.log(a) + math_ops.log(b)) + + def _entropy(self): + a = self.concentration1 + b = self.concentration0 + return (1 - 1. / a) + ( + 1 - 1. / b) * _harmonic_number(b) + math_ops.log(a) + math_ops.log(b) + + def _moment(self, n): + """Compute the n'th (uncentered) moment.""" + expanded_concentration1 = array_ops.ones_like( + self.total_concentration, dtype=self.dtype) * self.concentration1 + expanded_concentration0 = array_ops.ones_like( + self.total_concentration, dtype=self.dtype) * self.concentration0 + beta_arg0 = 1 + n / expanded_concentration1 + beta_arg = array_ops.stack([beta_arg0, expanded_concentration0], -1) + log_moment = math_ops.log(expanded_concentration0) + special_math_ops.lbeta( + beta_arg) + return math_ops.exp(log_moment) + + def _mean(self): + return self._moment(1) + + def _variance(self): + # TODO(b/72696533): Investigate a more numerically stable version. + return self._moment(2) - math_ops.square(self._moment(1)) + + @distribution_util.AppendDocstring( + """Note: The mode is undefined when `concentration1 <= 1` or + `concentration0 <= 1`. If `self.allow_nan_stats` is `True`, `NaN` + is used for undefined modes. If `self.allow_nan_stats` is `False` an + exception is raised when one or more modes are undefined.""") + def _mode(self): + a = self.concentration1 + b = self.concentration0 + mode = ((a - 1) / (a * b - 1))**(1. / a) + if self.allow_nan_stats: + nan = array_ops.fill( + self.batch_shape_tensor(), + np.array(np.nan, dtype=self.dtype.as_numpy_dtype), + name="nan") + is_defined = (self.concentration1 > 1.) & (self.concentration0 > 1.) + return array_ops.where(is_defined, mode, nan) + return control_flow_ops.with_dependencies([ + check_ops.assert_less( + array_ops.ones([], dtype=self.dtype), + self.concentration1, + message="Mode undefined for concentration1 <= 1."), + check_ops.assert_less( + array_ops.ones([], dtype=self.dtype), + self.concentration0, + message="Mode undefined for concentration0 <= 1.") + ], mode) diff --git a/tensorflow/contrib/distributions/python/ops/mixture.py b/tensorflow/contrib/distributions/python/ops/mixture.py index f2d492f5489a197157558ae727416b51db04793e..cef6a143fc615901315a3780bf4ed53b8c7cd177 100644 --- a/tensorflow/contrib/distributions/python/ops/mixture.py +++ b/tensorflow/contrib/distributions/python/ops/mixture.py @@ -71,6 +71,7 @@ class Mixture(distribution.Distribution): components, validate_args=False, allow_nan_stats=True, + use_static_graph=False, name="Mixture"): """Initialize a Mixture distribution. @@ -96,6 +97,11 @@ class Mixture(distribution.Distribution): exception if a statistic (e.g. mean/mode/etc...) is undefined for any batch member. If `True`, batch members with valid parameters leading to undefined statistics will return NaN for this statistic. + use_static_graph: Calls to `sample` will not rely on dynamic tensor + indexing, allowing for some static graph compilation optimizations, but + at the expense of sampling all underlying distributions in the mixture. + (Possibly useful when running on TPUs). + Default value: `False` (i.e., use dynamic indexing). name: A name for this distribution (optional). Raises: @@ -178,6 +184,10 @@ class Mixture(distribution.Distribution): self._static_event_shape = static_event_shape self._static_batch_shape = static_batch_shape + self._use_static_graph = use_static_graph + if use_static_graph and static_num_components is None: + raise ValueError("Number of categories must be known statically when " + "`static_sample=True`.") # We let the Mixture distribution access _graph_parents since its arguably # more like a baseclass. graph_parents = self._cat._graph_parents # pylint: disable=protected-access @@ -292,6 +302,31 @@ class Mixture(distribution.Distribution): return mixture_log_cdf def _sample_n(self, n, seed=None): + if self._use_static_graph: + # This sampling approach is almost the same as the approach used by + # `MixtureSameFamily`. The differences are due to having a list of + # `Distribution` objects rather than a single object, and maintaining + # random seed management that is consistent with the non-static code path. + samples = [] + cat_samples = self.cat.sample(n, seed=seed) + for c in range(self.num_components): + seed = distribution_util.gen_new_seed(seed, "mixture") + samples.append(self.components[c].sample(n, seed=seed)) + x = array_ops.stack( + samples, -self._static_event_shape.ndims - 1) # [n, B, k, E] + npdt = x.dtype.as_numpy_dtype + mask = array_ops.one_hot( + indices=cat_samples, # [n, B] + depth=self._num_components, # == k + on_value=np.ones([], dtype=npdt), + off_value=np.zeros([], dtype=npdt)) # [n, B, k] + mask = distribution_utils.pad_mixture_dimensions( + mask, self, self._cat, + self._static_event_shape.ndims) # [n, B, k, [1]*e] + return math_ops.reduce_sum( + x * mask, + axis=-1 - self._static_event_shape.ndims) # [n, B, E] + with ops.control_dependencies(self._assertions): n = ops.convert_to_tensor(n, name="n") static_n = tensor_util.constant_value(n) diff --git a/tensorflow/contrib/distributions/python/ops/mixture_same_family.py b/tensorflow/contrib/distributions/python/ops/mixture_same_family.py index 49afbea7f05136674aa0c1441bd46548b7b55c8f..b93bdc5ab4010663baddda1410b302644853648b 100644 --- a/tensorflow/contrib/distributions/python/ops/mixture_same_family.py +++ b/tensorflow/contrib/distributions/python/ops/mixture_same_family.py @@ -20,7 +20,7 @@ from __future__ import print_function import numpy as np -from tensorflow.python.framework import dtypes +from tensorflow.contrib.distributions.python.ops import distribution_util as distribution_utils from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops @@ -239,7 +239,9 @@ class MixtureSameFamily(distribution.Distribution): depth=self._num_components, # == k on_value=np.ones([], dtype=npdt), off_value=np.zeros([], dtype=npdt)) # [n, B, k] - mask = self._pad_mix_dims(mask) # [n, B, k, [1]*e] + mask = distribution_utils.pad_mixture_dimensions( + mask, self, self.mixture_distribution, + self._event_shape().ndims) # [n, B, k, [1]*e] return math_ops.reduce_sum( x * mask, axis=-1 - self._event_ndims) # [n, B, E] @@ -254,8 +256,9 @@ class MixtureSameFamily(distribution.Distribution): def _mean(self): with ops.control_dependencies(self._runtime_assertions): - probs = self._pad_mix_dims( - self.mixture_distribution.probs) # [B, k, [1]*e] + probs = distribution_utils.pad_mixture_dimensions( + self.mixture_distribution.probs, self, self.mixture_distribution, + self._event_shape().ndims) # [B, k, [1]*e] return math_ops.reduce_sum( probs * self.components_distribution.mean(), axis=-1 - self._event_ndims) # [B, E] @@ -271,8 +274,9 @@ class MixtureSameFamily(distribution.Distribution): def _variance(self): with ops.control_dependencies(self._runtime_assertions): # Law of total variance: Var(Y) = E[Var(Y|X)] + Var(E[Y|X]) - probs = self._pad_mix_dims( - self.mixture_distribution.probs) # [B, k, [1]*e] + probs = distribution_utils.pad_mixture_dimensions( + self.mixture_distribution.probs, self, self.mixture_distribution, + self._event_shape().ndims) # [B, k, [1]*e] mean_cond_var = math_ops.reduce_sum( probs * self.components_distribution.variance(), axis=-1 - self._event_ndims) # [B, E] @@ -291,8 +295,12 @@ class MixtureSameFamily(distribution.Distribution): with ops.control_dependencies(self._runtime_assertions): # Law of total variance: Var(Y) = E[Var(Y|X)] + Var(E[Y|X]) - probs = self._pad_mix_dims(self._pad_mix_dims( - self.mixture_distribution.probs)) # [B, k, 1, 1] + probs = distribution_utils.pad_mixture_dimensions( + distribution_utils.pad_mixture_dimensions( + self.mixture_distribution.probs, self, self.mixture_distribution, + self._event_shape().ndims), + self, self.mixture_distribution, + self._event_shape().ndims) # [B, k, 1, 1] mean_cond_var = math_ops.reduce_sum( probs * self.components_distribution.covariance(), axis=-3) # [B, e, e] @@ -312,27 +320,6 @@ class MixtureSameFamily(distribution.Distribution): shape[:d], [1], shape[d:]], axis=0)) return x - def _pad_mix_dims(self, x): - with ops.name_scope("pad_mix_dims", values=[x]): - def _get_ndims(d): - if d.batch_shape.ndims is not None: - return d.batch_shape.ndims - return array_ops.shape(d.batch_shape_tensor())[0] - dist_batch_ndims = _get_ndims(self) - cat_batch_ndims = _get_ndims(self.mixture_distribution) - pad_ndims = array_ops.where( - self.mixture_distribution.is_scalar_batch(), - dist_batch_ndims, - dist_batch_ndims - cat_batch_ndims) - s = array_ops.shape(x) - x = array_ops.reshape(x, shape=array_ops.concat([ - s[:-1], - array_ops.ones([pad_ndims], dtype=dtypes.int32), - s[-1:], - array_ops.ones([self._event_ndims], dtype=dtypes.int32), - ], axis=0)) - return x - def _outer_squared_difference(x, y): """Convenience function analogous to tf.squared_difference.""" diff --git a/tensorflow/contrib/eager/README.md b/tensorflow/contrib/eager/README.md index 09242ee47ddd044dfc99e22d5b7751a989c86485..9d2ca07c3a25fa7acb9b0f5806b763d9a57b51fa 100644 --- a/tensorflow/contrib/eager/README.md +++ b/tensorflow/contrib/eager/README.md @@ -41,28 +41,8 @@ support for distributed and multi-GPU training and CPU performance. ## Installation -Since eager execution is not yet part of a TensorFlow release, using it requires -either [building from source](https://www.tensorflow.org/install/install_sources) -or the latest nightly builds. The nightly builds are available as: - -- [`pip` packages](https://github.com/tensorflow/tensorflow/blob/master/README.md#installation) and - -- [docker](https://hub.docker.com/r/tensorflow/tensorflow/) images. - -For example, to run the latest nightly docker image: - -```sh -# If you have a GPU, use https://github.com/NVIDIA/nvidia-docker -nvidia-docker pull tensorflow/tensorflow:nightly-gpu -nvidia-docker run -it -p 8888:8888 tensorflow/tensorflow:nightly-gpu - -# If you do not have a GPU, use the CPU-only image -docker pull tensorflow/tensorflow:nightly -docker run -it -p 8888:8888 tensorflow/tensorflow:nightly -``` - -And then visit http://localhost:8888 in your browser for a Jupyter notebook -environment. Try out the notebooks below. +Eager execution is included in TensorFlow versions 1.5 and above. +Installation instructions at https://www.tensorflow.org/install/ ## Documentation diff --git a/tensorflow/contrib/eager/python/examples/mnist/mnist.py b/tensorflow/contrib/eager/python/examples/mnist/mnist.py index 2a7be95811f6fff06e2c489890703561ed879c42..772f59562ba27cce510c82681f491d005298f44c 100644 --- a/tensorflow/contrib/eager/python/examples/mnist/mnist.py +++ b/tensorflow/contrib/eager/python/examples/mnist/mnist.py @@ -39,7 +39,7 @@ class MNISTModel(tfe.Network): """MNIST Network. Network structure is equivalent to: - https://github.com/tensorflow/tensorflow/blob/r1.5/tensorflow/examples/tutorials/mnist/mnist_deep.py + https://github.com/tensorflow/tensorflow/blob/r1.6/tensorflow/examples/tutorials/mnist/mnist_deep.py and https://github.com/tensorflow/models/blob/master/tutorials/image/mnist/convolutional.py @@ -95,8 +95,7 @@ class MNISTModel(tfe.Network): x = self.max_pool2d(x) x = tf.layers.flatten(x) x = self.fc1(x) - if training: - x = self.dropout(x) + x = self.dropout(x, training=training) x = self.fc2(x) return x diff --git a/tensorflow/contrib/eager/python/g3doc/guide.md b/tensorflow/contrib/eager/python/g3doc/guide.md index 7eea93ce1f5aefe82d73b49f57b636692818ba16..ffc1d0332eae605ce0444a225e53baa68954cae0 100644 --- a/tensorflow/contrib/eager/python/g3doc/guide.md +++ b/tensorflow/contrib/eager/python/g3doc/guide.md @@ -19,29 +19,34 @@ to models defined without using eager execution. ## Installation -Eager execution is **not** included in the latest release (version 1.4) of -TensorFlow. To use it, you will need to [build TensorFlow from -source](https://www.tensorflow.org/install/install_sources) or install the -nightly builds. +Eager execution is included in TensorFlow versions 1.5 and above. +Installation instructions at https://www.tensorflow.org/install/ -For example, the nightly builds can be installed using `pip`: +The contents of this guide are compatible with TensorFlow 1.5. +However, if you run into bugs that are fixed in source but not the +release, you may want to either either [building from +source](https://www.tensorflow.org/install/install_sources) +or the try latest nightly builds. The nightly builds are available as: -- `pip install tf-nightly` (for CPU-only TensorFlow) -- `pip install tf-nightly-gpu` (for GPU-enabled TensorFlow) +- [`pip` packages](https://github.com/tensorflow/tensorflow/blob/master/README.md#installation) and -Or using `docker`, with [Jupyter Notebook](http://jupyter.org/) support: +- [docker](https://hub.docker.com/r/tensorflow/tensorflow/) images. + +For example, to run the latest nightly docker image: ```sh -# For CPU-only TensorFlow +# If you have a GPU, use https://github.com/NVIDIA/nvidia-docker +docker pull tensorflow/tensorflow:nightly-gpu +docker run --runtime=nvidia -it -p 8888:8888 tensorflow/tensorflow:nightly-gpu + +# If you do not have a GPU, use the CPU-only image docker pull tensorflow/tensorflow:nightly docker run -it -p 8888:8888 tensorflow/tensorflow:nightly - -# For GPU-enabled TensorFlow: -# (Requires https://github.com/NVIDIA/nvidia-docker) -nvidia-docker pull tensorflow/tensorflow:nightly-gpu -nvidia-docker run -it -p 8888:8888 tensorflow/tensorflow:nightly-gpu ``` +And then visit http://localhost:8888 in your browser for a Jupyter notebook +environment. + ## Getting Started With TensorFlow installed, eager execution is enabled via a single call: diff --git a/tensorflow/contrib/eager/python/saver_test.py b/tensorflow/contrib/eager/python/saver_test.py index abc7e3690c76c4446bce6b945325f1ca15ef1c8b..1a7f7b85e688e80e3cf482f2754462888187d311 100644 --- a/tensorflow/contrib/eager/python/saver_test.py +++ b/tensorflow/contrib/eager/python/saver_test.py @@ -73,16 +73,6 @@ class SaverTest(test.TestCase): with self.assertRaisesRegexp(ValueError, 'v1'): saver.save(ckpt_prefix) - def testDifferentGraphError(self): - with ops.device(self._dev()): - with ops.Graph().as_default(): - v1 = resource_variable_ops.ResourceVariable(1.0, name='v1') - with ops.Graph().as_default(): - saver = _saver.Saver([v1]) - ckpt_prefix = os.path.join(test.get_temp_dir(), 'ckpt') - with self.assertRaisesRegexp(ValueError, 'Graph'): - saver.save(ckpt_prefix) - def testSameObjectOK(self): with ops.device(self._dev()): v1 = resource_variable_ops.ResourceVariable(1.0, name='v1') diff --git a/tensorflow/contrib/eager/python/tfe.py b/tensorflow/contrib/eager/python/tfe.py index 712d1cb94d2f565bf6216f6c07a45d3d855efe9c..d32bebf90c1e768d1efec26b3b78bf1a522a8f00 100644 --- a/tensorflow/contrib/eager/python/tfe.py +++ b/tensorflow/contrib/eager/python/tfe.py @@ -59,7 +59,6 @@ To use, at program startup, call `tfe.enable_eager_execution()`. @@in_eager_mode @@in_graph_mode -@@IsolateTest @@run_test_in_graph_and_eager_modes @@DEVICE_PLACEMENT_EXPLICIT @@ -101,7 +100,6 @@ from tensorflow.python.eager.execution_callbacks import nan_callback from tensorflow.python.eager.execution_callbacks import seterr from tensorflow.python.framework.ops import enable_eager_execution from tensorflow.python.framework.ops import eager_run as run -from tensorflow.python.framework.test_util import IsolateTest from tensorflow.python.framework.test_util import run_in_graph_and_eager_modes as run_test_in_graph_and_eager_modes from tensorflow.python.ops.resource_variable_ops import ResourceVariable as Variable from tensorflow.python.ops.variable_scope import EagerVariableStore diff --git a/tensorflow/contrib/eager/python/tfe_test.py b/tensorflow/contrib/eager/python/tfe_test.py index 0dedb2fd7c0905801cd87c239ff2ee09eecb6080..b6659c2a1797feab261d756e78b45231dbea5a02 100644 --- a/tensorflow/contrib/eager/python/tfe_test.py +++ b/tensorflow/contrib/eager/python/tfe_test.py @@ -102,10 +102,6 @@ class TFETest(test_util.TensorFlowTestCase): # Expect at least one device. self.assertTrue(tfe.list_devices()) - def testNumGPUs(self): - devices = tfe.list_devices() - self.assertEqual(len(devices) - 1, tfe.num_gpus()) - def testAddCheckNumericsOpsRaisesError(self): with self.assertRaisesRegexp( RuntimeError, diff --git a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py index caa9dd83233b6b850385335fde96431271d85c3a..dfae034afc9a115dcc97e401e8a6d9c66a9b46e9 100644 --- a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py +++ b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py @@ -195,7 +195,7 @@ def _replicate_model_fn_with_mode( if not devices: devices = _get_local_devices('GPU') or _get_local_devices('CPU') - is_a_single_gpu_case = len(devices) == 1 and 'GPU' in devices[0] + is_a_single_gpu_case = len(devices) == 1 and 'GPU' in devices[0].upper() consolidation_device = devices[0] if is_a_single_gpu_case else '/CPU:0' ps_devices = [consolidation_device] @@ -457,6 +457,13 @@ def _get_local_devices(device_type): def _split_batch(features, labels, number_of_shards, device): """Split input features and labes into batches.""" + def ensure_divisible_by_shards(sequence): + batch_size = ops_lib.convert_to_tensor(sequence).get_shape()[0] + if batch_size % number_of_shards != 0: + raise ValueError( + 'Batch size {} needs to be divisible by the number of GPUs, which ' + 'is {}.'.format(batch_size, number_of_shards)) + def split_dictionary(dictionary): """Split a dictionary into shards.""" shards = [{} for _ in range(number_of_shards)] @@ -467,6 +474,7 @@ def _split_batch(features, labels, number_of_shards, device): sp_input=tensor, num_split=number_of_shards, axis=0)): shards[i][name] = shard else: + ensure_divisible_by_shards(tensor) for i, shard in enumerate(array_ops.split(tensor, number_of_shards)): shards[i][name] = shard return shards @@ -476,6 +484,7 @@ def _split_batch(features, labels, number_of_shards, device): if isinstance(features, dict): feature_shards = split_dictionary(features) else: + ensure_divisible_by_shards(features) feature_shards = array_ops.split(features, number_of_shards) if labels is None: @@ -483,6 +492,7 @@ def _split_batch(features, labels, number_of_shards, device): elif isinstance(labels, dict): label_shards = split_dictionary(labels) else: + ensure_divisible_by_shards(labels) label_shards = array_ops.split(labels, number_of_shards) return feature_shards, label_shards diff --git a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py index 03d31226af613960a19ce116b19b30153b1fdcee..ab117e61a7059a224ebf6ff0355ae10363b758f5 100644 --- a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py +++ b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py @@ -37,6 +37,7 @@ from tensorflow.python.feature_column import feature_column from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops as ops_lib +from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops @@ -433,12 +434,51 @@ class ReplicateModelTest(test_util.TensorFlowTestCase): 'probabilities': np.array([[0.1], [0.02]]) }, session.run(estimator_spec.predictions)) + def test_batch_size_that_is_not_divisible_by_the_number_of_gpus(self): + features = np.array([[1.0], [2.0], [3.0]]) + labels = np.array([[1.0], [2.0], [3.0]]) + + with self.assertRaisesRegexp( + ValueError, '.*Batch.+size.+needs.+to.+be.+divisible.+by.+GPUs.+'): + replicated_model_fn = replicate_model_fn.replicate_model_fn( + self.model_fn, devices=['/gpu:0', '/gpu:1']) + _ = replicated_model_fn( + features, labels, model_fn_lib.ModeKeys.TRAIN, self.params) + def test_unsupported_loss_reduction(self): with self.assertRaisesRegexp(ValueError, '.+none.+reduction.+is.+specified.+'): _ = replicate_model_fn.replicate_model_fn(self.model_fn, losses.Reduction.NONE) + def test_places_on_gpu_with_upper_case_spelling(self): + features = np.array([[0.01], [0.002]]) + labels = np.array([[0.01], [0.02]]) + + with self.test_session(): + replicated_model_fn = replicate_model_fn.replicate_model_fn( + self.model_fn, devices=['/GPU:0']) + _ = replicated_model_fn( + features, labels, model_fn_lib.ModeKeys.TRAIN, self.params) + + with variable_scope.variable_scope('', reuse=True): + c = variable_scope.get_variable('c', dtype=dtypes.float64) + self.assertEqual('/device:GPU:0', c.device) + + def test_places_on_gpu_with_lower_case_spelling(self): + features = np.array([[0.01], [0.002]]) + labels = np.array([[0.01], [0.02]]) + + with self.test_session(): + replicated_model_fn = replicate_model_fn.replicate_model_fn( + self.model_fn, devices=['/gpu:0']) + _ = replicated_model_fn( + features, labels, model_fn_lib.ModeKeys.TRAIN, self.params) + + with variable_scope.variable_scope('', reuse=True): + c = variable_scope.get_variable('c', dtype=dtypes.float64) + self.assertEqual('/device:GPU:0', c.device) + class ReplicateAcrossASingleDeviceWithoutTowerOptimizer( test_util.TensorFlowTestCase): @@ -981,8 +1021,13 @@ class SplitBatchTest(test_util.TensorFlowTestCase): return list(map(evaluate_items, first_list)), list( map(evaluate_items, second_list)) + def assertSparseValuesEqual(self, a, b): + self.assertAllEqual(a.indices, b.indices) + self.assertAllEqual(a.values, b.values) + self.assertAllEqual(a.dense_shape, b.dense_shape) + def test_simple_half_split(self): - with self.test_session() as session: # pylint: disable=unused-variable + with self.test_session(): features = [0.0, 1.0, 2.0, 3.0] labels = [10.0, 11.0, 12.0, 13.0] feature_shards, label_shards = replicate_model_fn._split_batch( @@ -995,7 +1040,7 @@ class SplitBatchTest(test_util.TensorFlowTestCase): self.assertAllEqual([[10.0, 11.0], [12.0, 13.0]], label_shards) def test_to_each_their_own(self): - with self.test_session() as session: # pylint: disable=unused-variable + with self.test_session(): features = [0.0, 1.0, 2.0, 3.0] labels = [10.0, 11.0, 12.0, 13.0] feature_shards, label_shards = replicate_model_fn._split_batch( @@ -1008,7 +1053,7 @@ class SplitBatchTest(test_util.TensorFlowTestCase): self.assertAllEqual([[10.0], [11.0], [12.0], [13.0]], label_shards) def test_one_batch(self): - with self.test_session() as session: # pylint: disable=unused-variable + with self.test_session(): features = [0.0, 1.0, 2.0, 3.0] labels = [10.0, 11.0, 12.0, 13.0] feature_shards, label_shards = replicate_model_fn._split_batch( @@ -1021,7 +1066,7 @@ class SplitBatchTest(test_util.TensorFlowTestCase): self.assertAllEqual([[10.0, 11.0, 12.0, 13.0]], label_shards) def test_half_split_in_dictionary(self): - with self.test_session() as session: # pylint: disable=unused-variable + with self.test_session(): features = {'first': [0.0, 1.0, 2.0, 3.0], 'second': [4.0, 5.0, 6.0, 7.0]} labels = [10.0, 11.0, 12.0, 13.0] @@ -1035,6 +1080,60 @@ class SplitBatchTest(test_util.TensorFlowTestCase): self.assertAllEqual([10.0, 11.0], label_shards[0].eval()) self.assertAllEqual([12.0, 13.0], label_shards[1].eval()) + def test_sparse_tensor_can_be_split_unevenly(self): + with self.test_session(): + features = { + 'x': + sparse_tensor.SparseTensor( + indices=[[0, 0], [1, 2], [2, 2]], + values=[1.0, 2.0, 3.0], + dense_shape=[3, 4]) + } + labels = np.array([[1.0], [2.0]]) + + feature_shards, label_shards = replicate_model_fn._split_batch( + features, labels, 2, device='/gpu:0') + + self.assertSparseValuesEqual( + sparse_tensor.SparseTensorValue( + indices=[[0, 0], [1, 2]], values=[1., 2.], dense_shape=[2, 4]), + feature_shards[0]['x'].eval()) + self.assertSparseValuesEqual( + sparse_tensor.SparseTensorValue( + indices=[[0, 2]], values=[3.], dense_shape=[1, 4]), + feature_shards[1]['x'].eval()) + self.assertAllEqual([[1.0]], label_shards[0].eval()) + self.assertAllEqual([[2.0]], label_shards[1].eval()) + + def test_sparse_tensor_can_be_split_unevenly_repeated_row(self): + with self.test_session(): + features = { + 'x': + sparse_tensor.SparseTensor( + indices=[[0, 0], [1, 0], [1, 1]], + values=[1.0, 2.0, 3.0], + dense_shape=[3, 4]) + } + labels = np.array([[1.0], [2.0]]) + + feature_shards, label_shards = replicate_model_fn._split_batch( + features, labels, 2, device='/gpu:0') + + print(feature_shards[0]['x'].eval()) + print(feature_shards[1]['x'].eval()) + self.assertSparseValuesEqual( + sparse_tensor.SparseTensorValue( + indices=[[0, 0], [1, 0], [1, 1]], + values=[1., 2., 3.], + dense_shape=[2, 4]), feature_shards[0]['x'].eval()) + + second_batch = feature_shards[1]['x'].eval() + self.assertFalse(len(second_batch.indices)) + self.assertFalse(len(second_batch.values)) + self.assertAllEqual([1, 4], second_batch.dense_shape) + self.assertAllEqual([[1.0]], label_shards[0].eval()) + self.assertAllEqual([[2.0]], label_shards[1].eval()) + def test_one_batch_in_dictionary(self): with self.test_session() as session: # pylint: disable=unused-variable features = {'first': [0.0, 1.0, 2.0, 3.0], 'second': [4.0, 5.0, 6.0, 7.0]} diff --git a/tensorflow/contrib/factorization/kernels/masked_matmul_ops.cc b/tensorflow/contrib/factorization/kernels/masked_matmul_ops.cc index 31d08bfb65ea49e1378ffba480771d38ce16abec..a8c5d0763c28ba2b54f217405f0da65533f26b91 100644 --- a/tensorflow/contrib/factorization/kernels/masked_matmul_ops.cc +++ b/tensorflow/contrib/factorization/kernels/masked_matmul_ops.cc @@ -57,11 +57,11 @@ typedef Eigen::Map< class MaskedMatmulOp : public OpKernel { public: - explicit MaskedMatmulOp(OpKernelConstruction* context) - : OpKernel(context) { - OP_REQUIRES_OK(context, context->MatchSignature( - {DT_FLOAT, DT_FLOAT, DT_INT64, DT_BOOL, DT_BOOL}, - {DT_FLOAT})); + explicit MaskedMatmulOp(OpKernelConstruction* context) : OpKernel(context) { + OP_REQUIRES_OK( + context, + context->MatchSignature( + {DT_FLOAT, DT_FLOAT, DT_INT64, DT_BOOL, DT_BOOL}, {DT_FLOAT})); } void Compute(OpKernelContext* context) override { @@ -110,12 +110,11 @@ class MaskedMatmulOp : public OpKernel { num_nonzero_elements, 2); Tensor* prod_values_tensor; - OP_REQUIRES_OK(context, - context->allocate_output( - 0, TensorShape({num_nonzero_elements}), - &prod_values_tensor)); - EigenMatFloatMap prod_values(prod_values_tensor->vec().data(), - 1, num_nonzero_elements); + OP_REQUIRES_OK(context, context->allocate_output( + 0, TensorShape({num_nonzero_elements}), + &prod_values_tensor)); + EigenMatFloatMap prod_values(prod_values_tensor->vec().data(), 1, + num_nonzero_elements); auto get_a_index = [&indices_mat, &a_dim_0](int64 i) { int64 a_index = internal::SubtleMustCopy(indices_mat(i, 0)); @@ -182,8 +181,8 @@ class MaskedMatmulOp : public OpKernel { } }; // Shard the work. - worker_threads.workers->ParallelFor( - num_nonzero_elements, cost_per_unit, work); + worker_threads.workers->ParallelFor(num_nonzero_elements, cost_per_unit, + work); } }; REGISTER_KERNEL_BUILDER(Name("MaskedMatmul").Device(DEVICE_CPU), diff --git a/tensorflow/contrib/factorization/python/ops/kmeans.py b/tensorflow/contrib/factorization/python/ops/kmeans.py index 4d0f9b24240ccbafe89ef912b4d3252cefb1f7f2..c861cfff544a78617aa1ace730b50c094cf16330 100644 --- a/tensorflow/contrib/factorization/python/ops/kmeans.py +++ b/tensorflow/contrib/factorization/python/ops/kmeans.py @@ -143,7 +143,7 @@ class _ModelFn(object): def model_fn(self, features, mode, config): """Model function for the estimator. - Note that this does not take a `1abels` arg. This works, but `input_fn` must + Note that this does not take a `labels` arg. This works, but `input_fn` must return either `features` or, equivalently, `(features, None)`. Args: diff --git a/tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc b/tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc index c85b1837ab5b0c1a3cea0525918f7717228d2fab..e61221a6b0d34373279a379f356c99c379488182 100644 --- a/tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc +++ b/tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc @@ -47,20 +47,19 @@ std::vector FfmpegAudioCommandLine(const string& input_filename, int32 channel_count, const string& stream) { std::vector command({ - "-nostats", // No additional progress display. - "-nostdin", // No interactive commands accepted. - "-f", input_format_id, // eg: "mp3" - "-probesize", StrCat(kDefaultProbeSize), "-i", input_filename, - "-loglevel", "error", // Print errors only. - "-hide_banner", // Skip printing build options, version, etc. - "-map_metadata", "-1", // Copy global metadata from input to output. - "-vn", // No video recording. - "-ac:a:0", StrCat(channel_count), "-ar:a:0", - StrCat(samples_per_second), - // Output set (in several ways) to signed 16-bit little-endian ints. - "-codec:a:0", "pcm_s16le", "-sample_fmt", "s16", "-f", "s16le", - "-sn", // No subtitle recording. - "-y" // Overwrite output file. + "-nostats", // No additional progress display. + "-nostdin", // No interactive commands accepted. + "-f", input_format_id, // eg: "mp3" + "-probesize", StrCat(kDefaultProbeSize), "-i", input_filename, + "-loglevel", "error", // Print errors only. + "-hide_banner", // Skip printing build options, version, etc. + "-map_metadata", "-1", // Copy global metadata from input to output. + "-vn", // No video recording. + "-ac:a:0", StrCat(channel_count), "-ar:a:0", StrCat(samples_per_second), + // Output set (in several ways) to signed 16-bit little-endian ints. + "-codec:a:0", "pcm_s16le", "-sample_fmt", "s16", "-f", "s16le", + "-sn", // No subtitle recording. + "-y" // Overwrite output file. }); if (!stream.empty()) { command.emplace_back("-map"); @@ -75,21 +74,13 @@ std::vector FfmpegVideoCommandLine(const string& input_filename, const string& output_filename) { return {"-nostats", // No additional progress display. "-nostdin", // No interactive commands accepted. - "-i", - input_filename, - "-f", - "image2pipe", - "-probesize", - StrCat(kDefaultProbeSize), - "-loglevel", + "-i", input_filename, "-f", "image2pipe", "-probesize", + StrCat(kDefaultProbeSize), "-loglevel", // Info is needed to get the information about stream, etc. // It is generated to a separate file, not stdout/stderr. "info", "-hide_banner", // Skip printing build options, version, etc. - "-vcodec", - "rawvideo", - "-pix_fmt", - "rgb24", + "-vcodec", "rawvideo", "-pix_fmt", "rgb24", "-y", // Overwrite output file. StrCat(output_filename)}; } diff --git a/tensorflow/contrib/ffmpeg/default/ffmpeg_lib_test.cc b/tensorflow/contrib/ffmpeg/default/ffmpeg_lib_test.cc index 85b61b26163d87a10d4e316720b4f633e038bbec..05728b3d37570d06f2f8af67e3b0612d21d07601 100644 --- a/tensorflow/contrib/ffmpeg/default/ffmpeg_lib_test.cc +++ b/tensorflow/contrib/ffmpeg/default/ffmpeg_lib_test.cc @@ -32,10 +32,8 @@ namespace tensorflow { namespace ffmpeg { namespace { -const char kTestWavFilename[] = - "contrib/ffmpeg/testdata/mono_10khz.wav"; -const char kTestMp3Filename[] = - "contrib/ffmpeg/testdata/test_sound1.mp3"; +const char kTestWavFilename[] = "contrib/ffmpeg/testdata/mono_10khz.wav"; +const char kTestMp3Filename[] = "contrib/ffmpeg/testdata/test_sound1.mp3"; // Set to true via a command line flag iff the test is expected to have FFmpeg // installed. @@ -139,7 +137,7 @@ TEST(FfmpegLibTest, TestRoundTripWav) { } // namespace ffmpeg } // namespace tensorflow -int main(int argc, char **argv) { +int main(int argc, char** argv) { tensorflow::string usage = tensorflow::ffmpeg::ParseTestFlags(&argc, argv); testing::InitGoogleTest(&argc, argv); if (argc != 1) { diff --git a/tensorflow/contrib/ffmpeg/default/ffmpeg_lib_utility_test.cc b/tensorflow/contrib/ffmpeg/default/ffmpeg_lib_utility_test.cc index 36fc71794b06e0f3cb86c40b325ce50e8999c667..d6c885a32424334bfc28c830e3701f219aa244ee 100644 --- a/tensorflow/contrib/ffmpeg/default/ffmpeg_lib_utility_test.cc +++ b/tensorflow/contrib/ffmpeg/default/ffmpeg_lib_utility_test.cc @@ -20,8 +20,6 @@ #include #include - -#include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/platform/env.h" diff --git a/tensorflow/contrib/framework/__init__.py b/tensorflow/contrib/framework/__init__.py index 673c51784229bd88011f8b33fb851a2885566220..fb101c36538f72d0665c41a625824eb0d66f48ce 100644 --- a/tensorflow/contrib/framework/__init__.py +++ b/tensorflow/contrib/framework/__init__.py @@ -53,6 +53,7 @@ See the @{$python/contrib.framework} guide. @@assign_from_values_fn @@create_global_step @@filter_variables +@@fuse_op @@get_global_step @@get_or_create_global_step @@get_local_variables @@ -85,6 +86,9 @@ See the @{$python/contrib.framework} guide. @@sort @@CriticalSection + +@@BoundedTensorSpec +@@TensorSpec """ from __future__ import absolute_import @@ -99,6 +103,9 @@ from tensorflow.contrib.framework.python.ops import * from tensorflow.python.framework.ops import prepend_name_scope from tensorflow.python.framework.ops import strip_name_scope +from tensorflow.python.framework.tensor_spec import BoundedTensorSpec +from tensorflow.python.framework.tensor_spec import TensorSpec + from tensorflow.python.util.all_util import remove_undocumented _allowed_symbols = ['nest'] diff --git a/tensorflow/contrib/framework/kernels/zero_initializer_op.cc b/tensorflow/contrib/framework/kernels/zero_initializer_op.cc index 6677dca752f84fc1ba7548b7739df04b7aaf14f7..5bf6b67529579e71a615c27e035111a58d5c02e0 100644 --- a/tensorflow/contrib/framework/kernels/zero_initializer_op.cc +++ b/tensorflow/contrib/framework/kernels/zero_initializer_op.cc @@ -21,8 +21,8 @@ limitations under the License. #include "tensorflow/contrib/framework/kernels/zero_initializer_op.h" -#include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" namespace tensorflow { @@ -81,8 +81,8 @@ TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC); #define REGISTER_GPU_KERNELS(T) REGISTER_KERNELS(GPU, T); TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS); #undef REGISTER_GPU_KERNELS -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA #undef REGISTER_KERNELS -} // namespace tensorflow +} // namespace tensorflow diff --git a/tensorflow/contrib/framework/kernels/zero_initializer_op.h b/tensorflow/contrib/framework/kernels/zero_initializer_op.h index 14c9268efa869ffd48b01dd2add44990ef7a43f8..99389a5ab6aa73c2ab0e522dd0f9fbc7093c8f4a 100644 --- a/tensorflow/contrib/framework/kernels/zero_initializer_op.h +++ b/tensorflow/contrib/framework/kernels/zero_initializer_op.h @@ -29,5 +29,5 @@ struct TensorSetZero { }; } // namespace functor -} // end namespace tensorflow -#endif // TENSORFLOW_CONTRIB_FRAMEWORK_KERNELS_ZERO_INITIALIZER_OP_H_ +} // end namespace tensorflow +#endif // TENSORFLOW_CONTRIB_FRAMEWORK_KERNELS_ZERO_INITIALIZER_OP_H_ diff --git a/tensorflow/contrib/framework/ops/variable_ops.cc b/tensorflow/contrib/framework/ops/variable_ops.cc index 1ee8e1498cf07559fe3db78ef832e2cdf26bea1c..706134ba9a51de6253ba7463b17ff662ea740ed0 100644 --- a/tensorflow/contrib/framework/ops/variable_ops.cc +++ b/tensorflow/contrib/framework/ops/variable_ops.cc @@ -26,8 +26,8 @@ REGISTER_OP("ZeroInitializer") .Attr("T: realnumbertype") .SetAllowsUninitializedInput() .SetShapeFn([](InferenceContext* c) { - c->set_output(0, c->input(0)); - return Status::OK(); + c->set_output(0, c->input(0)); + return Status::OK(); }) .Doc(R"doc( Initialize 'ref' with all zeros. This op requires that the tensor is not diff --git a/tensorflow/contrib/framework/python/ops/accumulate_n_v2_test.py b/tensorflow/contrib/framework/python/ops/accumulate_n_v2_test.py index b5e9f8df79262635bf579a6bf2260bc40c140c6f..6f65fe771eb77c10d0914faa90886b587adae68c 100644 --- a/tensorflow/contrib/framework/python/ops/accumulate_n_v2_test.py +++ b/tensorflow/contrib/framework/python/ops/accumulate_n_v2_test.py @@ -31,7 +31,6 @@ from tensorflow.python.ops import variables from tensorflow.python.platform import googletest - class AccumulateNV2Test(test_util.TensorFlowTestCase): """Tests of the new, differentiable version of accumulate_n""" @@ -62,8 +61,9 @@ class AccumulateNV2Test(test_util.TensorFlowTestCase): accum_n = av2.accumulate_n_v2(input_vars) sess.run(variables.global_variables_initializer()) accum_n_grad = gradients.gradients(accum_n, input_vars) - self.assertAllEqual(np.repeat(1.0, num_inputs), # d/dx (x + y + ...) = 1 - [g.eval() for g in accum_n_grad]) + self.assertAllEqual( + np.repeat(1.0, num_inputs), # d/dx (x + y + ...) = 1 + [g.eval() for g in accum_n_grad]) # The tests below used to be in a separate class under cwise_ops_test.py, # which did not run in the default test target. @@ -75,8 +75,8 @@ class AccumulateNV2Test(test_util.TensorFlowTestCase): np.random.rand(16, 16, 16, 16).astype(np.float32) for _ in range(20) ] random_tensors = [ - ops.convert_to_tensor( - x, dtype=dtypes_lib.float32) for x in random_arrays + ops.convert_to_tensor(x, dtype=dtypes_lib.float32) + for x in random_arrays ] tf_val = av2.accumulate_n_v2(random_tensors) np_val = random_arrays[0] @@ -95,21 +95,21 @@ class AccumulateNV2Test(test_util.TensorFlowTestCase): with self.assertRaises(ValueError): a = variables.Variable(0.2) b = variables.Variable(0.1) - tf_val = av2.accumulate_n_v2([a,b], shape=[2,2]) # Should be shape=[] + tf_val = av2.accumulate_n_v2([a, b], shape=[2, 2]) # Should be shape=[] def testIncompatibleShapes(self): with self.test_session(): with self.assertRaises(ValueError): - a = variables.Variable(np.array([0.1,0.2])) - b = variables.Variable(np.array([[0.3],[0.4]])) - tf_val = av2.accumulate_n_v2([a,b]) + a = variables.Variable(np.array([0.1, 0.2])) + b = variables.Variable(np.array([[0.3], [0.4]])) + tf_val = av2.accumulate_n_v2([a, b]) def testWrongType(self): with self.test_session(): with self.assertRaises(TypeError): a = variables.Variable(0.2, dtype=np.float32) b = variables.Variable(0.1, dtype=np.float32) - tf_val = av2.accumulate_n_v2([a,b], tensor_dtype=np.int32) + tf_val = av2.accumulate_n_v2([a, b], tensor_dtype=np.int32) def testWrongTypeOneInput(self): # Scenario that used to trigger a bug, even when testWrongType() worked diff --git a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py index 0d51c282a8977871185fb4200082feb7868cdbae..082c42eba180917e732bb7890129dfa94bf00fec 100644 --- a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py +++ b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py @@ -59,7 +59,11 @@ _summary_type_map = { class GANEstimator(estimator.Estimator): """An estimator for Generative Adversarial Networks (GANs). - This Estimator is backed by TFGAN. + This Estimator is backed by TFGAN. The network functions follow the TFGAN API + except for one exception: if either `generator_fn` or `discriminator_fn` have + an argument called `mode`, then the tf.Estimator mode is passed in for that + argument. This helps with operations like batch normalization, which have + different train and evaluation behavior. Example: @@ -233,9 +237,11 @@ def _gan_model_fn( def _make_gan_model(generator_fn, discriminator_fn, real_data, generator_inputs, generator_scope, add_summaries, mode): """Make a `GANModel`, and optionally pass in `mode`.""" - # If `generator_fn` has an argument `mode`, pass mode to it. + # If network functions have an argument `mode`, pass mode to it. if 'mode' in inspect.getargspec(generator_fn).args: generator_fn = functools.partial(generator_fn, mode=mode) + if 'mode' in inspect.getargspec(discriminator_fn).args: + discriminator_fn = functools.partial(discriminator_fn, mode=mode) gan_model = tfgan_train.gan_model( generator_fn, discriminator_fn, diff --git a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py index e752f0bcccda418b79d4fdabb27807394cbbb425..387a62bd741bd42c03dc1bf70592060c29ccd7a8 100644 --- a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py +++ b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py @@ -54,7 +54,8 @@ def generator_fn(noise_dict, mode): return layers.fully_connected(noise, noise.shape[1].value) -def discriminator_fn(data, _): +def discriminator_fn(data, unused_conditioning, mode): + del unused_conditioning, mode return layers.fully_connected(data, 1) @@ -99,7 +100,6 @@ def mock_head(testcase, expected_generator_inputs, expected_real_data, else: testcase.assertEqual(discriminator_scope_name, gan_model.discriminator_scope.name) - testcase.assertEqual(_or_none(discriminator_fn), gan_model.discriminator_fn) with ops.control_dependencies(assertions): if mode == model_fn_lib.ModeKeys.TRAIN: diff --git a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py index 986a5ff6dcbeb2ff996f49137adc6d34e14c979f..d9b07e62f89d61c72a34dfa844f11ad1238fb006 100644 --- a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py +++ b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py @@ -28,6 +28,7 @@ from __future__ import division from __future__ import print_function import functools +import os import sys import tarfile @@ -189,20 +190,31 @@ def get_graph_def_from_resource(filename): return graph_pb2.GraphDef.FromString(resource_loader.load_resource(filename)) -def get_graph_def_from_url_tarball(url, filename): - """Get a GraphDef proto from a tarball on the web.""" - def _progress(count, block_size, total_size): - sys.stdout.write('\r>> Downloading %s %.1f%%' % ( - url, float(count * block_size) / float(total_size) * 100.0)) - sys.stdout.flush() - tar_filename, _ = urllib.request.urlretrieve(url, reporthook=_progress) +def get_graph_def_from_url_tarball(url, filename, tar_filename=None): + """Get a GraphDef proto from a tarball on the web. + + Args: + url: Web address of tarball + filename: Filename of graph definition within tarball + tar_filename: Temporary download filename (None = always download) + + Returns: + A GraphDef loaded from a file in the downloaded tarball. + """ + if not (tar_filename and os.path.exists(tar_filename)): + def _progress(count, block_size, total_size): + sys.stdout.write('\r>> Downloading %s %.1f%%' % ( + url, float(count * block_size) / float(total_size) * 100.0)) + sys.stdout.flush() + tar_filename, _ = urllib.request.urlretrieve(url, tar_filename, _progress) with tarfile.open(tar_filename, 'r:gz') as tar: proto_str = tar.extractfile(filename).read() return graph_pb2.GraphDef.FromString(proto_str) def _default_graph_def_fn(): - return get_graph_def_from_url_tarball(INCEPTION_URL, INCEPTION_FROZEN_GRAPH) + return get_graph_def_from_url_tarball(INCEPTION_URL, INCEPTION_FROZEN_GRAPH, + os.path.basename(INCEPTION_URL)) def run_inception(images, diff --git a/tensorflow/contrib/gan/python/eval/python/sliced_wasserstein_test.py b/tensorflow/contrib/gan/python/eval/python/sliced_wasserstein_test.py index b960af28eaa969079b72c7aabcde2ad6cd1f5c68..871f1ad54e2559f5df28efa78f99997a866f7087 100644 --- a/tensorflow/contrib/gan/python/eval/python/sliced_wasserstein_test.py +++ b/tensorflow/contrib/gan/python/eval/python/sliced_wasserstein_test.py @@ -84,11 +84,11 @@ class ClassifierMetricsTest(test.TestCase): self.assertAllClose( np.array([0.014, 0.014], 'f'), np.array([x[0] for x in wscores], 'f'), - rtol=0.1) + rtol=0.15) self.assertAllClose( np.array([0.014, 0.020], 'f'), np.array([x[1] for x in wscores], 'f'), - rtol=0.1) + rtol=0.15) def test_sliced_wasserstein_distance_svd(self): """Test the distance.""" diff --git a/tensorflow/contrib/gdr/README.md b/tensorflow/contrib/gdr/README.md index 34ce60b360822888aa6223c89362ae1b0d9d991f..8242d93f129904828a11b61d48f2df8fb0f88bc3 100644 --- a/tensorflow/contrib/gdr/README.md +++ b/tensorflow/contrib/gdr/README.md @@ -119,4 +119,4 @@ In the original design (as in the reference), tensor buffers are only registered Reference === -Bairen Yi, Jiacheng Xia, Li Chen, and Kai Chen. 2017. Towards Zero Copy Dataflows using RDMA. In Proceedings of SIGCOMM Posters and Demos'17, Los Angeles, CA, USA, August 22-24, 2017, 3 pages. https://doi.org/10.1145/3123878.3123907 +Bairen Yi, Jiacheng Xia, Li Chen, and Kai Chen. 2017. Towards Zero Copy Dataflows using RDMA. In Proceedings of SIGCOMM Posters and Demos'17, Los Angeles, CA, USA, August 22-24, 2017, 3 pages. https://doi.org/10.1145/3123878.3131975 diff --git a/tensorflow/contrib/gdr/gdr_memory_manager.cc b/tensorflow/contrib/gdr/gdr_memory_manager.cc index 5c7ac744289ab7729b4cc43ab9bedc9342284e65..81e70ae30a4c72dbcedd1aabfe758ecca4c8b366 100644 --- a/tensorflow/contrib/gdr/gdr_memory_manager.cc +++ b/tensorflow/contrib/gdr/gdr_memory_manager.cc @@ -86,8 +86,9 @@ int TryToReadNumaNode(ibv_device* device) { if (strings::safe_strto32(content, &value)) { if (value < 0) { LOG(INFO) << "Successful NUMA node read from SysFS had negative value (" - << value << "), but there must be at least one NUMA node" - ", so returning NUMA node zero"; + << value + << "), but there must be at least one NUMA node" + ", so returning NUMA node zero"; return 0; } LOG(INFO) << "NUMA node for device: " << device->name << " is " << value; @@ -290,8 +291,8 @@ Status GdrMemoryManager::Init() { // Host memory allocators for (Allocator* allocator : allocators) { auto* visitable_allocator = dynamic_cast(allocator); - CHECK(visitable_allocator) << "is not visitable for instrumentation" - << allocator->Name(); + CHECK(visitable_allocator) + << "is not visitable for instrumentation" << allocator->Name(); // Make sure we don't instrument the same allocator twice if (instrumented_.find(allocator) == std::end(instrumented_)) { visitable_allocator->AddAllocVisitor(alloc_visitor); @@ -635,8 +636,8 @@ void GdrMemoryManager::TensorFromTransportOptions( } else { checksum = GPUUtil::Checksum(*tensor); } - CHECK(checksum == remote_mr.checksum()) << "Checksum mismatch: " << checksum - << "!=" << remote_mr.checksum(); + CHECK(checksum == remote_mr.checksum()) + << "Checksum mismatch: " << checksum << "!=" << remote_mr.checksum(); #endif } done(Status::OK()); diff --git a/tensorflow/contrib/image/kernels/image_ops.cc b/tensorflow/contrib/image/kernels/image_ops.cc index 6adf837ca0ab506bd18f5e2e1fc1847e31d782bf..c2e32da133b32c8fe169302668031af8bace2c22 100644 --- a/tensorflow/contrib/image/kernels/image_ops.cc +++ b/tensorflow/contrib/image/kernels/image_ops.cc @@ -43,9 +43,9 @@ template struct FillProjectiveTransform; typedef Eigen::ThreadPoolDevice CPUDevice; using functor::FillProjectiveTransform; +using generator::Interpolation; using generator::INTERPOLATION_BILINEAR; using generator::INTERPOLATION_NEAREST; -using generator::Interpolation; using generator::ProjectiveGenerator; template @@ -72,11 +72,12 @@ class ImageProjectiveTransform : public OpKernel { const Tensor& transform_t = ctx->input(1); OP_REQUIRES(ctx, images_t.shape().dims() == 4, errors::InvalidArgument("Input images must have rank 4")); - OP_REQUIRES(ctx, (TensorShapeUtils::IsMatrix(transform_t.shape()) && - (transform_t.dim_size(0) == images_t.dim_size(0) || - transform_t.dim_size(0) == 1) && - transform_t.dim_size(1) == - ProjectiveGenerator::kNumParameters), + OP_REQUIRES(ctx, + (TensorShapeUtils::IsMatrix(transform_t.shape()) && + (transform_t.dim_size(0) == images_t.dim_size(0) || + transform_t.dim_size(0) == 1) && + transform_t.dim_size(1) == + ProjectiveGenerator::kNumParameters), errors::InvalidArgument( "Input transform should be num_images x 8 or 1 x 8")); auto images = images_t.tensor(); diff --git a/tensorflow/contrib/image/kernels/single_image_random_dot_stereograms_ops.cc b/tensorflow/contrib/image/kernels/single_image_random_dot_stereograms_ops.cc index 9f0bf37aed3fc9aeefb7602ef3fda4cfd76f1917..8f9a5c28039b74a874028826ca8a6d5a36ab7cf4 100755 --- a/tensorflow/contrib/image/kernels/single_image_random_dot_stereograms_ops.cc +++ b/tensorflow/contrib/image/kernels/single_image_random_dot_stereograms_ops.cc @@ -143,8 +143,8 @@ class SingleImageRandomDotStereogramsOp : public OpKernel { } data_box_left = deltaX_border_image / 2; // Center DATA in X dimension - data_box_width = data_Xwindow; // width of scan line - data_box_height = data_Ywindow; // hight of image + data_box_width = data_Xwindow; // width of scan line + data_box_height = data_Ywindow; // hight of image const T* inputZ = input_tensor.flat().data(); // Flatten input Z buffer diff --git a/tensorflow/contrib/image/ops/single_image_random_dot_stereograms_ops.cc b/tensorflow/contrib/image/ops/single_image_random_dot_stereograms_ops.cc index 1f41f243f2ebc0d1e884728defa160bf6d6c34ce..8139d4272d6950815bd39a64e86e0f7422e6f799 100755 --- a/tensorflow/contrib/image/ops/single_image_random_dot_stereograms_ops.cc +++ b/tensorflow/contrib/image/ops/single_image_random_dot_stereograms_ops.cc @@ -58,7 +58,9 @@ REGISTER_OP("SingleImageRandomDotStereograms") int colors; TF_RETURN_IF_ERROR(c->GetAttr("number_colors", &colors)); - c->set_output(0, c->MakeShape({y_dim, x_dim, colors > 256? c->MakeDim(3) : c->MakeDim(1)})); + c->set_output( + 0, c->MakeShape( + {y_dim, x_dim, colors > 256 ? c->MakeDim(3) : c->MakeDim(1)})); return Status::OK(); }) .Doc(R"doc( diff --git a/tensorflow/contrib/image/python/ops/image_ops.py b/tensorflow/contrib/image/python/ops/image_ops.py index 63377ae50310db51a3111c5a6e00df7d75dccc0b..c139ae89d8d682d6b87813c3a21703ffa762f28e 100644 --- a/tensorflow/contrib/image/python/ops/image_ops.py +++ b/tensorflow/contrib/image/python/ops/image_ops.py @@ -40,7 +40,7 @@ ops.RegisterShape("ImageProjectiveTransform")(common_shapes.call_cpp_shape_fn) def rotate(images, angles, interpolation="NEAREST", name=None): - """Rotate image(s) by the passed angle(s) in radians. + """Rotate image(s) counterclockwise by the passed angle(s) in radians. Args: images: A tensor of shape (num_images, num_rows, num_columns, num_channels) @@ -290,31 +290,76 @@ def compose_transforms(*transforms): """ assert transforms, "transforms cannot be empty" with ops.name_scope("compose_transforms"): - composed = _flat_transforms_to_matrices(transforms[0]) + composed = flat_transforms_to_matrices(transforms[0]) for tr in transforms[1:]: # Multiply batches of matrices. - composed = math_ops.matmul(composed, _flat_transforms_to_matrices(tr)) - return _transform_matrices_to_flat(composed) + composed = math_ops.matmul(composed, flat_transforms_to_matrices(tr)) + return matrices_to_flat_transforms(composed) -def _flat_transforms_to_matrices(transforms): - # Make the transform(s) 2D in case the input is a single transform. - transforms = array_ops.reshape(transforms, constant_op.constant([-1, 8])) - num_transforms = array_ops.shape(transforms)[0] - # Add a column of ones for the implicit last entry in the matrix. - return array_ops.reshape( - array_ops.concat( - [transforms, array_ops.ones([num_transforms, 1])], axis=1), - constant_op.constant([-1, 3, 3])) +def flat_transforms_to_matrices(transforms): + """Converts `tf.contrib.image` projective transforms to affine matrices. + Note that the output matrices map output coordinates to input coordinates. For + the forward transformation matrix, call `tf.linalg.inv` on the result. -def _transform_matrices_to_flat(transform_matrices): - # Flatten each matrix. - transforms = array_ops.reshape(transform_matrices, - constant_op.constant([-1, 9])) - # Divide each matrix by the last entry (normally 1). - transforms /= transforms[:, 8:9] - return transforms[:, :8] + Args: + transforms: Vector of length 8, or batches of transforms with shape + `(N, 8)`. + + Returns: + 3D tensor of matrices with shape `(N, 3, 3)`. The output matrices map the + *output coordinates* (in homogeneous coordinates) of each transform to the + corresponding *input coordinates*. + + Raises: + ValueError: If `transforms` have an invalid shape. + """ + with ops.name_scope("flat_transforms_to_matrices"): + transforms = ops.convert_to_tensor(transforms, name="transforms") + if transforms.shape.ndims not in (1, 2): + raise ValueError("Transforms should be 1D or 2D, got: %s" % transforms) + # Make the transform(s) 2D in case the input is a single transform. + transforms = array_ops.reshape(transforms, constant_op.constant([-1, 8])) + num_transforms = array_ops.shape(transforms)[0] + # Add a column of ones for the implicit last entry in the matrix. + return array_ops.reshape( + array_ops.concat( + [transforms, array_ops.ones([num_transforms, 1])], axis=1), + constant_op.constant([-1, 3, 3])) + + +def matrices_to_flat_transforms(transform_matrices): + """Converts affine matrices to `tf.contrib.image` projective transforms. + + Note that we expect matrices that map output coordinates to input coordinates. + To convert forward transformation matrices, call `tf.linalg.inv` on the + matrices and use the result here. + + Args: + transform_matrices: One or more affine transformation matrices, for the + reverse transformation in homogeneous coordinates. Shape `(3, 3)` or + `(N, 3, 3)`. + + Returns: + 2D tensor of flat transforms with shape `(N, 8)`, which may be passed into + `tf.contrib.image.transform`. + + Raises: + ValueError: If `transform_matrices` have an invalid shape. + """ + with ops.name_scope("matrices_to_flat_transforms"): + transform_matrices = ops.convert_to_tensor( + transform_matrices, name="transform_matrices") + if transform_matrices.shape.ndims not in (2, 3): + raise ValueError( + "Matrices should be 2D or 3D, got: %s" % transform_matrices) + # Flatten each matrix. + transforms = array_ops.reshape(transform_matrices, + constant_op.constant([-1, 9])) + # Divide each matrix by the last entry (normally 1). + transforms /= transforms[:, 8:9] + return transforms[:, :8] @ops.RegisterGradient("ImageProjectiveTransform") @@ -346,9 +391,9 @@ def _image_projective_transform_grad(op, grad): raise TypeError("Transforms should have rank 1 or 2.") # Invert transformations - transforms = _flat_transforms_to_matrices(transforms=transforms) + transforms = flat_transforms_to_matrices(transforms=transforms) inverse = linalg_ops.matrix_inverse(transforms) - transforms = _transform_matrices_to_flat(inverse) + transforms = matrices_to_flat_transforms(inverse) output = gen_image_ops.image_projective_transform( grad, transforms, interpolation=interpolation) if len(image_or_images.get_shape()) == 2: diff --git a/tensorflow/contrib/input_pipeline/kernels/input_pipeline_kernels.cc b/tensorflow/contrib/input_pipeline/kernels/input_pipeline_kernels.cc index ca288c1f737d25faac678f5c199d5c1e49f721cb..886f6798150c57d8066546b0919481d3878882fc 100644 --- a/tensorflow/contrib/input_pipeline/kernels/input_pipeline_kernels.cc +++ b/tensorflow/contrib/input_pipeline/kernels/input_pipeline_kernels.cc @@ -34,9 +34,8 @@ class ObtainNextOp : public OpKernel { // Allocate output. Tensor* output_tensor = nullptr; - OP_REQUIRES_OK( - ctx, - ctx->allocate_output("out_element", TensorShape({}), &output_tensor)); + OP_REQUIRES_OK(ctx, ctx->allocate_output("out_element", TensorShape({}), + &output_tensor)); // Obtain mutex for the "counter" tensor. mutex* mu; diff --git a/tensorflow/contrib/kafka/python/ops/kafka_dataset_ops.py b/tensorflow/contrib/kafka/python/ops/kafka_dataset_ops.py index 6590d86ebb7d9da836e5777af7d517919f4e2eff..e561f595a405280010a54d761bdb378ec0162ac0 100644 --- a/tensorflow/contrib/kafka/python/ops/kafka_dataset_ops.py +++ b/tensorflow/contrib/kafka/python/ops/kafka_dataset_ops.py @@ -30,7 +30,8 @@ class KafkaDataset(Dataset): """A Kafka Dataset that consumes the message. """ - def __init__(self, topics, servers="localhost", group="", eof=False, timeout=1000): + def __init__( + self, topics, servers="localhost", group="", eof=False, timeout=1000): """Create a KafkaReader. Args: diff --git a/tensorflow/contrib/kfac/examples/mlp.py b/tensorflow/contrib/kfac/examples/mlp.py index 0f0dbb53f45dfefe69aaa9e25caf6ba0a3cf449e..87eed03888c894a04c0521d1ce5ee8975b60776b 100644 --- a/tensorflow/contrib/kfac/examples/mlp.py +++ b/tensorflow/contrib/kfac/examples/mlp.py @@ -317,7 +317,10 @@ def train_mnist_estimator(data_dir, num_epochs, use_fake_data=False): return tf.estimator.EstimatorSpec( mode=mode, loss=loss, train_op=train_op, training_hooks=hooks) + run_config = tf.estimator.RunConfig( + model_dir="/tmp/mnist", save_checkpoints_steps=1, keep_checkpoint_max=100) + # Train until input_fn() is empty with Estimator. This is a prerequisite for # TPU compatibility. - estimator = tf.estimator.Estimator(model_fn=model_fn) + estimator = tf.estimator.Estimator(model_fn=model_fn, config=run_config) estimator.train(input_fn=input_fn) diff --git a/tensorflow/contrib/kfac/python/ops/fisher_factors.py b/tensorflow/contrib/kfac/python/ops/fisher_factors.py index f59168cbc05fffd104ff5a44308eefd206beb9db..bcba18ae147c6ceca50bc9a2a17e01fc201d88c1 100644 --- a/tensorflow/contrib/kfac/python/ops/fisher_factors.py +++ b/tensorflow/contrib/kfac/python/ops/fisher_factors.py @@ -31,6 +31,7 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn from tensorflow.python.ops import special_math_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables @@ -111,6 +112,54 @@ def diagonal_covariance_initializer(shape, dtype, partition_info): # pylint: di return array_ops.ones(shape, dtype) +def extract_image_patches(image, ksizes, strides, padding, name=None): + """Extracts image patches for an N-dimensional convolution. + + This function is a compatibility wrapper over tf.extract_image_patches(), as + ExtractImagePatches isn't yet implemented in XLA. + + Args: + image: Tensor of shape [batch, in_x, in_y, ..., in_channels]. Input images. + All dimensions except 'batch' must be defined. + ksizes: [filter_x, filter_y, ...]. Spatial shape of filter in each + dimension. + strides: [stride_x, stride_y, ...]. Spatial stride for filter in each + dimension. + padding: str. "VALID" or "SAME". + name: str or None. name of Op. + + Returns: + result: [batch, out_x, out_y, ..., filter_x, filter_y, ..., in_channels]. + Contains image patches to which conv kernel would be applied for each + output location. [out_x, out_y, ...] depends on padding. + """ + if not utils.on_tpu(): + return array_ops.extract_image_patches( + image, + ksizes=([1] + list(ksizes) + [1]), + strides=([1] + list(strides) + [1]), + rates=[1, 1, 1, 1], + padding=padding, + name=name) + + with tf_ops.name_scope(name, "extract_image_patches", + [image, ksizes, strides, padding]): + batch = image.shape.as_list()[0] + in_channels = image.shape.as_list()[-1] + + # Map each input feature to a location in the output. + out_channels = np.prod(ksizes) * in_channels + filters = linalg_ops.eye(out_channels), + filters = array_ops.reshape(filters, ksizes + [in_channels, out_channels]) + + result = nn.convolution(image, filters, padding, strides=strides) + out_spatial = result.shape.as_list()[1:-1] + result = array_ops.reshape( + result, [batch or -1] + out_spatial + ksizes + [in_channels]) + + return result + + def compute_cov(tensor, tensor_right=None, normalizer=None): """Compute the empirical second moment of the rows of a 2D Tensor. @@ -668,11 +717,10 @@ class ConvDiagonalFactor(DiagonalFactor): # TODO(b/64144716): there is potential here for a big savings in terms # of memory use. - patches = array_ops.extract_image_patches( + patches = extract_image_patches( self._inputs, - ksizes=[1, filter_height, filter_width, 1], - strides=self._strides, - rates=[1, 1, 1, 1], + ksizes=[filter_height, filter_width], + strides=self._strides[1:-1], padding=self._padding) if self._has_bias: @@ -816,11 +864,10 @@ class ConvInputKroneckerFactor(InverseProvidingFactor): # TODO(b/64144716): there is potential here for a big savings in terms of # memory use. - patches = array_ops.extract_image_patches( + patches = extract_image_patches( self._inputs, - ksizes=[1, filter_height, filter_width, 1], - strides=self._strides, - rates=[1, 1, 1, 1], + ksizes=[filter_height, filter_width], + strides=self._strides[1:-1], padding=self._padding) flatten_size = (filter_height * filter_width * in_channels) diff --git a/tensorflow/contrib/kfac/python/ops/utils_lib.py b/tensorflow/contrib/kfac/python/ops/utils_lib.py index cc48e3c69f24c2abd343e2e120d3589cd323fcdc..fe8e39c212c2c3381f9aa6fdb9fdf423ff958481 100644 --- a/tensorflow/contrib/kfac/python/ops/utils_lib.py +++ b/tensorflow/contrib/kfac/python/ops/utils_lib.py @@ -24,6 +24,7 @@ from tensorflow.python.util.all_util import remove_undocumented # pylint: enable=unused-import,line-too-long,wildcard-import _allowed_symbols = [ + "set_global_constants", "SequenceDict", "tensors_to_column", "column_to_tensors", diff --git a/tensorflow/contrib/layers/kernels/sparse_feature_cross_kernel.cc b/tensorflow/contrib/layers/kernels/sparse_feature_cross_kernel.cc index 932c5ab99249feda1e3a7f2d707ce4237fe7177f..01893d60615a9b4ded2afc88c6de0168d4be0921 100644 --- a/tensorflow/contrib/layers/kernels/sparse_feature_cross_kernel.cc +++ b/tensorflow/contrib/layers/kernels/sparse_feature_cross_kernel.cc @@ -423,8 +423,9 @@ class SparseFeatureCrossOp : public OpKernel { "Input values should be a std::vector but received shape ", values_list_in[i].shape().DebugString(), " at position ", i)); OP_REQUIRES( - context, indices_list_in[i].shape().dim_size(0) == - values_list_in[i].shape().dim_size(0), + context, + indices_list_in[i].shape().dim_size(0) == + values_list_in[i].shape().dim_size(0), errors::InvalidArgument( "Expected size of values to be ", indices_list_in[i].shape().dim_size(0), " got ", diff --git a/tensorflow/contrib/layers/python/layers/layers.py b/tensorflow/contrib/layers/python/layers/layers.py index c8e3307ee8b5ded30dc864c4e69452f58685b8f0..fb7b2e315efb773770eda8c07e52c4850e48e4da 100644 --- a/tensorflow/contrib/layers/python/layers/layers.py +++ b/tensorflow/contrib/layers/python/layers/layers.py @@ -60,12 +60,12 @@ __all__ = [ 'conv2d_in_plane', 'conv2d_transpose', 'conv3d_transpose', 'convolution', 'convolution2d', 'convolution2d_in_plane', 'convolution2d_transpose', 'convolution3d', 'convolution3d_transpose', 'dense_to_sparse', - 'dropout', 'elu', 'flatten', - 'fully_connected', 'GDN', 'gdn', 'layer_norm', 'linear', 'pool', - 'max_pool2d', 'max_pool3d', 'one_hot_encoding', 'relu', 'relu6', 'repeat', - 'scale_gradient', 'separable_conv2d', 'separable_convolution2d', 'softmax', - 'spatial_softmax', 'stack', 'unit_norm', 'legacy_fully_connected', - 'legacy_linear', 'legacy_relu', 'maxout' + 'dropout', 'elu', 'flatten', 'fully_connected', 'GDN', 'gdn', 'layer_norm', + 'linear', 'pool', 'max_pool2d', 'max_pool3d', 'one_hot_encoding', 'relu', + 'relu6', 'repeat', 'scale_gradient', 'separable_conv2d', + 'separable_convolution2d', 'softmax', 'spatial_softmax', 'stack', + 'unit_norm', 'legacy_fully_connected', 'legacy_linear', 'legacy_relu', + 'maxout' ] DATA_FORMAT_NCHW = 'NCHW' @@ -1418,7 +1418,9 @@ def dense_to_sparse(tensor, eos_token=0, outputs_collections=None, scope=None): with variable_scope.variable_scope( scope, 'dense_to_sparse', [tensor]) as sc: tensor = ops.convert_to_tensor(tensor) - indices = array_ops.where(math_ops.not_equal(tensor, constant_op.constant(eos_token, tensor.dtype))) + indices = array_ops.where( + math_ops.not_equal( + tensor, constant_op.constant(eos_token, tensor.dtype))) values = array_ops.gather_nd(tensor, indices) shape = array_ops.shape(tensor, out_type=dtypes.int64) outputs = sparse_tensor.SparseTensor(indices, values, shape) diff --git a/tensorflow/contrib/layers/python/layers/layers_test.py b/tensorflow/contrib/layers/python/layers/layers_test.py index f05f63224636967e99673c04fb9de60af34e4268..8945690db8ee233e61645c38e6e4d615c4f0da66 100644 --- a/tensorflow/contrib/layers/python/layers/layers_test.py +++ b/tensorflow/contrib/layers/python/layers/layers_test.py @@ -1308,7 +1308,8 @@ class DenseToSparseTest(test.TestCase): expected_constant = np.reshape(np.arange(24, dtype=np.int64), (3, 4, 2)) tensor = constant_op.constant(expected_constant) sparse = _layers.dense_to_sparse(tensor) - dense = sparse_ops.sparse_to_dense(sparse.indices, sparse.dense_shape, sparse.values) + dense = sparse_ops.sparse_to_dense( + sparse.indices, sparse.dense_shape, sparse.values) with self.test_session() as sess: constant = sess.run(dense) self.assertAllEqual(expected_constant, constant) diff --git a/tensorflow/contrib/learn/python/learn/datasets/base.py b/tensorflow/contrib/learn/python/learn/datasets/base.py index 71978d439449e29c7cb907b18bab5d6659a972b6..18bf16e246bcb6c0a6a4ce75bc5c28d4e0d045e5 100644 --- a/tensorflow/contrib/learn/python/learn/datasets/base.py +++ b/tensorflow/contrib/learn/python/learn/datasets/base.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - """Base utilities for loading datasets.""" from __future__ import absolute_import @@ -100,9 +99,7 @@ def load_iris(data_path=None): module_path = path.dirname(__file__) data_path = path.join(module_path, 'data', 'iris.csv') return load_csv_with_header( - data_path, - target_dtype=np.int, - features_dtype=np.float) + data_path, target_dtype=np.int, features_dtype=np.float) def load_boston(data_path=None): @@ -118,16 +115,10 @@ def load_boston(data_path=None): module_path = path.dirname(__file__) data_path = path.join(module_path, 'data', 'boston_house_prices.csv') return load_csv_with_header( - data_path, - target_dtype=np.float, - features_dtype=np.float) + data_path, target_dtype=np.float, features_dtype=np.float) -def retry(initial_delay, - max_delay, - factor=2.0, - jitter=0.25, - is_retriable=None): +def retry(initial_delay, max_delay, factor=2.0, jitter=0.25, is_retriable=None): """Simple decorator for wrapping retriable functions. Args: @@ -152,7 +143,7 @@ def retry(initial_delay, def delays(): delay = initial_delay while delay <= max_delay: - yield delay * random.uniform(1 - jitter, 1 + jitter) + yield delay * random.uniform(1 - jitter, 1 + jitter) delay *= factor def wrap(fn): @@ -172,7 +163,9 @@ def retry(initial_delay, else: raise return fn(*args, **kwargs) + return wrapped_fn + return wrap diff --git a/tensorflow/contrib/learn/python/learn/datasets/synthetic_test.py b/tensorflow/contrib/learn/python/learn/datasets/synthetic_test.py index 19791d7759e05f477f1b79220f29831fb2240e1b..5809995c8c7d8e72eb47ee88a72547bae7fd3594 100644 --- a/tensorflow/contrib/learn/python/learn/datasets/synthetic_test.py +++ b/tensorflow/contrib/learn/python/learn/datasets/synthetic_test.py @@ -136,7 +136,7 @@ class SyntheticTest(test.TestCase): self.assertRaises(AssertionError, np.testing.assert_array_equal, spir0.data, spir1.data) - def test_spirals(self): + def test_spirals_synthetic(self): synthetic.spirals(3) diff --git a/tensorflow/contrib/learn/python/learn/estimators/dnn_test.py b/tensorflow/contrib/learn/python/learn/estimators/dnn_test.py index 12f9bba531a296a00d17956b8ce32e5d7dead380..2bd57597c2e9444b51b1dacfbe4180b443c95a3d 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/dnn_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/dnn_test.py @@ -1224,7 +1224,7 @@ class DNNRegressorTest(test.TestCase): self, predictions, expected_shape): predictions_nparray = np.array(predictions) self.assertAllEqual(expected_shape, predictions_nparray.shape) - self.assertTrue(np.issubdtype(predictions_nparray.dtype, np.float)) + self.assertTrue(np.issubdtype(predictions_nparray.dtype, np.floating)) def testPredict_AsIterableFalse(self): """Tests predict method with as_iterable=False.""" diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator.py b/tensorflow/contrib/learn/python/learn/estimators/estimator.py index 8d59fe66d98b2ca7dc143cfdf05d29629e3bf616..63d0f1e1d454354948654e8ad4208a8852d356ca 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/estimator.py +++ b/tensorflow/contrib/learn/python/learn/estimators/estimator.py @@ -600,7 +600,8 @@ class BaseEstimator(sklearn.BaseEstimator, evaluable.Evaluable, input_fn=None, batch_size=None, outputs=None, - as_iterable=True): + as_iterable=True, + iterate_batches=False): """Returns predictions for given features. Args: @@ -616,6 +617,9 @@ class BaseEstimator(sklearn.BaseEstimator, evaluable.Evaluable, for each example until inputs are exhausted. Note: The inputs must terminate if you want the iterable to terminate (e.g. be sure to pass num_epochs=1 if you are using something like read_batch_features). + iterate_batches: If True, yield the whole batch at once instead of + decomposing the batch into individual samples. Only relevant when + as_iterable is True. Returns: A numpy array of predicted classes or regression values if the @@ -635,7 +639,8 @@ class BaseEstimator(sklearn.BaseEstimator, evaluable.Evaluable, input_fn=input_fn, feed_fn=feed_fn, outputs=outputs, - as_iterable=as_iterable) + as_iterable=as_iterable, + iterate_batches=iterate_batches) def get_variable_value(self, name): """Returns value of the variable given by name. diff --git a/tensorflow/contrib/learn/python/learn/monitors.py b/tensorflow/contrib/learn/python/learn/monitors.py index 0948dee7e2fa1b1b3617abd08d2d43ebc5340f63..51381a7427c919592b8e818c4b46dba974992610 100644 --- a/tensorflow/contrib/learn/python/learn/monitors.py +++ b/tensorflow/contrib/learn/python/learn/monitors.py @@ -879,7 +879,7 @@ class GraphDump(BaseMonitor): this_output = self.data[step] if step in self.data else {} other_output = other_dump.data[step] if step in other_dump.data else {} for key in this_output: - if not isinstance(key, str) and not isinstance(key, unicode): + if not isinstance(key, six.string_types): continue if key not in other_output: raise ValueError("%s missing at step %s.", (key, step)) diff --git a/tensorflow/contrib/lite/BUILD b/tensorflow/contrib/lite/BUILD index 13350c5a438b75fe14e8753e5bb1bb77ec8f655b..cc0e20f75ee74b3df2f46f8df179aa0c33741c5f 100644 --- a/tensorflow/contrib/lite/BUILD +++ b/tensorflow/contrib/lite/BUILD @@ -53,6 +53,8 @@ cc_test( srcs = ["arena_planner_test.cc"], deps = [ ":arena_planner", + "//tensorflow/contrib/lite/testing:util", + "//tensorflow/core:lib", "@com_google_googletest//:gtest", ], ) @@ -167,6 +169,7 @@ cc_test( deps = [ ":framework", ":string_util", + "//tensorflow/contrib/lite/testing:util", "@com_google_googletest//:gtest", ], ) diff --git a/tensorflow/contrib/lite/arena_planner.cc b/tensorflow/contrib/lite/arena_planner.cc index bf1bcdd1a7a7d3395c45ae95abd5980e9ffc0fc6..87b17c338e7afc33d32dd9688cc0825ac319dd19 100644 --- a/tensorflow/contrib/lite/arena_planner.cc +++ b/tensorflow/contrib/lite/arena_planner.cc @@ -185,8 +185,12 @@ TfLiteStatus ArenaPlanner::CalculateAllocations(int first_node, int last_node) { TfLiteStatus ArenaPlanner::ResolveTensorAllocation(int tensor_index) { TfLiteTensor& tensor = *graph_info_->tensor(tensor_index); if (tensor.allocation_type == kTfLiteArenaRw) { - TF_LITE_ENSURE_STATUS( - arena_.ResolveAlloc(context_, allocs_[tensor_index], &tensor.data.raw)); + // Skip resolution if the size of the tensor is zero, leaving it as a + // nullptr. + if (allocs_[tensor_index].size != 0) { + TF_LITE_ENSURE_STATUS(arena_.ResolveAlloc(context_, allocs_[tensor_index], + &tensor.data.raw)); + } } if (tensor.allocation_type == kTfLiteArenaRwPersistent) { TF_LITE_ENSURE_STATUS(persistent_arena_.ResolveAlloc( diff --git a/tensorflow/contrib/lite/arena_planner_test.cc b/tensorflow/contrib/lite/arena_planner_test.cc index c27c327abc63d7bd1e3912d368a1dacb62c50ca8..a8a8755e2c9e81474f2ff9cd2b85c0eb3d5c3441 100644 --- a/tensorflow/contrib/lite/arena_planner_test.cc +++ b/tensorflow/contrib/lite/arena_planner_test.cc @@ -18,6 +18,8 @@ limitations under the License. #include #include +#include "tensorflow/contrib/lite/testing/util.h" +#include "tensorflow/core/platform/logging.h" namespace tflite { namespace { @@ -191,8 +193,8 @@ TEST_F(ArenaPlannerTest, GraphWithNoOps) { EXPECT_EQ(GetOffset(10), GetOffsetAfter(0)); // The outputs are never allocated because they are not connected to any // inputs. - EXPECT_EQ(GetOffset(5), 0); - EXPECT_EQ(GetOffset(11), 0); + EXPECT_TRUE((*graph.tensors())[5].data.raw == nullptr); + EXPECT_TRUE((*graph.tensors())[11].data.raw == nullptr); } TEST_F(ArenaPlannerTest, GraphWithOneOp) { @@ -371,11 +373,7 @@ TEST_F(ArenaPlannerTest, LargerGraphAndStepwiseAllocation) { SetGraph(&graph); auto is_unallocated = [&](int tensor_index) { - // TODO(ahentz): We'd to use nullptr to represent unallocated tensors, but - // the current code still points them all to the beginning fo the alloc - // (that is, zero offset). - // return (*graph.tensors())[tensor_index].data.raw == nullptr; - return GetOffset(tensor_index) == 0; + return (*graph.tensors())[tensor_index].data.raw == nullptr; }; // The allocation plan is made at the beginning and is independent of @@ -464,9 +462,7 @@ TEST_F(ArenaPlannerTest, LargerGraphAndStepwiseAllocation) { } // namespace tflite int main(int argc, char** argv) { - // ::tflite::LogToStderr(); - FLAGS_logtostderr = true; - + ::tflite::LogToStderr(); ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); } diff --git a/tensorflow/contrib/lite/builtin_op_data.h b/tensorflow/contrib/lite/builtin_op_data.h index 8966e5c2b602b22df5fc9ecded5d79e7bb80e7e2..5dbeadd16582ec586adab100b8a46e10182bd5ee 100644 --- a/tensorflow/contrib/lite/builtin_op_data.h +++ b/tensorflow/contrib/lite/builtin_op_data.h @@ -116,25 +116,9 @@ typedef struct { } TfLiteAddParams; typedef struct { - // Number of spatial dimensions. - // For now only NHWC is supported, and the value should always be 2. - int num_spatial_dimensions; - // TODO(ahentz): We can't have dynamic data in this struct, at least not yet. - // For now we will fix the maximum possible number of dimensions. - int block_shape[2]; - int before_paddings[2]; - int after_paddings[2]; } TfLiteSpaceToBatchNDParams; typedef struct { - // Number of spatial dimensions. - // For now only NHWC is supported, and the value should always be 2. - int num_spatial_dimensions; - // TODO(ahentz): We can't have dynamic data in this struct, at least not yet. - // For now we will fix the maximum possible number of dimensions. - int block_shape[2]; - int before_crops[2]; - int after_crops[2]; } TfLiteBatchToSpaceNDParams; typedef struct { @@ -167,6 +151,7 @@ typedef struct { } TfLiteLSTMParams; typedef struct { + bool align_corners; } TfLiteResizeBilinearParams; typedef struct { @@ -207,10 +192,6 @@ typedef struct { } TfLiteTransposeParams; typedef struct { - // TODO(ahentz): We can't have dynamic data in this struct, at least not yet. - // For now we will fix the maximum possible number of dimensions. - int axis[8]; - int num_axis_dimensions; bool keep_dims; } TfLiteMeanParams; diff --git a/tensorflow/contrib/lite/examples/ios/camera/CameraExampleViewController.mm b/tensorflow/contrib/lite/examples/ios/camera/CameraExampleViewController.mm index 10f31bb6f17242c9f7f70f0648ec643f99c5ac86..d74e275f0439b1ce56b29e0eadff5f211f6a4faa 100644 --- a/tensorflow/contrib/lite/examples/ios/camera/CameraExampleViewController.mm +++ b/tensorflow/contrib/lite/examples/ios/camera/CameraExampleViewController.mm @@ -225,14 +225,8 @@ static void GetTopN(const uint8_t* prediction, const int prediction_size, const assert(pixelBuffer != NULL); OSType sourcePixelFormat = CVPixelBufferGetPixelFormatType(pixelBuffer); - int doReverseChannels; - if (kCVPixelFormatType_32ARGB == sourcePixelFormat) { - doReverseChannels = 1; - } else if (kCVPixelFormatType_32BGRA == sourcePixelFormat) { - doReverseChannels = 0; - } else { - assert(false); // Unknown source format - } + assert(sourcePixelFormat == kCVPixelFormatType_32ARGB || + sourcePixelFormat == kCVPixelFormatType_32BGRA); const int sourceRowBytes = (int)CVPixelBufferGetBytesPerRow(pixelBuffer); const int image_width = (int)CVPixelBufferGetWidth(pixelBuffer); diff --git a/tensorflow/contrib/lite/examples/label_image/BUILD b/tensorflow/contrib/lite/examples/label_image/BUILD index 476d85c0314e331d6d3bad382c331a8458fd01a1..d216cdf69ba8fd3de5a665e5d8b29aa5e01bff13 100644 --- a/tensorflow/contrib/lite/examples/label_image/BUILD +++ b/tensorflow/contrib/lite/examples/label_image/BUILD @@ -42,7 +42,10 @@ cc_library( "bitmap_helpers_impl.h", "label_image.h", ], - deps = ["//tensorflow/contrib/lite:string"], + deps = [ + "//tensorflow/contrib/lite:string", + "//tensorflow/contrib/lite/kernels:builtin_ops", + ], ) # TODO(ahentz): Test disabled as it has a memory leek from read_bmp diff --git a/tensorflow/contrib/lite/examples/label_image/bitmap_helpers.h b/tensorflow/contrib/lite/examples/label_image/bitmap_helpers.h index 860e27e5ba9cc9fe23d2a7f9f65dd53bbf76f7a3..471fda2ba465aa5ccad2985a063a6855b7488a05 100644 --- a/tensorflow/contrib/lite/examples/label_image/bitmap_helpers.h +++ b/tensorflow/contrib/lite/examples/label_image/bitmap_helpers.h @@ -26,15 +26,15 @@ uint8_t* read_bmp(const std::string& input_bmp_name, int* width, int* height, int* channels, Settings* s); template -void downsize(T* out, uint8_t* in, int image_height, int image_width, - int image_channels, int wanted_height, int wanted_width, - int wanted_channels, Settings* s); +void resize(T* out, uint8_t* in, int image_height, int image_width, + int image_channels, int wanted_height, int wanted_width, + int wanted_channels, Settings* s); // explicit instantiation -template void downsize(uint8_t*, unsigned char*, int, int, int, int, - int, int, Settings*); -template void downsize(float*, unsigned char*, int, int, int, int, int, - int, Settings*); +template void resize(uint8_t*, unsigned char*, int, int, int, int, + int, int, Settings*); +template void resize(float*, unsigned char*, int, int, int, int, int, + int, Settings*); } // namespace label_image } // namespace tflite diff --git a/tensorflow/contrib/lite/examples/label_image/bitmap_helpers_impl.h b/tensorflow/contrib/lite/examples/label_image/bitmap_helpers_impl.h index 64a931082b0cbb4632ec3a814ce654d4f9106bc1..33ea695dda8a27ab2f0dd1c75538833debb26b95 100644 --- a/tensorflow/contrib/lite/examples/label_image/bitmap_helpers_impl.h +++ b/tensorflow/contrib/lite/examples/label_image/bitmap_helpers_impl.h @@ -16,30 +16,76 @@ limitations under the License. #ifndef TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_BITMAP_HELPERS_IMPL_H #define TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_BITMAP_HELPERS_IMPL_H +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/string_util.h" +#include "tensorflow/contrib/lite/version.h" + #include "tensorflow/contrib/lite/examples/label_image/label_image.h" namespace tflite { namespace label_image { template -void downsize(T* out, uint8_t* in, int image_height, int image_width, - int image_channels, int wanted_height, int wanted_width, - int wanted_channels, Settings* s) { - for (int y = 0; y < wanted_height; ++y) { - const int in_y = (y * image_height) / wanted_height; - uint8_t* in_row = in + (in_y * image_width * image_channels); - T* out_row = out + (y * wanted_width * wanted_channels); - for (int x = 0; x < wanted_width; ++x) { - const int in_x = (x * image_width) / wanted_width; - uint8_t* in_pixel = in_row + (in_x * image_channels); - T* out_pixel = out_row + (x * wanted_channels); - for (int c = 0; c < wanted_channels; ++c) { - if (s->input_floating) - out_pixel[c] = (in_pixel[c] - s->input_mean) / s->input_std; - else - out_pixel[c] = in_pixel[c]; - } - } +void resize(T* out, uint8_t* in, int image_height, int image_width, + int image_channels, int wanted_height, int wanted_width, + int wanted_channels, Settings* s) { + + int number_of_pixels = image_height * image_width * image_channels; + std::unique_ptr interpreter(new Interpreter); + + int base_index = 0; + + // two inputs: input and new_sizes + interpreter->AddTensors(2, &base_index); + // one output + interpreter->AddTensors(1, &base_index); + // set input and output tensors + interpreter->SetInputs({0, 1}); + interpreter->SetOutputs({2}); + + // set paramters of tensors + TfLiteQuantizationParams quant; + interpreter->SetTensorParametersReadWrite( + 0, kTfLiteFloat32, "input", + {1, image_height, image_width, image_channels}, quant); + interpreter->SetTensorParametersReadWrite(1, kTfLiteInt32, "new_size", {2}, + quant); + interpreter->SetTensorParametersReadWrite( + 2, kTfLiteFloat32, "output", + {1, wanted_height, wanted_width, wanted_channels}, quant); + + ops::builtin::BuiltinOpResolver resolver; + TfLiteRegistration* resize_op = + resolver.FindOp(BuiltinOperator_RESIZE_BILINEAR); + interpreter->AddNodeWithParameters({0, 1}, {2}, nullptr, 0, nullptr, + resize_op, nullptr); + + interpreter->AllocateTensors(); + + // fill input image + // in[] are integers, cannot do memcpy() directly + auto input = interpreter->typed_tensor(0); + for (int i = 0; i < number_of_pixels; i++) { + input[i] = in[i]; + } + + // fill new_sizes + interpreter->typed_tensor(1)[0] = wanted_height; + interpreter->typed_tensor(1)[1] = wanted_width; + + interpreter->Invoke(); + + auto output = interpreter->typed_tensor(2); + auto output_number_of_pixels = + wanted_height * wanted_height * wanted_channels; + + for (int i = 0; i < output_number_of_pixels; i++) { + if (s->input_floating) + out[i] = (output[i] - s->input_mean) / s->input_std; + else + out[i] = (uint8_t)output[i]; } } diff --git a/tensorflow/contrib/lite/examples/label_image/label_image.cc b/tensorflow/contrib/lite/examples/label_image/label_image.cc index d7f49ad8757e8899fe9c23b985edff6ba7f68750..a78900122efa540322a2f80fa3a98e6a8985ddd5 100644 --- a/tensorflow/contrib/lite/examples/label_image/label_image.cc +++ b/tensorflow/contrib/lite/examples/label_image/label_image.cc @@ -151,14 +151,14 @@ void RunInference(Settings* s) { switch (interpreter->tensor(input)->type) { case kTfLiteFloat32: s->input_floating = true; - downsize(interpreter->typed_tensor(input), in, - image_height, image_width, image_channels, - wanted_height, wanted_width, wanted_channels, s); + resize(interpreter->typed_tensor(input), in, + image_height, image_width, image_channels, + wanted_height, wanted_width, wanted_channels, s); break; case kTfLiteUInt8: - downsize(interpreter->typed_tensor(input), in, - image_height, image_width, image_channels, - wanted_height, wanted_width, wanted_channels, s); + resize(interpreter->typed_tensor(input), in, + image_height, image_width, image_channels, + wanted_height, wanted_width, wanted_channels, s); break; default: LOG(FATAL) << "cannot handle input type " diff --git a/tensorflow/contrib/lite/interpreter.cc b/tensorflow/contrib/lite/interpreter.cc index 69a597dc5a219b55eced6ec8da5b388caf372b8e..9dd60abc8639a0e051f717d5e41921271a78bf8a 100644 --- a/tensorflow/contrib/lite/interpreter.cc +++ b/tensorflow/contrib/lite/interpreter.cc @@ -36,6 +36,10 @@ constexpr const int kSlotsToReserve = 128; namespace tflite { // A trivial implementation of GraphInfo around the Interpreter. +// NOTE: this interpreter info represents the subset of the +// graph that is executed according to execution plan. Thus, +// the indices are execution plan indices rather than raw node +// indices. class InterpreterInfo : public GraphInfo { public: explicit InterpreterInfo(Interpreter* interpreter) @@ -45,9 +49,12 @@ class InterpreterInfo : public GraphInfo { TfLiteTensor* tensor(size_t index) override { return interpreter_->tensor(index); } - size_t num_nodes() const override { return interpreter_->nodes_size(); } + size_t num_nodes() const override { + return interpreter_->execution_plan().size(); + } const TfLiteNode& node(size_t index) const override { - return interpreter_->node_and_registration(index)->first; + int node_index = interpreter_->execution_plan()[index]; + return interpreter_->node_and_registration(node_index)->first; } const std::vector& inputs() const override { return interpreter_->inputs(); @@ -73,7 +80,7 @@ Interpreter::Interpreter(ErrorReporter* error_reporter) // Reserve some space for the tensors to avoid excessive resizing. tensors_.reserve(kSlotsToReserve); nodes_and_registration_.reserve(kSlotsToReserve); - next_node_to_prepare_ = 0; + next_execution_plan_index_to_prepare_ = 0; UseNNAPI(false); } @@ -160,7 +167,7 @@ TfLiteIntArray* convertVectorToTfLiteIntArray(const std::vector& x) { } // namespace TfLiteStatus Interpreter::AllocateTensors() { - next_node_to_prepare_ = 0; + next_execution_plan_index_to_prepare_ = 0; if (memory_planner_) { TF_LITE_ENSURE_STATUS(memory_planner_->ResetAllocations()); } @@ -190,7 +197,8 @@ TfLiteStatus Interpreter::AddNodeWithParameters( &context_, CheckTensorIndices("node outputs", outputs.data(), outputs.size())); - if (node_index) *node_index = nodes_and_registration_.size(); + int new_node_index = nodes_and_registration_.size(); + if (node_index) *node_index = new_node_index; nodes_and_registration_.resize(nodes_and_registration_.size() + 1); auto& node_and_reg = nodes_and_registration_.back(); TfLiteNode& node = node_and_reg.first; @@ -213,6 +221,7 @@ TfLiteStatus Interpreter::AddNodeWithParameters( } node.builtin_data = builtin_data_deleter.release(); node_and_reg.second = *registration; + execution_plan_.push_back(new_node_index); return kTfLiteOk; } @@ -240,16 +249,19 @@ bool HasDynamicTensor(const TfLiteContext& context, return false; } -TfLiteStatus Interpreter::PrepareOpsStartingAt(int first_node, - int* last_node_prepared) { - for (int i = first_node; i < nodes_and_registration_.size(); i++) { - TfLiteNode& node = nodes_and_registration_[i].first; - const TfLiteRegistration& registration = nodes_and_registration_[i].second; +TfLiteStatus Interpreter::PrepareOpsStartingAt( + int first_execution_plan_index, int* last_execution_plan_index_prepared) { + for (int execution_plan_index = first_execution_plan_index; + execution_plan_index < execution_plan_.size(); execution_plan_index++) { + int node_index = execution_plan_[execution_plan_index]; + TfLiteNode& node = nodes_and_registration_[node_index].first; + const TfLiteRegistration& registration = + nodes_and_registration_[node_index].second; if (OpPrepare(registration, &node) == kTfLiteError) { return kTfLiteError; } - *last_node_prepared = i; + *last_execution_plan_index_prepared = execution_plan_index; // Discontinue if the node has dynamic outputs. Note that we don't // stop for dynamic temporary tensors since they won't affect the @@ -268,14 +280,14 @@ TfLiteStatus Interpreter::PrepareOpsAndTensors() { memory_planner_->PlanAllocations(); } - int last_node_prepared = 0; + int last_exec_plan_index_prepared = 0; - TF_LITE_ENSURE_STATUS( - PrepareOpsStartingAt(next_node_to_prepare_, &last_node_prepared)); + TF_LITE_ENSURE_STATUS(PrepareOpsStartingAt( + next_execution_plan_index_to_prepare_, &last_exec_plan_index_prepared)); TF_LITE_ENSURE_STATUS(memory_planner_->ExecuteAllocations( - next_node_to_prepare_, last_node_prepared)); + next_execution_plan_index_to_prepare_, last_exec_plan_index_prepared)); - next_node_to_prepare_ = last_node_prepared + 1; + next_execution_plan_index_to_prepare_ = last_exec_plan_index_prepared + 1; return kTfLiteOk; } @@ -291,7 +303,8 @@ TfLiteStatus Interpreter::Invoke() { TfLiteStatus status = kTfLiteOk; if (nnapi_delegate_) { - if (next_node_to_prepare_ == nodes_and_registration_.size()) { + TF_LITE_ENSURE_STATUS(PrepareOpsAndTensors()); + if (next_execution_plan_index_to_prepare_ == execution_plan_.size()) { TF_LITE_ENSURE_OK(&context_, nnapi_delegate_->Invoke(this)); return kTfLiteOk; } else { @@ -311,13 +324,17 @@ TfLiteStatus Interpreter::Invoke() { // TODO(b/71913981): we should force recalculation in the presence of dynamic // tensors, because they may have new value which in turn may affect shapes // and allocations. - for (int i = 0; i < nodes_and_registration_.size(); i++) { - if (i == next_node_to_prepare_) { + for (int execution_plan_index = 0; + execution_plan_index < execution_plan_.size(); execution_plan_index++) { + if (execution_plan_index == next_execution_plan_index_to_prepare_) { TF_LITE_ENSURE_STATUS(PrepareOpsAndTensors()); - TF_LITE_ENSURE(&context_, next_node_to_prepare_ >= i); + TF_LITE_ENSURE(&context_, next_execution_plan_index_to_prepare_ >= + execution_plan_index); } - TfLiteNode& node = nodes_and_registration_[i].first; - const TfLiteRegistration& registration = nodes_and_registration_[i].second; + int node_index = execution_plan_[execution_plan_index]; + TfLiteNode& node = nodes_and_registration_[node_index].first; + const TfLiteRegistration& registration = + nodes_and_registration_[node_index].second; if (OpInvoke(registration, &node) == kTfLiteError) { status = kTfLiteError; } @@ -421,6 +438,14 @@ TfLiteStatus Interpreter::SetTensorParametersReadWrite( return kTfLiteOk; } +TfLiteStatus Interpreter::SetExecutionPlan(const std::vector& new_plan) { + for (int node_index : new_plan) { + TF_LITE_ENSURE(&context_, node_index >= 0 && node_index < nodes_size()); + } + execution_plan_ = new_plan; + return kTfLiteOk; +} + TfLiteStatus Interpreter::ResizeTensorImpl(TfLiteTensor* tensor, TfLiteIntArray* new_size) { // Note that in theory we could resize kTfLiteArenaRwPersistent tensors too. @@ -434,6 +459,9 @@ TfLiteStatus Interpreter::ResizeTensorImpl(TfLiteTensor* tensor, TfLiteIntArrayFree(new_size); return kTfLiteError; } + + // Realloc space for kTfLiteDynamic tensors. + TfLiteTensorRealloc(bytesRequired, tensor); tensor->bytes = bytesRequired; } if (tensor->dims) TfLiteIntArrayFree(tensor->dims); diff --git a/tensorflow/contrib/lite/interpreter.h b/tensorflow/contrib/lite/interpreter.h index 4f732769f9f921a9debd5213547d2baccfa69426..3b077c7a3540d40a0e8d53179a4ac1aba7e292a1 100644 --- a/tensorflow/contrib/lite/interpreter.h +++ b/tensorflow/contrib/lite/interpreter.h @@ -108,7 +108,7 @@ class Interpreter { // Adds a node with the given parameters and returns the index of the new // node in `node_index` (optionally). Interpreter will take ownership of - // `builtin_data` and destroy it with `delete`. Ownership of 'init_data' + // `builtin_data` and destroy it with `free`. Ownership of 'init_data' // remains with the caller. TfLiteStatus AddNodeWithParameters(const std::vector& inputs, const std::vector& outputs, @@ -166,12 +166,19 @@ class Interpreter { // Return the number of ops in the model. int nodes_size() const { return nodes_and_registration_.size(); } + // WARNING: Experimental interface, subject to change + const std::vector& execution_plan() const { return execution_plan_; } + + // WARNING: Experimental interface, subject to change + // Overrides execution plan. This bounds checks indices sent in. + TfLiteStatus SetExecutionPlan(const std::vector& new_plan); + // Get a tensor data structure. // TODO(aselle): Create a safe ArrayHandle interface to avoid exposing this // read/write access to structure TfLiteTensor* tensor(int tensor_index) { if (tensor_index >= context_.tensors_size || tensor_index < 0) - return nullptr; + return nullptr; return &context_.tensors[tensor_index]; } @@ -279,7 +286,8 @@ class Interpreter { // dynamic tensors is found or all ops have been prepared. Fill // 'last_node_prepared' with the id of the op containing dynamic tensors, or // the last in the graph. - TfLiteStatus PrepareOpsStartingAt(int first_node, int* last_node_prepared); + TfLiteStatus PrepareOpsStartingAt(int first_execution_plan_index, + int* last_execution_plan_index_prepared); // Tensors needed by the interpreter. Use `AddTensors` to add more blank // tensor entries. Note, `tensors_.data()` needs to be synchronized to the @@ -299,7 +307,8 @@ class Interpreter { TfLiteStatus BytesRequired(TfLiteType type, const int* dims, int dims_size, size_t* bytes); - // Request an tensor be resized implementation. + // Request an tensor be resized implementation. If the given tensor is of + // type kTfLiteDynamic it will also be allocated new memory. TfLiteStatus ResizeTensorImpl(TfLiteTensor* tensor, TfLiteIntArray* new_size); // Report a detailed error string (will be printed to stderr). @@ -354,7 +363,14 @@ class Interpreter { // node id, and execute the node to generate the output tensor before continue // to allocate successors. This process repeats until all nodes are executed. // NOTE: this relies on the order of nodes that is in topological order. - int next_node_to_prepare_; + int next_execution_plan_index_to_prepare_; + + // WARNING: This is an experimental interface that is subject to change. + // This is a list of node indices (to index into nodes_and_registration). + // This represents a valid topological sort (dependency ordered) execution + // plan. In particular, it is valid for this ordering to contain only a + // subset of the node indices. + std::vector execution_plan_; // Whether to delegate to NN API std::unique_ptr nnapi_delegate_; diff --git a/tensorflow/contrib/lite/interpreter_test.cc b/tensorflow/contrib/lite/interpreter_test.cc index edff2109430c6e1ec6c481619ed7772237a3301d..cfda19d72cd6d4de6e4f0bdc4d369439fef2293c 100644 --- a/tensorflow/contrib/lite/interpreter_test.cc +++ b/tensorflow/contrib/lite/interpreter_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include "tensorflow/contrib/lite/error_reporter.h" #include "tensorflow/contrib/lite/string_util.h" +#include "tensorflow/contrib/lite/testing/util.h" namespace tflite { namespace { @@ -282,6 +283,51 @@ TEST(BasicInterpreter, NoOpInterpreter) { ASSERT_EQ(interpreter.Invoke(), kTfLiteOk); } +TEST(BasicInterpreter, ResizingTensors) { + Interpreter interpreter; + ASSERT_EQ(interpreter.AddTensors(1), kTfLiteOk); + ASSERT_EQ(interpreter.SetInputs({0}), kTfLiteOk); + ASSERT_EQ(interpreter.SetOutputs({0}), kTfLiteOk); + + ASSERT_EQ(interpreter.SetTensorParametersReadWrite( + 0, kTfLiteFloat32, "", {3}, TfLiteQuantizationParams()), + kTfLiteOk); + + int t = interpreter.inputs()[0]; + TfLiteTensor* tensor = interpreter.tensor(t); + + ASSERT_EQ(interpreter.ResizeInputTensor(t, {1, 2, 3}), kTfLiteOk); + EXPECT_EQ(tensor->bytes, 6 * sizeof(float)); + ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk); + + tensor->data.f[5] = 0.123f; + + // Changing from kTfLiteArenaRw to kTfLiteDynamic is quite complicate: we need + // to unset data.raw, otherwise Realloc will try to free that memory. + tensor->data.raw = nullptr; + tensor->allocation_type = kTfLiteDynamic; + + ASSERT_EQ(interpreter.ResizeInputTensor(t, {1, 2, 4}), kTfLiteOk); + EXPECT_EQ(tensor->bytes, 8 * sizeof(float)); + ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk); + + // TODO(ahentz): We shouldn't have to force reallocation, but + // ResizeInputTensor doesn't realloc dynamic tensors. Also note that + // TfLiteTensorRealloc(tensor->bytes, tensor) is a no-op. + TfLiteTensorRealloc(9 * sizeof(float), tensor); + tensor->data.f[7] = 0.123f; + + ASSERT_EQ(interpreter.ResizeInputTensor(t, {2, 2, 4}), kTfLiteOk); + EXPECT_EQ(tensor->bytes, 16 * sizeof(float)); + ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk); + + // TODO(ahentz): We shouldn't have to force reallocation, but + // ResizeInputTensor doesn't realloc dynamic tensors. Also note that + // TfLiteTensorRealloc(tensor->bytes, tensor) is a no-op. + TfLiteTensorRealloc(17 * sizeof(float), tensor); + tensor->data.f[15] = 0.123f; +} + TEST(BasicInterpreter, OneOpInterpreter) { Interpreter interpreter; ASSERT_EQ(interpreter.AddTensors(2), kTfLiteOk); @@ -514,13 +560,138 @@ TEST(BasicInterpreter, TestCustomErrorReporter) { ASSERT_EQ(reporter.calls, 1); } +// Test fixture that allows playing with execution plans. It creates a two +// node graph that can be executed in either [0,1] order or [1,0] order. +// The CopyOp records when it is invoked in the class member run_order_ +// so we can test whether the execution plan was honored. +class TestExecutionPlan : public ::testing::Test { + // Encapsulates the node ids and provides them to a C primitive data type + // Allocatable with placement new, but never destructed, so make sure this + // doesn't own any heap allocated data. This is then is used as op local + // data to allow access to the test fixture data. + class CallReporting { + public: + CallReporting(int node_id, std::vector* run_order) + : node_id_(node_id), run_order_(run_order) {} + + void Record() { run_order_->push_back(node_id_); } + + private: + // The node id for this particular node + int node_id_; + // A pointer to the global run-order + std::vector* run_order_; + }; + + // Build a kernel registration for an op that copies its one input + // to an output + TfLiteRegistration CopyOpRegistration() { + TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr}; + + reg.prepare = [](TfLiteContext* context, TfLiteNode* node) { + // Set output size to input size + TfLiteTensor* tensor0 = &context->tensors[node->inputs->data[0]]; + TfLiteTensor* tensor1 = &context->tensors[node->outputs->data[0]]; + TfLiteIntArray* newSize = TfLiteIntArrayCopy(tensor0->dims); + return context->ResizeTensor(context, tensor1, newSize); + }; + + reg.invoke = [](TfLiteContext* context, TfLiteNode* node) { + CallReporting* call_reporting = + reinterpret_cast(node->builtin_data); + // Copy input data to output data. + TfLiteTensor* a0 = &context->tensors[node->inputs->data[0]]; + TfLiteTensor* a1 = &context->tensors[node->outputs->data[0]]; + int num = a0->dims->data[0]; + for (int i = 0; i < num; i++) { + a1->data.f[i] = a0->data.f[i]; + } + call_reporting->Record(); + return kTfLiteOk; + }; + return reg; + } + + // Adds a copy node going from tensor `input` to output tensor `output`. + // Note, input is used as the node_id. Inject run_order as op accessible + // data. Note: this is a little strange of a way to do this, but it is + // using op functionality to avoid static global variables. + void MakeCopyNode(int input, int output) { + // Ownership of call_reporting is taken by interpreter (malloc is used due + // to nodes being a C99 interface so free() is used). + TfLiteRegistration copy_op = CopyOpRegistration(); + CallReporting* call_reporting_1 = + reinterpret_cast(malloc(sizeof(CallReporting))); + new (call_reporting_1) CallReporting(input, &run_order_); + ASSERT_EQ(interpreter_.AddNodeWithParameters( + {0}, {2}, nullptr, 0, + reinterpret_cast(call_reporting_1), ©_op), + kTfLiteOk); + ASSERT_EQ(interpreter_.ResizeInputTensor(input, {3}), kTfLiteOk); + } + + void SetUp() final { + // Add two inputs and two outputs that don't depend on each other + ASSERT_EQ(interpreter_.AddTensors(4), kTfLiteOk); + interpreter_.SetInputs({0, 1}); + interpreter_.SetOutputs({2, 3}); + TfLiteQuantizationParams quantized; + for (int tensor_index = 0; tensor_index < 4; tensor_index++) { + ASSERT_EQ(interpreter_.SetTensorParametersReadWrite( + tensor_index, kTfLiteFloat32, "", {3}, quantized), + kTfLiteOk); + } + + // Define two copy functions that also use the user_data to report that + // they were called. + // i.e. tensor[2] = copy(tensor[0]); tensor[3] = copy(tensor[1]); + // thus we can reorder the two nodes arbitrary and still satisfy dependency + // order. + MakeCopyNode(0, 2); + MakeCopyNode(1, 3); + + ASSERT_EQ(interpreter_.AllocateTensors(), kTfLiteOk); + } + + protected: + Interpreter interpreter_; + + // list of node_ids that were run + std::vector run_order_; +}; + +TEST_F(TestExecutionPlan, DefaultExecutionPlan) { + // Check default order + ASSERT_EQ(interpreter_.Invoke(), kTfLiteOk); + ASSERT_EQ(run_order_, std::vector({0, 1})); +} + +TEST_F(TestExecutionPlan, ReversedExecutionPlan) { + // Check reversed order + interpreter_.SetExecutionPlan({1, 0}); + ASSERT_EQ(interpreter_.Invoke(), kTfLiteOk); + ASSERT_EQ(run_order_, std::vector({1, 0})); +} + +TEST_F(TestExecutionPlan, SubsetExecutionPlan) { + // Check running only node index 1 + interpreter_.SetExecutionPlan({1}); + ASSERT_EQ(interpreter_.Invoke(), kTfLiteOk); + ASSERT_EQ(run_order_, std::vector({1})); +} + +TEST_F(TestExecutionPlan, NullExecutionPlan) { + // Check nothing executed. + interpreter_.SetExecutionPlan({}); + ASSERT_EQ(interpreter_.Invoke(), kTfLiteOk); + ASSERT_EQ(run_order_, std::vector()); +} + } // namespace } // namespace tflite int main(int argc, char** argv) { -#ifdef OS_LINUX - FLAGS_logtostderr = true; -#endif + ::tflite::LogToStderr(); ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); } diff --git a/tensorflow/contrib/lite/java/BUILD b/tensorflow/contrib/lite/java/BUILD index 9a1a888b93ff981b1d14faa7e847e80be1f167f2..35aacb70002d1d454f675484e4398bcdffc4acf1 100644 --- a/tensorflow/contrib/lite/java/BUILD +++ b/tensorflow/contrib/lite/java/BUILD @@ -111,6 +111,26 @@ java_test( ], ) +# TODO: generate large models at runtime, instead of storing them. +java_test( + name = "InterpreterTest", + size = "small", + srcs = ["src/test/java/org/tensorflow/lite/InterpreterTest.java"], + data = [ + "src/testdata/add.bin", + "src/testdata/mobilenet.tflite.bin", + ], + javacopts = JAVACOPTS, + test_class = "org.tensorflow.lite.InterpreterTest", + visibility = ["//visibility:private"], + deps = [ + ":libtensorflowlite_jni.so", + ":tensorflowlitelib", + "@com_google_truth", + "@junit", + ], +) + java_test( name = "TensorTest", size = "small", diff --git a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc index f3f51b668f068ffcd02862a79b72dbae31d31c02..c346f9f92e360c0722ebac440d790da6441ceecf 100644 --- a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc +++ b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc @@ -200,6 +200,12 @@ TfLiteStatus setInputs(JNIEnv* env, tflite::Interpreter* interpreter, return kTfLiteOk; } +// TODO(yichengfan): evaluate the benefit to use tflite verifier. +bool VerifyModel(const void* buf, size_t len) { + flatbuffers::Verifier verifier(static_cast(buf), len); + return tflite::VerifyModelBuffer(verifier); +} + } // namespace JNIEXPORT jobjectArray JNICALL @@ -271,6 +277,17 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_createModel( convertLongToErrorReporter(env, error_handle); if (error_reporter == nullptr) return 0; const char* path = env->GetStringUTFChars(model_file, nullptr); + + { + tflite::FileCopyAllocation allocation(path, nullptr); + if (!VerifyModel(allocation.base(), allocation.bytes())) { + throwException(env, kIllegalArgumentException, + "Contents of %s is not a valid flatbuffer model", path); + env->ReleaseStringUTFChars(model_file, path); + return 0; + } + } + auto model = tflite::FlatBufferModel::BuildFromFile(path, error_reporter); if (!model) { throwException(env, kIllegalArgumentException, @@ -293,6 +310,12 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_createModelWithBuffer( const char* buf = static_cast(env->GetDirectBufferAddress(model_buffer)); jlong capacity = env->GetDirectBufferCapacity(model_buffer); + if (!VerifyModel(buf, capacity)) { + throwException(env, kIllegalArgumentException, + "MappedByteBuffer is not a valid flatbuffer model"); + return 0; + } + auto model = tflite::FlatBufferModel::BuildFromBuffer( buf, static_cast(capacity), error_reporter); if (!model) { diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java index 473f73816fd3c0a414a2c2e232dec299579fcbb6..90323555d88419d837a76bca7de6d9998e388fca 100644 --- a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java +++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java @@ -60,9 +60,7 @@ public final class NativeInterpreterWrapperTest { NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(INVALID_MODEL_PATH); fail(); } catch (IllegalArgumentException e) { - assertThat(e) - .hasMessageThat() - .contains("Model provided has model identifier ' is ', should be 'TFL3'"); + assertThat(e).hasMessageThat().contains("is not a valid flatbuffer model"); } } diff --git a/tensorflow/contrib/lite/kernels/BUILD b/tensorflow/contrib/lite/kernels/BUILD index d9051f3516367581d31f7477e0f5bb5fa0e979f8..a8ef0daede4f3b7eeffccf77263577002d512e2c 100644 --- a/tensorflow/contrib/lite/kernels/BUILD +++ b/tensorflow/contrib/lite/kernels/BUILD @@ -156,6 +156,7 @@ cc_library( "//tensorflow/contrib/lite:framework", "//tensorflow/contrib/lite:string_util", "//tensorflow/contrib/lite/kernels:gemm_support", + "//tensorflow/contrib/lite/kernels/internal:kernel_utils", "//tensorflow/contrib/lite/kernels/internal:optimized", "//tensorflow/contrib/lite/kernels/internal:optimized_base", "//tensorflow/contrib/lite/kernels/internal:quantization_util", @@ -249,6 +250,7 @@ tf_cc_test( ":builtin_ops", "//tensorflow/contrib/lite:framework", "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_absl//absl/memory", "@com_google_googletest//:gtest", ], ) diff --git a/tensorflow/contrib/lite/kernels/activations.cc b/tensorflow/contrib/lite/kernels/activations.cc index 8ac93bc8c8dcfc66d3822e01b6f9b29a3e49c446..3c5c77815d0f2592ab549152b4d77f45b967a660 100644 --- a/tensorflow/contrib/lite/kernels/activations.cc +++ b/tensorflow/contrib/lite/kernels/activations.cc @@ -15,8 +15,8 @@ limitations under the License. #include #include #include -#include #include +#include #include #include @@ -134,8 +134,7 @@ TfLiteStatus ReluEval(TfLiteContext* context, TfLiteNode* node) { float* out = output->data.f; for (; in < in_end; in++, out++) *out = std::max(0.f, *in); return kTfLiteOk; - } - break; + } break; default: context->ReportError(context, "Only float32 supported currently."); return kTfLiteError; @@ -173,8 +172,7 @@ TfLiteStatus Relu6Eval(TfLiteContext* context, TfLiteNode* node) { float* out = output->data.f; for (; in < in_end; in++, out++) *out = std::min(std::max(0.f, *in), 6.f); return kTfLiteOk; - } - break; + } break; default: context->ReportError(context, "Only float32 supported currently."); return kTfLiteError; @@ -192,8 +190,7 @@ TfLiteStatus TanhEval(TfLiteContext* context, TfLiteNode* node) { float* out = output->data.f; for (; in < in_end; in++, out++) *out = std::tanh(*in); return kTfLiteOk; - } - break; + } break; default: context->ReportError(context, "Only float32 supported currently."); return kTfLiteError; diff --git a/tensorflow/contrib/lite/kernels/add.cc b/tensorflow/contrib/lite/kernels/add.cc index 0e10a249abac3ba19cf107e055aa71d1eee00122..63ea89df56bafa995950afec3a58267681af304f 100644 --- a/tensorflow/contrib/lite/kernels/add.cc +++ b/tensorflow/contrib/lite/kernels/add.cc @@ -37,7 +37,23 @@ constexpr int kInputTensor1 = 0; constexpr int kInputTensor2 = 1; constexpr int kOutputTensor = 0; +struct OpData { + bool requires_broadcast; +}; + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + auto* data = new OpData; + data->requires_broadcast = false; + return data; +} + +void Free(TfLiteContext* context, void* buffer) { + delete reinterpret_cast(buffer); +} + TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + OpData* data = reinterpret_cast(node->user_data); + TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); @@ -45,43 +61,56 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); - TF_LITE_ENSURE_EQ(context, NumDimensions(input1), NumDimensions(input2)); - for (int i = 0; i < NumDimensions(input1); ++i) { - TF_LITE_ENSURE_EQ(context, SizeOfDimension(input1, i), - SizeOfDimension(input2, i)); - } + TF_LITE_ENSURE_EQ(context, input1->type, input2->type); + output->type = input2->type; - TF_LITE_ENSURE_EQ(context, input1->type, output->type); - TF_LITE_ENSURE_EQ(context, input2->type, output->type); + data->requires_broadcast = !HaveSameShapes(input1, input2); + + TfLiteIntArray* output_size = nullptr; + if (data->requires_broadcast) { + TF_LITE_ENSURE_OK(context, CalculateShapeForBroadcast( + context, input1, input2, &output_size)); + } else { + output_size = TfLiteIntArrayCopy(input1->dims); + } - TfLiteIntArray* output_size = TfLiteIntArrayCopy(input1->dims); return context->ResizeTensor(context, output, output_size); } template void EvalAddFloat(TfLiteContext* context, TfLiteNode* node, - TfLiteAddParams* params, TfLiteTensor* input1, - TfLiteTensor* input2, TfLiteTensor* output) { + TfLiteAddParams* params, const OpData* data, + TfLiteTensor* input1, TfLiteTensor* input2, + TfLiteTensor* output) { float output_activation_min, output_activation_max; CalculateActivationRangeFloat(params->activation, &output_activation_min, &output_activation_max); -#define TF_LITE_ADD(type) \ - type::Add(GetTensorData(input1), GetTensorDims(input1), \ - GetTensorData(input2), GetTensorDims(input2), \ - output_activation_min, output_activation_max, \ - GetTensorData(output), GetTensorDims(output)) - if (kernel_type == kReference) { - TF_LITE_ADD(reference_ops); +#define TF_LITE_ADD(type, opname) \ + type::opname(GetTensorData(input1), GetTensorDims(input1), \ + GetTensorData(input2), GetTensorDims(input2), \ + output_activation_min, output_activation_max, \ + GetTensorData(output), GetTensorDims(output)) + if (kernel_type == kReference) { + if (data->requires_broadcast) { + TF_LITE_ADD(reference_ops, BroadcastAdd); } else { - TF_LITE_ADD(optimized_ops); + TF_LITE_ADD(reference_ops, Add); + } + } else { + if (data->requires_broadcast) { + TF_LITE_ADD(optimized_ops, BroadcastAdd); + } else { + TF_LITE_ADD(optimized_ops, Add); + } } #undef TF_LITE_ADD } template void EvalAddQuantized(TfLiteContext* context, TfLiteNode* node, - TfLiteAddParams* params, TfLiteTensor* input1, - TfLiteTensor* input2, TfLiteTensor* output) { + TfLiteAddParams* params, const OpData* data, + TfLiteTensor* input1, TfLiteTensor* input2, + TfLiteTensor* output) { auto input1_offset = -input1->params.zero_point; auto input2_offset = -input2->params.zero_point; auto output_offset = output->params.zero_point; @@ -112,19 +141,20 @@ void EvalAddQuantized(TfLiteContext* context, TfLiteNode* node, CalculateActivationRangeUint8(params->activation, output, &output_activation_min, &output_activation_max); -#define TF_LITE_ADD(type) \ - type::BroadcastAdd( \ - left_shift, GetTensorData(input1), GetTensorDims(input1), \ - input1_offset, input1_multiplier, input1_shift, \ - GetTensorData(input2), GetTensorDims(input2), input2_offset, \ - input2_multiplier, input2_shift, output_offset, output_multiplier, \ - output_shift, output_activation_min, output_activation_max, \ - GetTensorData(output), GetTensorDims(output)); - +#define TF_LITE_ADD(type, opname) \ + type::opname(left_shift, GetTensorData(input1), \ + GetTensorDims(input1), input1_offset, input1_multiplier, \ + input1_shift, GetTensorData(input2), \ + GetTensorDims(input2), input2_offset, input2_multiplier, \ + input2_shift, output_offset, output_multiplier, output_shift, \ + output_activation_min, output_activation_max, \ + GetTensorData(output), GetTensorDims(output)); + // The quantized version of Add doesn't support activations, so we + // always use BroadcastAdd. if (kernel_type == kReference) { - TF_LITE_ADD(reference_ops); + TF_LITE_ADD(reference_ops, BroadcastAdd); } else { - TF_LITE_ADD(optimized_ops); + TF_LITE_ADD(optimized_ops, BroadcastAdd); } #undef TF_LITE_ADD } @@ -132,15 +162,17 @@ void EvalAddQuantized(TfLiteContext* context, TfLiteNode* node, template TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { auto* params = reinterpret_cast(node->builtin_data); + OpData* data = reinterpret_cast(node->user_data); TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); if (output->type == kTfLiteFloat32) { - EvalAddFloat(context, node, params, input1, input2, output); + EvalAddFloat(context, node, params, data, input1, input2, + output); } else if (output->type == kTfLiteUInt8) { - EvalAddQuantized(context, node, params, input1, input2, + EvalAddQuantized(context, node, params, data, input1, input2, output); } else { context->ReportError(context, @@ -154,19 +186,19 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } // namespace add TfLiteRegistration* Register_ADD_REF() { - static TfLiteRegistration r = {nullptr, nullptr, add::Prepare, + static TfLiteRegistration r = {add::Init, add::Free, add::Prepare, add::Eval}; return &r; } TfLiteRegistration* Register_ADD_GENERIC_OPT() { - static TfLiteRegistration r = {nullptr, nullptr, add::Prepare, + static TfLiteRegistration r = {add::Init, add::Free, add::Prepare, add::Eval}; return &r; } TfLiteRegistration* Register_ADD_NEON_OPT() { - static TfLiteRegistration r = {nullptr, nullptr, add::Prepare, + static TfLiteRegistration r = {add::Init, add::Free, add::Prepare, add::Eval}; return &r; } diff --git a/tensorflow/contrib/lite/kernels/add_test.cc b/tensorflow/contrib/lite/kernels/add_test.cc index 306dfc3e803d3df34061767ba9ced032299bfa26..956d05bed5162f6ce59705d59aad77ff056dda77 100644 --- a/tensorflow/contrib/lite/kernels/add_test.cc +++ b/tensorflow/contrib/lite/kernels/add_test.cc @@ -25,10 +25,11 @@ using ::testing::ElementsAreArray; class BaseAddOpModel : public SingleOpModel { public: - BaseAddOpModel(const TensorData& input, const TensorData& output, + BaseAddOpModel(const TensorData& input1, const TensorData& input2, + const TensorData& output, ActivationFunctionType activation_type) { - input1_ = AddInput(input); - input2_ = AddInput(input); + input1_ = AddInput(input1); + input2_ = AddInput(input2); output_ = AddOutput(output); SetBuiltinOp(BuiltinOperator_ADD, BuiltinOptions_AddOptions, CreateAddOptions(builder_, activation_type).Union()); @@ -70,6 +71,7 @@ float GetTolerance(int min, int max) { TEST(FloatAddOpModel, NoActivation) { FloatAddOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}}, + {TensorType_FLOAT32, {1, 2, 2, 1}}, {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE); m.PopulateTensor(m.input1(), {-2.0, 0.2, 0.7, 0.8}); m.PopulateTensor(m.input2(), {0.1, 0.2, 0.3, 0.5}); @@ -78,9 +80,9 @@ TEST(FloatAddOpModel, NoActivation) { } TEST(FloatAddOpModel, ActivationRELU_N1_TO_1) { - FloatAddOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}}, - {TensorType_FLOAT32, {}}, - ActivationFunctionType_RELU_N1_TO_1); + FloatAddOpModel m( + {TensorType_FLOAT32, {1, 2, 2, 1}}, {TensorType_FLOAT32, {1, 2, 2, 1}}, + {TensorType_FLOAT32, {}}, ActivationFunctionType_RELU_N1_TO_1); m.PopulateTensor(m.input1(), {-2.0, 0.2, 0.7, 0.8}); m.PopulateTensor(m.input2(), {0.1, 0.2, 0.3, 0.5}); m.Invoke(); @@ -92,6 +94,7 @@ TEST(FloatAddOpModel, VariousInputShapes) { {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}}; for (int i = 0; i < test_shapes.size(); ++i) { FloatAddOpModel m({TensorType_FLOAT32, test_shapes[i]}, + {TensorType_FLOAT32, test_shapes[i]}, {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE); m.PopulateTensor(m.input1(), {-2.0, 0.2, 0.7, 0.8, 1.1, 2.0}); m.PopulateTensor(m.input2(), {0.1, 0.2, 0.3, 0.5, 1.1, 0.1}); @@ -102,6 +105,23 @@ TEST(FloatAddOpModel, VariousInputShapes) { } } +TEST(FloatAddOpModel, WithBroadcast) { + std::vector> test_shapes = { + {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}}; + for (int i = 0; i < test_shapes.size(); ++i) { + FloatAddOpModel m({TensorType_FLOAT32, test_shapes[i]}, + {TensorType_FLOAT32, {}}, // always a scalar + {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE); + m.PopulateTensor(m.input1(), {-2.0, 0.2, 0.7, 0.8, 1.1, 2.0}); + m.PopulateTensor(m.input2(), {0.1}); + m.Invoke(); + EXPECT_THAT( + m.GetOutput(), + ElementsAreArray(ArrayFloatNear({-1.9, 0.3, 0.8, 0.9, 1.2, 2.1}))) + << "With shape number " << i; + } +} + TEST(QuantizedAddOpModel, QuantizedTestsNoActivation) { float kQuantizedTolerance = GetTolerance(-1.0, 1.0); std::vector> inputs1 = { @@ -112,6 +132,7 @@ TEST(QuantizedAddOpModel, QuantizedTestsNoActivation) { {0.7, 0.6, 0.6, 0.5}, {-0.2, 0.6, 0.9, -0.1}, {-0.2, 0.6, -0.1, 0.8}}; for (int i = 0; i < inputs1.size(); ++i) { QuantizedAddOpModel m({TensorType_UINT8, {1, 2, 2, 1}, -1.0, 1.0}, + {TensorType_UINT8, {1, 2, 2, 1}, -1.0, 1.0}, {TensorType_UINT8, {}, -1.0, 1.0}, ActivationFunctionType_NONE); m.QuantizeAndPopulate(m.input1(), inputs1[i]); @@ -133,6 +154,7 @@ TEST(QuantizedAddOpModel, QuantizedTestsActivationRELU_N1_TO_1) { {-0.2, 0.6, -0.1, 0.8}}; for (int i = 0; i < inputs1.size(); ++i) { QuantizedAddOpModel m({TensorType_UINT8, {1, 2, 2, 1}, -1.0, 1.0}, + {TensorType_UINT8, {1, 2, 2, 1}, -1.0, 1.0}, {TensorType_UINT8, {}, -1.0, 1.0}, ActivationFunctionType_RELU_N1_TO_1); m.QuantizeAndPopulate(m.input1(), inputs1[i]); @@ -150,6 +172,7 @@ TEST(QuantizedAddOpModel, QuantizedVariousInputShapes) { {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}}; for (int i = 0; i < test_shapes.size(); ++i) { QuantizedAddOpModel m({TensorType_UINT8, test_shapes[i], -3.0, 3.0}, + {TensorType_UINT8, test_shapes[i], -3.0, 3.0}, {TensorType_UINT8, {}, -3.0, 3.0}, ActivationFunctionType_NONE); m.QuantizeAndPopulate(m.input1(), {-2.0, 0.2, 0.7, 0.8, 1.1, 2.0}); @@ -162,6 +185,25 @@ TEST(QuantizedAddOpModel, QuantizedVariousInputShapes) { } } +TEST(QuantizedAddOpModel, QuantizedWithBroadcast) { + float kQuantizedTolerance = GetTolerance(-3.0, 3.0); + std::vector> test_shapes = { + {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}}; + for (int i = 0; i < test_shapes.size(); ++i) { + QuantizedAddOpModel m({TensorType_UINT8, test_shapes[i], -3.0, 3.0}, + {TensorType_UINT8, {}, -3.0, 3.0}, + {TensorType_UINT8, {}, -3.0, 3.0}, + ActivationFunctionType_NONE); + m.QuantizeAndPopulate(m.input1(), {-2.0, 0.2, 0.7, 0.8, 1.1, 2.0}); + m.QuantizeAndPopulate(m.input2(), {0.1}); + m.Invoke(); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear({-1.9, 0.3, 0.8, 0.9, 1.2, 2.1}, + kQuantizedTolerance))) + << "With shape number " << i; + } +} + } // namespace } // namespace tflite int main(int argc, char** argv) { diff --git a/tensorflow/contrib/lite/kernels/basic_rnn.cc b/tensorflow/contrib/lite/kernels/basic_rnn.cc index 3cee43c68b2a0af5a3fd84b33a980b74bb8f0cb4..2c5074eca3176c7f33a6f051b492dc41333257ed 100644 --- a/tensorflow/contrib/lite/kernels/basic_rnn.cc +++ b/tensorflow/contrib/lite/kernels/basic_rnn.cc @@ -15,14 +15,15 @@ limitations under the License. #include #include #include -#include #include +#include #include #include #include "tensorflow/contrib/lite/builtin_op_data.h" #include "tensorflow/contrib/lite/context.h" #include "tensorflow/contrib/lite/kernels/activation_functor.h" +#include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h" #include "tensorflow/contrib/lite/kernels/op_macros.h" namespace tflite { @@ -76,8 +77,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TfLiteIntArray* output_size_array = TfLiteIntArrayCreate(2); output_size_array->data[0] = batch_size; output_size_array->data[1] = num_units; - TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, output, - output_size_array)); + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, output, output_size_array)); return kTfLiteOk; } @@ -101,50 +102,20 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { const int batch_size = input->dims->data[0]; const int num_units = input_weights->dims->data[0]; const int input_size = input->dims->data[1]; - const int input_weights_stride = input_weights->dims->data[1]; - const int recurrent_weights_stride = recurrent_weights->dims->data[1]; - - // For each batch - for (int b = 0; b < batch_size; b++) { - // Initialize the pointer to input, output and bias. - const float* input_ptr_batch = input->data.f + b * input_size; - float* output_ptr_batch = output->data.f + b * num_units; - float* hidden_state_ptr_batch = hidden_state->data.f + b * num_units; - - // Initialize input_weights and recurrent_weights. - const float* input_weights_ptr = input_weights->data.f; - const float* recurrent_weights_ptr = recurrent_weights->data.f; - - // Output = bias - for (int o = 0; o < num_units; o++) { - output_ptr_batch[o] = bias_ptr[o]; - } - - // Output += input * input_weights - for (int o = 0; o < num_units; o++) { - for (int i = 0; i < input_size; i++) { - output_ptr_batch[o] += input_ptr_batch[i] * input_weights_ptr[i]; - } - input_weights_ptr += input_weights_stride; - } - - // Output += recurrent_weights * hidden_state - for (int o = 0; o < num_units; o++) { - for (int h = 0; h < num_units; h++) { - output_ptr_batch[o] += - hidden_state_ptr_batch[h] * recurrent_weights_ptr[h]; - } - recurrent_weights_ptr += recurrent_weights_stride; - } - - // Output = activation(Output) and update hidden_state - for (int o = 0; o < num_units; o++) { - output_ptr_batch[o] = - (ActivationFunctor(params->activation))(output_ptr_batch[o]); - hidden_state_ptr_batch[o] = output_ptr_batch[o]; - } - } + // Initialize the pointer to hidden state. + float* hidden_state_ptr_batch = hidden_state->data.f; + // Initialize the pointer to input and output. + const float* input_ptr_batch = input->data.f; + float* output_ptr_batch = output->data.f; + // Initialize input_weights and recurrent_weights. + const float* input_weights_ptr = input_weights->data.f; + const float* recurrent_weights_ptr = recurrent_weights->data.f; + + kernel_utils::RnnBatchStep(input_ptr_batch, input_weights_ptr, + recurrent_weights_ptr, bias_ptr, input_size, + num_units, batch_size, params->activation, + hidden_state_ptr_batch, output_ptr_batch); return kTfLiteOk; } diff --git a/tensorflow/contrib/lite/kernels/basic_rnn_test.cc b/tensorflow/contrib/lite/kernels/basic_rnn_test.cc index 5ecccb985e91238f1183c8f94a2b5f468758ce55..fa7ef525db47c93f98951604cd04da66196422d7 100644 --- a/tensorflow/contrib/lite/kernels/basic_rnn_test.cc +++ b/tensorflow/contrib/lite/kernels/basic_rnn_test.cc @@ -14,8 +14,8 @@ limitations under the License. ==============================================================================*/ // Unit test for TFLite RNN op. -#include #include +#include #include #include @@ -120,8 +120,7 @@ static float rnn_golden_output[] = { 0.415153, 0.210318, 0, 0, 0, 0, 0, 2.02616, 0, 0.728256, 0.84183, 0.0907453, - 0.628881, 3.58099, 1.49974, 0 -}; + 0.628881, 3.58099, 1.49974, 0}; class RNNOpModel : public SingleOpModel { public: diff --git a/tensorflow/contrib/lite/kernels/batch_to_space_nd.cc b/tensorflow/contrib/lite/kernels/batch_to_space_nd.cc index 0eed680fdcc2afc4bc72be55a5e7722310fa4538..889239f93215a309d5434b209ebfc1f584c47849 100644 --- a/tensorflow/contrib/lite/kernels/batch_to_space_nd.cc +++ b/tensorflow/contrib/lite/kernels/batch_to_space_nd.cc @@ -35,12 +35,14 @@ enum KernelType { struct BatchToSpaceNDContext { BatchToSpaceNDContext(TfLiteContext* context, TfLiteNode* node) { - params = reinterpret_cast(node->builtin_data); input = GetInput(context, node, 0); + block_shape = GetInput(context, node, 1); + crops = GetInput(context, node, 2); output = GetOutput(context, node, 0); } - TfLiteBatchToSpaceNDParams* params; TfLiteTensor* input; + TfLiteTensor* block_shape; + TfLiteTensor* crops; TfLiteTensor* output; }; @@ -48,23 +50,28 @@ struct BatchToSpaceNDContext { // The 4D array need to have exactly 2 spatial dimensions. // TODO(ycling): Support arbitrary dimension in BatchToSpaceND. const int kInputDimensionNum = 4; -const int kOutputDimensionNum = 4; +const int kBlockSizeDimensionNum = 1; const int kSpatialDimensionNum = 2; -TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { - // The 2nd tensor (block_shape) and the 3rd tensor (crops) are ignored now. - TF_LITE_ENSURE(context, NumInputs(node) >= 1 && NumInputs(node) <= 3); - TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); +TfLiteStatus ResizeOutputTensor(TfLiteContext* context, + BatchToSpaceNDContext* op_context) { + TfLiteIntArray* input_size = op_context->input->dims; + const int* block_shape = GetTensorData(op_context->block_shape); + const int* crops = GetTensorData(op_context->crops); - BatchToSpaceNDContext op_context(context, node); - TF_LITE_ENSURE_EQ(context, NumDimensions(op_context.input), - kInputDimensionNum); - TF_LITE_ENSURE_EQ(context, op_context.params->num_spatial_dimensions, + TF_LITE_ENSURE_EQ(context, NumDimensions(op_context->block_shape), + kBlockSizeDimensionNum); + TF_LITE_ENSURE_EQ(context, op_context->block_shape->dims->data[0], + kSpatialDimensionNum); + TF_LITE_ENSURE_EQ(context, NumDimensions(op_context->crops), kSpatialDimensionNum); - TF_LITE_ENSURE_EQ(context, op_context.input->type, op_context.output->type); - const TfLiteIntArray* input_size = op_context.input->dims; - const int* block_shape = op_context.params->block_shape; + // TODO(ycling): Add crops as part of calculation. Remove check for a crops + // containing all zeroes. + TF_LITE_ENSURE_EQ(context, crops[0], 0); + TF_LITE_ENSURE_EQ(context, crops[1], 0); + TF_LITE_ENSURE_EQ(context, crops[2], 0); + TF_LITE_ENSURE_EQ(context, crops[3], 0); // Number of batch must be multiple of (block_shape[0] * block_shape[1]). TF_LITE_ENSURE_EQ(context, @@ -76,27 +83,48 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { const int output_width = input_size->data[2] * block_shape[1]; const int output_channel_size = input_size->data[3]; - TfLiteIntArray* output_size = TfLiteIntArrayCreate(kOutputDimensionNum); + TfLiteIntArray* output_size = TfLiteIntArrayCopy(input_size); output_size->data[0] = output_batch_size; output_size->data[1] = output_height; output_size->data[2] = output_width; output_size->data[3] = output_channel_size; - return context->ResizeTensor(context, op_context.output, output_size); + return context->ResizeTensor(context, op_context->output, output_size); +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 3); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + BatchToSpaceNDContext op_context(context, node); + TF_LITE_ENSURE_EQ(context, NumDimensions(op_context.input), + kInputDimensionNum); + TF_LITE_ENSURE_EQ(context, op_context.input->type, op_context.output->type); + + if (!IsConstantTensor(op_context.block_shape) || + !IsConstantTensor(op_context.crops)) { + SetTensorToDynamic(op_context.output); + return kTfLiteOk; + } + return ResizeOutputTensor(context, &op_context); } template TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { BatchToSpaceNDContext op_context(context, node); - int block_shape_dims_array[1] = {kSpatialDimensionNum}; - Dims<4> block_shape_dims = GetTensorDims(block_shape_dims_array, 1); + // Resize the output tensor if the output tensor is dynamic. + if (IsDynamicTensor(op_context.output)) { + TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context)); + TfLiteTensorRealloc(op_context.output->bytes, op_context.output); + } -#define TF_LITE_BATCH_TO_SPACE_ND(type, scalar) \ - type::BatchToSpaceND(GetTensorData(op_context.input), \ - GetTensorDims(op_context.input), \ - op_context.params->block_shape, block_shape_dims, \ - GetTensorData(op_context.output), \ +#define TF_LITE_BATCH_TO_SPACE_ND(type, scalar) \ + type::BatchToSpaceND(GetTensorData(op_context.input), \ + GetTensorDims(op_context.input), \ + GetTensorData(op_context.block_shape), \ + GetTensorDims(op_context.block_shape), \ + GetTensorData(op_context.output), \ GetTensorDims(op_context.output)) switch (op_context.input->type) { // Already know in/out types are same. case kTfLiteFloat32: diff --git a/tensorflow/contrib/lite/kernels/batch_to_space_nd_test.cc b/tensorflow/contrib/lite/kernels/batch_to_space_nd_test.cc index 3ec4efbebcef9d55d0042d93007018c9f6ee3b58..8485cde1b40066f2070855bca91ea78a9f80e83c 100644 --- a/tensorflow/contrib/lite/kernels/batch_to_space_nd_test.cc +++ b/tensorflow/contrib/lite/kernels/batch_to_space_nd_test.cc @@ -26,36 +26,76 @@ using ::testing::ElementsAreArray; class BatchToSpaceNDOpModel : public SingleOpModel { public: - BatchToSpaceNDOpModel(std::initializer_list input_shape, - std::initializer_list block_shape, - std::initializer_list before_crops, - std::initializer_list after_crops) { - input_ = AddInput(TensorType_FLOAT32); - output_ = AddOutput(TensorType_FLOAT32); - SetBuiltinOp(BuiltinOperator_BATCH_TO_SPACE_ND, - BuiltinOptions_BatchToSpaceNDOptions, - CreateBatchToSpaceNDOptions( - builder_, builder_.CreateVector(block_shape), - builder_.CreateVector(before_crops), - builder_.CreateVector(after_crops)) - .Union()); - BuildInterpreter({input_shape}); - } - void SetInput(std::initializer_list data) { PopulateTensor(input_, data); } + void SetBlockShape(std::initializer_list data) { + PopulateTensor(block_shape_, data); + } + + void SetCrops(std::initializer_list data) { + PopulateTensor(crops_, data); + } + std::vector GetOutput() { return ExtractVector(output_); } std::vector GetOutputShape() { return GetTensorShape(output_); } - private: + protected: int input_; + int block_shape_; + int crops_; int output_; }; -TEST(BatchToSpaceNDOpTest, SimpleTest) { - BatchToSpaceNDOpModel m({4, 2, 2, 1}, {2, 2}, {0, 0}, {0, 0}); +// Tests case where block_shape and crops are const tensors. +// +// Example usage is as follows: +// BatchToSpaceNDOpConstModel m(input_shape, block_shape, crops); +// m.SetInput(input_data); +// m.Invoke(); +class BatchToSpaceNDOpConstModel : public BatchToSpaceNDOpModel { + public: + BatchToSpaceNDOpConstModel(std::initializer_list input_shape, + std::initializer_list block_shape, + std::initializer_list crops) { + input_ = AddInput(TensorType_FLOAT32); + block_shape_ = AddConstInput(TensorType_INT32, block_shape, {2}); + crops_ = AddConstInput(TensorType_INT32, crops, {2, 2}); + output_ = AddOutput(TensorType_FLOAT32); + + SetBuiltinOp(BuiltinOperator_BATCH_TO_SPACE_ND, + BuiltinOptions_BatchToSpaceNDOptions, + CreateBatchToSpaceNDOptions(builder_).Union()); + BuildInterpreter({input_shape}); + } +}; + +// Tests case where block_shape and crops are non-const tensors. +// +// Example usage is as follows: +// BatchToSpaceNDOpDynamicModel m(input_shape); +// m.SetInput(input_data); +// m.SetBlockShape(block_shape); +// m.SetPaddings(crops); +// m.Invoke(); +class BatchToSpaceNDOpDynamicModel : public BatchToSpaceNDOpModel { + public: + BatchToSpaceNDOpDynamicModel(std::initializer_list input_shape) { + input_ = AddInput(TensorType_FLOAT32); + block_shape_ = AddInput(TensorType_INT32); + crops_ = AddInput(TensorType_INT32); + output_ = AddOutput(TensorType_FLOAT32); + + SetBuiltinOp(BuiltinOperator_BATCH_TO_SPACE_ND, + BuiltinOptions_BatchToSpaceNDOptions, + CreateBatchToSpaceNDOptions(builder_).Union()); + BuildInterpreter({input_shape, {2}, {2, 2}}); + } +}; + +TEST(BatchToSpaceNDOpTest, SimpleConstTest) { + BatchToSpaceNDOpConstModel m({4, 2, 2, 1}, {2, 2}, {0, 0, 0, 0}); m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); m.Invoke(); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1})); @@ -63,11 +103,35 @@ TEST(BatchToSpaceNDOpTest, SimpleTest) { 4, 8, 11, 15, 12, 16})); } +TEST(BatchToSpaceNDOpTest, SimpleDynamicTest) { + BatchToSpaceNDOpDynamicModel m({4, 2, 2, 1}); + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); + m.SetBlockShape({2, 2}); + m.SetCrops({0, 0, 0, 0}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 5, 2, 6, 9, 13, 10, 14, 3, 7, + 4, 8, 11, 15, 12, 16})); +} + TEST(BatchToSpaceNDOpTest, InvalidShapeTest) { - EXPECT_DEATH(BatchToSpaceNDOpModel({3, 2, 2, 1}, {2, 2}, {0, 0}, {0, 0}), + EXPECT_DEATH(BatchToSpaceNDOpConstModel({3, 2, 2, 1}, {2, 2}, {0, 0, 0, 0}), "Cannot allocate tensors"); } +TEST(BatchToSpaceNDOpTest, InvalidCropsConstTest) { + EXPECT_DEATH(BatchToSpaceNDOpConstModel({3, 2, 2, 1}, {2, 2}, {0, 0, 0, 1}), + "1 != 0"); +} + +TEST(BatchToSpaceNDOpTest, InvalidCropsDynamicTest) { + BatchToSpaceNDOpDynamicModel m({4, 2, 2, 1}); + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); + m.SetBlockShape({2, 2}); + m.SetCrops({0, 0, 1, 0}); + EXPECT_DEATH(m.Invoke(), "1 != 0"); +} + } // namespace } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc b/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc index f54081623578a7b1f37de8d9f111d7950c9e2757..aa24c1f34cd1e8c02a6a75b62fbe5f3c629498ca 100644 --- a/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc +++ b/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/contrib/lite/builtin_op_data.h" #include "tensorflow/contrib/lite/context.h" #include "tensorflow/contrib/lite/kernels/activation_functor.h" +#include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h" #include "tensorflow/contrib/lite/kernels/op_macros.h" namespace tflite { @@ -119,47 +120,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { return kTfLiteOk; } -namespace { -// Performs one RNN computation step for the input specified by input_ptr_batch. -// The RNN cell is specified by the pointers to its weights and biases, along -// with the input size, number of units, strides, activation. -// The pointers to the hidden state and the output are updated as a result. -// TODO(mirkov): factor out this function to a shared library. -void RnnStep(const float* input_ptr_batch, const float* input_weights_ptr, - const float* recurrent_weights_ptr, const float* bias_ptr, - int input_size, int num_units, int input_weights_stride, - int recurrent_weights_stride, TfLiteFusedActivation activation, - float* hidden_state_ptr_batch, float* output_ptr_batch) { - // Output = bias - for (int o = 0; o < num_units; o++) { - output_ptr_batch[o] = bias_ptr[o]; - } - - // Output += input * input_weights - for (int o = 0; o < num_units; o++) { - for (int i = 0; i < input_size; i++) { - output_ptr_batch[o] += input_ptr_batch[i] * input_weights_ptr[i]; - } - input_weights_ptr += input_weights_stride; - } - - // Output += recurrent_weights * hidden_state - for (int o = 0; o < num_units; o++) { - for (int h = 0; h < num_units; h++) { - output_ptr_batch[o] += - hidden_state_ptr_batch[h] * recurrent_weights_ptr[h]; - } - recurrent_weights_ptr += recurrent_weights_stride; - } - - // Output = activation(Output) and update hidden_state - for (int o = 0; o < num_units; o++) { - output_ptr_batch[o] = (ActivationFunctor(activation))(output_ptr_batch[o]); - hidden_state_ptr_batch[o] = output_ptr_batch[o]; - } -} -} // namespace - TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { auto* params = reinterpret_cast(node->builtin_data); @@ -189,15 +149,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { const int input_size = input->dims->data[2]; const int fw_num_units = fw_input_weights->dims->data[0]; - const int fw_input_weights_stride = fw_input_weights->dims->data[1]; - const int fw_recurrent_weights_stride = fw_recurrent_weights->dims->data[1]; const float* fw_bias_ptr = fw_bias->data.f; const float* fw_input_weights_ptr = fw_input_weights->data.f; const float* fw_recurrent_weights_ptr = fw_recurrent_weights->data.f; const int bw_num_units = bw_input_weights->dims->data[0]; - const int bw_input_weights_stride = bw_input_weights->dims->data[1]; - const int bw_recurrent_weights_stride = bw_recurrent_weights->dims->data[1]; const float* bw_bias_ptr = bw_bias->data.f; const float* bw_input_weights_ptr = bw_input_weights->data.f; const float* bw_recurrent_weights_ptr = bw_recurrent_weights->data.f; @@ -212,10 +168,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { float* output_ptr_batch = fw_output->data.f + b * fw_num_units * max_time + s * fw_num_units; - RnnStep(input_ptr_batch, fw_input_weights_ptr, fw_recurrent_weights_ptr, - fw_bias_ptr, input_size, fw_num_units, fw_input_weights_stride, - fw_recurrent_weights_stride, params->activation, - fw_hidden_state_ptr_batch, output_ptr_batch); + kernel_utils::RnnBatchStep( + input_ptr_batch, fw_input_weights_ptr, fw_recurrent_weights_ptr, + fw_bias_ptr, input_size, fw_num_units, /*batch_size=*/1, + params->activation, fw_hidden_state_ptr_batch, output_ptr_batch); } // Backward cell. float* bw_hidden_state_ptr_batch = @@ -226,10 +182,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { float* output_ptr_batch = bw_output->data.f + b * bw_num_units * max_time + s * bw_num_units; - RnnStep(input_ptr_batch, bw_input_weights_ptr, bw_recurrent_weights_ptr, - bw_bias_ptr, input_size, bw_num_units, bw_input_weights_stride, - bw_recurrent_weights_stride, params->activation, - bw_hidden_state_ptr_batch, output_ptr_batch); + kernel_utils::RnnBatchStep( + input_ptr_batch, bw_input_weights_ptr, bw_recurrent_weights_ptr, + bw_bias_ptr, input_size, bw_num_units, /*batch_size=*/1, + params->activation, bw_hidden_state_ptr_batch, output_ptr_batch); } } return kTfLiteOk; diff --git a/tensorflow/contrib/lite/kernels/concatenation.cc b/tensorflow/contrib/lite/kernels/concatenation.cc index 9e7a1233dac0f3cd02dc386f9d194597f38ca3b8..7ff907531805887afea407684fdbaa65e98d619a 100644 --- a/tensorflow/contrib/lite/kernels/concatenation.cc +++ b/tensorflow/contrib/lite/kernels/concatenation.cc @@ -49,6 +49,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { // dimensions except 'axis' must be equal. TfLiteTensor* t0 = &context->tensors[node->inputs->data[0]]; TfLiteType input_type = t0->type; + if (axis < 0) axis += t0->dims->size; TF_LITE_ENSURE(context, axis >= 0); TF_LITE_ENSURE(context, axis < t0->dims->size); @@ -131,8 +132,9 @@ template TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { auto* params = reinterpret_cast(node->builtin_data); - + int axis = params->axis; TfLiteTensor* output = &context->tensors[node->outputs->data[0]]; + if (axis < 0) axis += output->dims->size; // TODO(ahentz): Creating 'all_inputs' below is not very efficient. We should // allocate and populate these during Prepare(). @@ -141,7 +143,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { #define TF_LITE_CONCATENATION(type, scalar) \ VectorOfInputs all_inputs(*context, *node->inputs); \ type::Concatenation( \ - RemapDim(NumDimensions(output), params->axis), all_inputs.data(), \ + RemapDim(NumDimensions(output), axis), all_inputs.data(), \ all_inputs.dims(), node->inputs->size, GetTensorData(output), \ GetTensorDims(output)) diff --git a/tensorflow/contrib/lite/kernels/concatenation_test.cc b/tensorflow/contrib/lite/kernels/concatenation_test.cc index 499856a93cbbfbf9aa1a326912e52ce32bbbdf83..ba1ffc5f8423b9626c9c8e2a1086ea0dcca43f50 100644 --- a/tensorflow/contrib/lite/kernels/concatenation_test.cc +++ b/tensorflow/contrib/lite/kernels/concatenation_test.cc @@ -94,7 +94,7 @@ TEST(ConcatenationOpTest, TwoDimensionalOneInput) { EXPECT_THAT(m0.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6})); } -TEST(ConcatenationOpTest, TwoInputsTwoAxis) { +TEST(ConcatenationOpTest, TwoInputsTwoAxesNegativeAxes) { // We will concatenate two tensors along different dimensions. auto tensor0 = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; auto tensor1 = {7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f}; @@ -107,6 +107,14 @@ TEST(ConcatenationOpTest, TwoInputsTwoAxis) { EXPECT_THAT(m0.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12})); + ConcatenationOpModel m0_negative({TensorType_FLOAT32, {2, 3}}, /*axis=*/-2, + /*num_inputs=*/2); + m0_negative.SetInput(0, tensor0); + m0_negative.SetInput(1, tensor1); + m0_negative.Invoke(); + EXPECT_THAT(m0_negative.GetOutput(), + ElementsAreArray({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12})); + ConcatenationOpModel m1({TensorType_FLOAT32, {2, 3}}, /*axis=*/1, /*num_inputs=*/2); m1.SetInput(0, tensor0); @@ -114,6 +122,14 @@ TEST(ConcatenationOpTest, TwoInputsTwoAxis) { m1.Invoke(); EXPECT_THAT(m1.GetOutput(), ElementsAreArray({1, 2, 3, 7, 8, 9, 4, 5, 6, 10, 11, 12})); + + ConcatenationOpModel m1_negative({TensorType_FLOAT32, {2, 3}}, /*axis=*/-1, + /*num_inputs=*/2); + m1_negative.SetInput(0, tensor0); + m1_negative.SetInput(1, tensor1); + m1_negative.Invoke(); + EXPECT_THAT(m1_negative.GetOutput(), + ElementsAreArray({1, 2, 3, 7, 8, 9, 4, 5, 6, 10, 11, 12})); } TEST(ConcatenationOpTest, FourInputs) { diff --git a/tensorflow/contrib/lite/kernels/conv.cc b/tensorflow/contrib/lite/kernels/conv.cc index 37f499a4d09a38765aa4b8db8aa91b708edd7823..1fba3cbbce17fa35ac647912c21dae89384c8ca3 100644 --- a/tensorflow/contrib/lite/kernels/conv.cc +++ b/tensorflow/contrib/lite/kernels/conv.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/contrib/lite/builtin_op_data.h" #include "tensorflow/contrib/lite/context.h" #include "tensorflow/contrib/lite/kernels/gemm_support.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/cblas_conv.h" #include "tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h" #include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" #include "tensorflow/contrib/lite/kernels/internal/quantization_util.h" @@ -38,11 +39,16 @@ namespace ops { namespace builtin { namespace conv { -// This file has three implementation of Conv. +// This file has 4 implementation of Conv. enum KernelType { kReference, kGenericOptimized, // Neon-free - kNeonOptimized, + kMultithreadOptimized, + // The kernel uses use CBLAS interface for matrix multiplication. + // It's fast when an optimized CBLAS implementation is available (e.g. Apple + // Accelerate Framework), and it's slow when falling back to naive + // implementation. + kCblasOptimized, }; struct OpData { @@ -265,10 +271,13 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { free(hwcn_weights->data.raw); hwcn_weights->data.raw = nullptr; } + + // Note that hwcn_weights_status is a kTfLiteDynamic tensor, and + // ResizeTensor will actually allocate space for it. The would be more + // efficient if we placed hwcn_weights_status in the persistent arena. auto hwcn_weights_status = context->ResizeTensor(context, hwcn_weights, hwcn_weights_size); if (hwcn_weights_status != kTfLiteOk) return hwcn_weights_status; - hwcn_weights->data.raw = static_cast(malloc(hwcn_weights->bytes)); // TODO(petewarden): If Resize() is called when the size hasn't actually // changed, this will do extra redundant work. @@ -290,26 +299,34 @@ void EvalQuantized(TfLiteContext* context, TfLiteNode* node, auto filter_offset = -filter->params.zero_point; auto output_offset = output->params.zero_point; - if (kernel_type == kReference) { - reference_ops::Conv( - GetTensorData(input), GetTensorDims(input), input_offset, - GetTensorData(filter), GetTensorDims(filter), filter_offset, - GetTensorData(bias), GetTensorDims(bias), params->stride_width, - params->stride_height, data->padding.width, data->padding.height, - output_offset, data->output_multiplier, data->output_shift, - data->output_activation_min, data->output_activation_max, - GetTensorData(output), GetTensorDims(output), - GetTensorData(im2col), GetTensorDims(im2col), gemm_context); - } else { - optimized_ops::Conv( - GetTensorData(input), GetTensorDims(input), input_offset, - GetTensorData(filter), GetTensorDims(filter), filter_offset, - GetTensorData(bias), GetTensorDims(bias), params->stride_width, - params->stride_height, data->padding.width, data->padding.height, - output_offset, data->output_multiplier, data->output_shift, - data->output_activation_min, data->output_activation_max, - GetTensorData(output), GetTensorDims(output), - GetTensorData(im2col), GetTensorDims(im2col), gemm_context); + switch (kernel_type) { + case kReference: + reference_ops::Conv( + GetTensorData(input), GetTensorDims(input), input_offset, + GetTensorData(filter), GetTensorDims(filter), filter_offset, + GetTensorData(bias), GetTensorDims(bias), + params->stride_width, params->stride_height, data->padding.width, + data->padding.height, output_offset, data->output_multiplier, + data->output_shift, data->output_activation_min, + data->output_activation_max, GetTensorData(output), + GetTensorDims(output), GetTensorData(im2col), + GetTensorDims(im2col), gemm_context); + break; + case kGenericOptimized: + case kMultithreadOptimized: + case kCblasOptimized: + // There is only one optimized implementation for Quantized Conv. + optimized_ops::Conv( + GetTensorData(input), GetTensorDims(input), input_offset, + GetTensorData(filter), GetTensorDims(filter), filter_offset, + GetTensorData(bias), GetTensorDims(bias), + params->stride_width, params->stride_height, data->padding.width, + data->padding.height, output_offset, data->output_multiplier, + data->output_shift, data->output_activation_min, + data->output_activation_max, GetTensorData(output), + GetTensorDims(output), GetTensorData(im2col), + GetTensorDims(im2col), gemm_context); + break; } } @@ -322,31 +339,57 @@ void EvalFloat(TfLiteContext* context, TfLiteNode* node, CalculateActivationRangeFloat(params->activation, &output_activation_min, &output_activation_max); - if (kernel_type == kReference) { - reference_ops::Conv(GetTensorData(input), GetTensorDims(input), - GetTensorData(filter), GetTensorDims(filter), - GetTensorData(bias), GetTensorDims(bias), - params->stride_width, params->stride_height, - data->padding.width, data->padding.height, - output_activation_min, output_activation_max, - GetTensorData(output), GetTensorDims(output), - GetTensorData(im2col), GetTensorDims(im2col)); - } else { - const float* filter_data; - if (data->need_hwcn_weights) { - filter_data = GetTensorData(hwcn_weights); - } else { - filter_data = GetTensorData(filter); + switch (kernel_type) { + case kReference: { + reference_ops::Conv(GetTensorData(input), GetTensorDims(input), + GetTensorData(filter), GetTensorDims(filter), + GetTensorData(bias), GetTensorDims(bias), + params->stride_width, params->stride_height, + data->padding.width, data->padding.height, + output_activation_min, output_activation_max, + GetTensorData(output), GetTensorDims(output), + GetTensorData(im2col), GetTensorDims(im2col)); + break; + } + case kGenericOptimized: { + optimized_ops::Conv(GetTensorData(input), GetTensorDims(input), + GetTensorData(filter), GetTensorDims(filter), + GetTensorData(bias), GetTensorDims(bias), + params->stride_width, params->stride_height, + data->padding.width, data->padding.height, + output_activation_min, output_activation_max, + GetTensorData(output), GetTensorDims(output), + GetTensorData(im2col), GetTensorDims(im2col)); + break; + } + case kMultithreadOptimized: { + const float* filter_data; + if (data->need_hwcn_weights) { + filter_data = GetTensorData(hwcn_weights); + } else { + filter_data = GetTensorData(filter); + } + multithreaded_ops::Conv( + GetTensorData(input), GetTensorDims(input), filter_data, + GetTensorDims(filter), GetTensorData(bias), + GetTensorDims(bias), params->stride_width, params->stride_height, + data->padding.width, data->padding.height, params->padding, + output_activation_min, output_activation_max, + GetTensorData(output), GetTensorDims(output), + GetTensorData(im2col), GetTensorDims(im2col)); + break; + } + case kCblasOptimized: { + cblas_ops::Conv(GetTensorData(input), GetTensorDims(input), + GetTensorData(filter), GetTensorDims(filter), + GetTensorData(bias), GetTensorDims(bias), + params->stride_width, params->stride_height, + data->padding.width, data->padding.height, + output_activation_min, output_activation_max, + GetTensorData(output), GetTensorDims(output), + GetTensorData(im2col), GetTensorDims(im2col)); + break; } - - multithreaded_ops::Conv( - GetTensorData(input), GetTensorDims(input), filter_data, - GetTensorDims(filter), GetTensorData(bias), GetTensorDims(bias), - params->stride_width, params->stride_height, data->padding.width, - data->padding.height, params->padding, output_activation_min, - output_activation_max, GetTensorData(output), - GetTensorDims(output), GetTensorData(im2col), - GetTensorDims(im2col)); } } @@ -407,17 +450,25 @@ TfLiteRegistration* Register_CONVOLUTION_GENERIC_OPT() { return &r; } -TfLiteRegistration* Register_CONVOLUTION_NEON_OPT() { +TfLiteRegistration* Register_CONVOLUTION_MULTITHREADED_OPT() { + static TfLiteRegistration r = {conv::Init, conv::Free, conv::Prepare, + conv::Eval}; + return &r; +} + +TfLiteRegistration* Register_CONVOLUTION_CBLAS_OPT() { static TfLiteRegistration r = {conv::Init, conv::Free, conv::Prepare, - conv::Eval}; + conv::Eval}; return &r; } TfLiteRegistration* Register_CONV_2D() { -#ifdef USE_NEON - return Register_CONVOLUTION_NEON_OPT(); +// TODO(ycling): Define a compilation flag and use CBLAS kernel when a +// fast CBLAS implementatino is available. +#ifdef TFLITE_USE_CBLAS_CONVOLUTION_KERNEL + return Register_CONVOLUTION_CBLAS_OPT(); #else - return Register_CONVOLUTION_GENERIC_OPT(); + return Register_CONVOLUTION_MULTITHREADED_OPT(); #endif } diff --git a/tensorflow/contrib/lite/kernels/conv_test.cc b/tensorflow/contrib/lite/kernels/conv_test.cc index 1d0a81c3135625c07a3566f5f9a8e5401f0d4db7..d2393c3c97bb9516e2b8a6c8ae037dc0dfdfe64b 100644 --- a/tensorflow/contrib/lite/kernels/conv_test.cc +++ b/tensorflow/contrib/lite/kernels/conv_test.cc @@ -15,12 +15,25 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/contrib/lite/interpreter.h" #include "tensorflow/contrib/lite/kernels/register.h" #include "tensorflow/contrib/lite/kernels/test_util.h" #include "tensorflow/contrib/lite/model.h" namespace tflite { + +namespace ops { +namespace builtin { + +TfLiteRegistration* Register_CONVOLUTION_REF(); +TfLiteRegistration* Register_CONVOLUTION_GENERIC_OPT(); +TfLiteRegistration* Register_CONVOLUTION_MULTITHREADED_OPT(); +TfLiteRegistration* Register_CONVOLUTION_CBLAS_OPT(); + +} // namespace builtin +} // namespace ops + namespace { using ::testing::ElementsAreArray; @@ -30,9 +43,9 @@ class BaseConvolutionOpModel : public SingleOpModel { // TODO(ahentz): Also test different activation types, bias, padding types, // stride values. BaseConvolutionOpModel( - const TensorData& input, const TensorData& filter, - const TensorData& output, int stride_width = 2, int stride_height = 2, - enum Padding padding = Padding_VALID, + TfLiteRegistration* registration, const TensorData& input, + const TensorData& filter, const TensorData& output, int stride_width = 2, + int stride_height = 2, enum Padding padding = Padding_VALID, enum ActivationFunctionType activation = ActivationFunctionType_NONE) { input_ = AddInput(input); filter_ = AddInput(filter); @@ -62,6 +75,8 @@ class BaseConvolutionOpModel : public SingleOpModel { stride_height, activation) .Union()); + resolver_ = absl::make_unique(BuiltinOperator_CONV_2D, + registration); BuildInterpreter({GetShape(input_), GetShape(filter_), GetShape(bias_)}); } @@ -83,12 +98,26 @@ class ConvolutionOpModel : public BaseConvolutionOpModel { void SetInput(std::initializer_list data) { PopulateTensor(input_, data); } - std::vector GetOutput() { return ExtractVector(output_); } }; -TEST(ConvolutionOpTest, SimpleTestFloat32) { - ConvolutionOpModel m({TensorType_FLOAT32, {2, 2, 4, 1}}, +const auto kKernelMap = new std::map({ + {"Reference", ops::builtin::Register_CONVOLUTION_REF()}, + {"GenericOptimized", ops::builtin::Register_CONVOLUTION_GENERIC_OPT()}, + {"MultithreadedOptimized", + ops::builtin::Register_CONVOLUTION_MULTITHREADED_OPT()}, + {"CblasOptimized", ops::builtin::Register_CONVOLUTION_CBLAS_OPT()}, +}); + +class ConvolutionOpTest : public SingleOpTest { + protected: + const std::map& GetKernelMap() override { + return *kKernelMap; + } +}; + +TEST_P(ConvolutionOpTest, SimpleTestFloat32) { + ConvolutionOpModel m(GetRegistration(), {TensorType_FLOAT32, {2, 2, 4, 1}}, {TensorType_FLOAT32, {3, 2, 2, 1}}, {TensorType_FLOAT32, {}}); @@ -117,8 +146,8 @@ TEST(ConvolutionOpTest, SimpleTestFloat32) { })); } -TEST(ConvolutionOpTest, SimpleTestFloat32WithAnisotropicStrides) { - ConvolutionOpModel m({TensorType_FLOAT32, {1, 3, 6, 1}}, +TEST_P(ConvolutionOpTest, SimpleTestFloat32WithAnisotropicStrides) { + ConvolutionOpModel m(GetRegistration(), {TensorType_FLOAT32, {1, 3, 6, 1}}, {TensorType_FLOAT32, {1, 2, 2, 1}}, {TensorType_FLOAT32, {}}, /*stride_width=*/3, /*stride_height=*/1); @@ -139,7 +168,7 @@ TEST(ConvolutionOpTest, SimpleTestFloat32WithAnisotropicStrides) { })); } -TEST(ConvolutionOpTest, HandCalculatedFloat32) { +TEST_P(ConvolutionOpTest, HandCalculatedFloat32) { const int depth = 1; const int image_width = 4; const int image_height = 3; @@ -150,6 +179,7 @@ TEST(ConvolutionOpTest, HandCalculatedFloat32) { const int stride_height = 1; const Padding padding = Padding_SAME; ConvolutionOpModel m( + GetRegistration(), {TensorType_FLOAT32, {image_batch_count, image_height, image_width, depth}}, {TensorType_FLOAT32, {depth, filter_size, filter_size, filter_count}}, @@ -192,7 +222,7 @@ TEST(ConvolutionOpTest, HandCalculatedFloat32) { 178, 187, 234, 261, 121})); } -TEST(ConvolutionOpTest, HandCalculatedWithBiasFloat32) { +TEST_P(ConvolutionOpTest, HandCalculatedWithBiasFloat32) { const int depth = 1; const int image_width = 4; const int image_height = 3; @@ -203,6 +233,7 @@ TEST(ConvolutionOpTest, HandCalculatedWithBiasFloat32) { const int stride_height = 1; const Padding padding = Padding_SAME; ConvolutionOpModel m( + GetRegistration(), {TensorType_FLOAT32, {image_batch_count, image_height, image_width, depth}}, {TensorType_FLOAT32, {depth, filter_size, filter_size, filter_count}}, @@ -245,7 +276,7 @@ TEST(ConvolutionOpTest, HandCalculatedWithBiasFloat32) { 367, 188, 197, 244, 271, 131})); } -TEST(ConvolutionOpTest, HandCalculatedWithReluFloat32) { +TEST_P(ConvolutionOpTest, HandCalculatedWithReluFloat32) { const int depth = 1; const int image_width = 4; const int image_height = 3; @@ -256,6 +287,7 @@ TEST(ConvolutionOpTest, HandCalculatedWithReluFloat32) { const int stride_height = 1; const Padding padding = Padding_SAME; ConvolutionOpModel m( + GetRegistration(), {TensorType_FLOAT32, {image_batch_count, image_height, image_width, depth}}, {TensorType_FLOAT32, {depth, filter_size, filter_size, filter_count}}, @@ -300,7 +332,7 @@ TEST(ConvolutionOpTest, HandCalculatedWithReluFloat32) { ElementsAreArray({0, 0, 0, 0, 35, 112, 157, 0, 0, 34, 61, 0})); } -TEST(ConvolutionOpTest, HandCalculatedValidFloat32) { +TEST_P(ConvolutionOpTest, HandCalculatedValidFloat32) { const int depth = 1; const int image_width = 4; const int image_height = 3; @@ -311,6 +343,7 @@ TEST(ConvolutionOpTest, HandCalculatedValidFloat32) { const int stride_height = 1; const Padding padding = Padding_VALID; ConvolutionOpModel m( + GetRegistration(), {TensorType_FLOAT32, {image_batch_count, image_height, image_width, depth}}, {TensorType_FLOAT32, {depth, filter_size, filter_size, filter_count}}, @@ -366,8 +399,9 @@ class QuantizedConvolutionOpModel : public BaseConvolutionOpModel { // In this tests we set the input and output scales so that the results // match exactly the 'non-quantized' version. -TEST(ConvolutionOpTest, SimpleTestQuantized) { - QuantizedConvolutionOpModel m({TensorType_UINT8, {2, 2, 4, 1}, -63.5, 64}, +TEST_P(ConvolutionOpTest, SimpleTestQuantized) { + QuantizedConvolutionOpModel m(GetRegistration(), + {TensorType_UINT8, {2, 2, 4, 1}, -63.5, 64}, {TensorType_UINT8, {3, 2, 2, 1}, -63.5, 64}, {TensorType_UINT8, {}, -127, 128}); m.SetInput({ @@ -405,8 +439,9 @@ TEST(ConvolutionOpTest, SimpleTestQuantized) { })); } -TEST(ConvolutionOpTest, SimpleTestQuantizedWithAnisotropicStrides) { - QuantizedConvolutionOpModel m({TensorType_UINT8, {1, 3, 6, 1}, -63.5, 64}, +TEST_P(ConvolutionOpTest, SimpleTestQuantizedWithAnisotropicStrides) { + QuantizedConvolutionOpModel m(GetRegistration(), + {TensorType_UINT8, {1, 3, 6, 1}, -63.5, 64}, {TensorType_UINT8, {1, 2, 2, 1}, -63.5, 64}, {TensorType_UINT8, {}, -127, 128}, /*stride_width=*/3, /*stride_height=*/1); @@ -430,6 +465,11 @@ TEST(ConvolutionOpTest, SimpleTestQuantizedWithAnisotropicStrides) { 167, 93, // })); } + +INSTANTIATE_TEST_CASE_P( + ConvolutionOpTest, ConvolutionOpTest, + ::testing::ValuesIn(SingleOpTest::GetKernelTags(*kKernelMap))); + } // namespace } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/embedding_lookup_sparse_test.cc b/tensorflow/contrib/lite/kernels/embedding_lookup_sparse_test.cc index dcdc5fffad9ceac1a9d23a4e91637a9ff92a8dda..ef2b5422253ea880a9ded4d3c0efc5cec07178a9 100644 --- a/tensorflow/contrib/lite/kernels/embedding_lookup_sparse_test.cc +++ b/tensorflow/contrib/lite/kernels/embedding_lookup_sparse_test.cc @@ -123,18 +123,16 @@ TEST(EmbeddingLookupOpTest, SimpleTestSqrtn) { [](int i, int j, int k) { return i + j / 10.0f + k / 100.0f; }); m.Invoke(); - EXPECT_THAT( - m.GetOutput(), - ElementsAreArray(ArrayFloatNear({ - 1.00, 1.01, 1.10, 1.11, 1.20, 1.21, // Row 1 - 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, // - - 6.00f / std::sqrt(20.0f), 6.06f / std::sqrt(20.0f), - 6.60f / std::sqrt(20.0f), 6.66f / std::sqrt(20.0f), - 7.20f / std::sqrt(20.0f), - 7.26f / - std::sqrt( - 20.0f), // 2 * Row 3 + 4 * Row 0, // 2 * Row 3 + 4 * Row 0 - }))); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear({ + 1.00, 1.01, 1.10, 1.11, 1.20, 1.21, // Row 1 + 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, // - + 6.00f / std::sqrt(20.0f), 6.06f / std::sqrt(20.0f), + 6.60f / std::sqrt(20.0f), 6.66f / std::sqrt(20.0f), + 7.20f / std::sqrt(20.0f), + 7.26f / std::sqrt(20.0f), // 2 * Row 3 + 4 * Row 0, // 2 * + // Row 3 + 4 * Row 0 + }))); } TEST(EmbeddingLookupOpTest, Indices3DTest) { diff --git a/tensorflow/contrib/lite/kernels/gather_test.cc b/tensorflow/contrib/lite/kernels/gather_test.cc index 658d977b8dc7fffcdde69d74ba2564dfa1b5709e..cdadbeda1884ba0186846826dd16be6ff69878d9 100644 --- a/tensorflow/contrib/lite/kernels/gather_test.cc +++ b/tensorflow/contrib/lite/kernels/gather_test.cc @@ -81,10 +81,8 @@ TEST(GatherOpTest, Test0DIndex) { m.SetInputFloat({-2.0, 0.2, 0.7, 0.8}); m.SetPositions({1}); m.Invoke(); - EXPECT_THAT(m.GetOutputFloat(), - ElementsAreArray(ArrayFloatNear({0.7, 0.8}))); - EXPECT_THAT(m.GetOutputShape(), - ElementsAreArray({2})); + EXPECT_THAT(m.GetOutputFloat(), ElementsAreArray(ArrayFloatNear({0.7, 0.8}))); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2})); } TEST(GatherOpTest, Test0DIndexWith0DResult) { @@ -94,8 +92,7 @@ TEST(GatherOpTest, Test0DIndexWith0DResult) { m.SetInputFloat({1.0, 2.0, 3.0}); m.SetPositions({1}); m.Invoke(); - EXPECT_THAT(m.GetOutputFloat(), - ElementsAreArray(ArrayFloatNear({2.0}))); + EXPECT_THAT(m.GetOutputFloat(), ElementsAreArray(ArrayFloatNear({2.0}))); EXPECT_TRUE(m.GetOutputShape().empty()); } diff --git a/tensorflow/contrib/lite/kernels/hashtable_lookup_test.cc b/tensorflow/contrib/lite/kernels/hashtable_lookup_test.cc index cb6038f9009a3865661e7b4f075c3033166d0f91..ba0ed5ce06392613238b757308dddc2b22e7eb30 100644 --- a/tensorflow/contrib/lite/kernels/hashtable_lookup_test.cc +++ b/tensorflow/contrib/lite/kernels/hashtable_lookup_test.cc @@ -116,7 +116,10 @@ TEST(HashtableLookupOpTest, Test2DInput) { 1.0, 1.1, // 1-st item }))); EXPECT_THAT(m.GetHit(), ElementsAreArray({ - 1, 0, 1, 1, + 1, + 0, + 1, + 1, })); } diff --git a/tensorflow/contrib/lite/kernels/internal/BUILD b/tensorflow/contrib/lite/kernels/internal/BUILD index 38b032c6de7987ff5e3da3ba5fcf4e9fc8574c44..adedd58ff438ec9136f14e81e5cf93f6031890ef 100644 --- a/tensorflow/contrib/lite/kernels/internal/BUILD +++ b/tensorflow/contrib/lite/kernels/internal/BUILD @@ -124,6 +124,13 @@ config_setting( }, ) +config_setting( + name = "darwin_x86_64", + values = { + "cpu": "darwin_x86_64", + }, +) + config_setting( name = "freebsd", values = { @@ -154,6 +161,7 @@ cc_library( ":x86": tflite_deps_intel, ":x86_64": tflite_deps_intel, ":darwin": tflite_deps_intel, + ":darwin_x86_64": tflite_deps_intel, ":freebsd": tflite_deps_intel, "//conditions:default": [], }), @@ -162,6 +170,8 @@ cc_library( cc_library( name = "optimized", hdrs = [ + "optimized/cblas_conv.h", + "optimized/cblas_reference.h", "optimized/eigen_spatial_convolutions.h", "optimized/eigen_tensor_reduced_instantiations_oss.h", "optimized/multithreaded_conv.h", @@ -232,6 +242,7 @@ cc_library( ":x86": tflite_deps_intel, ":x86_64": tflite_deps_intel, ":darwin": tflite_deps_intel, + ":darwin_x86_64": tflite_deps_intel, ":freebsd": tflite_deps_intel, "//conditions:default": [], }), @@ -284,6 +295,16 @@ cc_library( ], ) +cc_library( + name = "kernel_utils", + srcs = ["kernel_utils.cc"], + hdrs = ["kernel_utils.h"], + deps = [ + ":tensor_utils", + "//tensorflow/contrib/lite:builtin_op_data", + ], +) + cc_library( name = "tensor_utils", srcs = [ @@ -330,6 +351,9 @@ cc_library( ":x86": [ ":neon_tensor_utils", ], + ":k8": [ + ":neon_tensor_utils", + ], ":darwin": [ ":neon_tensor_utils", ], diff --git a/tensorflow/contrib/lite/kernels/internal/compatibility.h b/tensorflow/contrib/lite/kernels/internal/compatibility.h index 1d963afb7e1ce414f251f090208923ca0c68cee1..51426bb1c584b82af7b1a2ffaf5a675a1dd9a6fd 100644 --- a/tensorflow/contrib/lite/kernels/internal/compatibility.h +++ b/tensorflow/contrib/lite/kernels/internal/compatibility.h @@ -27,6 +27,10 @@ limitations under the License. #define TFLITE_DCHECK_EQ(x, y) ((x) == (y)) ? (void)0 : assert(false) #endif +#ifndef TFLITE_DCHECK_NE +#define TFLITE_DCHECK_NE(x, y) ((x) != (y)) ? (void)0 : assert(false) +#endif + #ifndef TFLITE_DCHECK_GE #define TFLITE_DCHECK_GE(x, y) ((x) >= (y)) ? (void)0 : assert(false) #endif @@ -52,6 +56,10 @@ limitations under the License. #define TFLITE_CHECK_EQ(x, y) ((x) == (y)) ? (void)0 : abort() #endif +#ifndef TFLITE_CHECK_NE +#define TFLITE_CHECK_NE(x, y) ((x) != (y)) ? (void)0 : abort() +#endif + #ifndef TFLITE_CHECK_GE #define TFLITE_CHECK_GE(x, y) ((x) >= (y)) ? (void)0 : abort() #endif diff --git a/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc b/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc new file mode 100644 index 0000000000000000000000000000000000000000..510395126ce3785b1d44fec1e0eb994c29ff0db7 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc @@ -0,0 +1,44 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h" + +namespace tflite { +namespace kernel_utils { + +void RnnBatchStep(const float* input_ptr_batch, const float* input_weights_ptr, + const float* recurrent_weights_ptr, const float* bias_ptr, + int input_size, int num_units, int batch_size, + TfLiteFusedActivation activation, + float* hidden_state_ptr_batch, float* output_ptr_batch) { + // Output = bias + tensor_utils::VectorBatchVectorAssign(bias_ptr, num_units, batch_size, + output_ptr_batch); + // Output += input * input_weights + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + input_weights_ptr, num_units, input_size, input_ptr_batch, batch_size, + output_ptr_batch, /*result_stride=*/1); + // Output += recurrent_weights * hidden_state + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + recurrent_weights_ptr, num_units, num_units, hidden_state_ptr_batch, + batch_size, output_ptr_batch, /*result_stride=*/1); + // Output = activation(Output) and update hidden_state + tensor_utils::ApplyActivationToVector( + output_ptr_batch, num_units * batch_size, activation, output_ptr_batch); + tensor_utils::VectorBatchVectorAssign(output_ptr_batch, num_units, batch_size, + hidden_state_ptr_batch); +} + +} // namespace kernel_utils +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/internal/kernel_utils.h b/tensorflow/contrib/lite/kernels/internal/kernel_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..9872d4500b862388ed4b96c97e3755f548e35d35 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/kernel_utils.h @@ -0,0 +1,40 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_KERNEL_UTILS_H_ +#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_KERNEL_UTILS_H_ + +#include "tensorflow/contrib/lite/builtin_op_data.h" + +namespace tflite { +namespace kernel_utils { + +// Performs an RNN batch inference step for inputs specified by input_ptr_batch. +// The RNN cell is specified by the pointers to its input and recurrent weights, +// and biases, along with the input size, number of units, activation. +// +// The pointers to the hidden state and the output are updated as a result. +// +// The pointers with the suffix "_batch" point to data aligned in batch_major +// order, and each step processes batch_size many inputs from input_ptr_batch, +// and updates batch_size many outputs and hidden states. +void RnnBatchStep(const float* input_ptr_batch, const float* input_weights_ptr, + const float* recurrent_weights_ptr, const float* bias_ptr, + int input_size, int num_units, int batch_size, + TfLiteFusedActivation activation, + float* hidden_state_ptr_batch, float* output_ptr_batch); + +} // namespace kernel_utils +} // namespace tflite +#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_KERNEL_UTILS_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/cblas_conv.h b/tensorflow/contrib/lite/kernels/internal/optimized/cblas_conv.h new file mode 100644 index 0000000000000000000000000000000000000000..fcb9fac6713865f2d6f89755c785b179436cbc57 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/optimized/cblas_conv.h @@ -0,0 +1,89 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_CBLAS_CONV_H_ +#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_CBLAS_CONV_H_ + +// The Conv implementation based on CBLAS interface. This is only used on iOS +// for now, utilizing Apple's Accelerate framework. + +// TODO(ycling): Update the BUILD file and integrate with Apple Accelerate +// Famework when it's available. +#include "tensorflow/contrib/lite/kernels/internal/optimized/cblas_reference.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" + +namespace tflite { +namespace cblas_ops { + +inline void Conv(const float* input_data, const Dims<4>& input_dims, + const float* filter_data, const Dims<4>& filter_dims, + const float* bias_data, const Dims<4>& bias_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, float output_activation_min, + float output_activation_max, float* output_data, + const Dims<4>& output_dims, float* im2col_data, + const Dims<4>& im2col_dims) { + gemmlowp::ScopedProfilingLabel label("Conv/cblas"); + + const float* gemm_input_data = nullptr; + const Dims<4>* gemm_input_dims = nullptr; + const int filter_width = ArraySize(filter_dims, 1); + const int filter_height = ArraySize(filter_dims, 2); + const bool need_im2col = stride_width != 1 || stride_height != 1 || + filter_width != 1 || filter_height != 1; + if (need_im2col) { + TFLITE_DCHECK(im2col_data); + optimized_ops::Im2col(input_data, input_dims, stride_width, stride_height, + pad_width, pad_height, filter_height, filter_width, 0, + im2col_data, im2col_dims); + gemm_input_data = im2col_data; + gemm_input_dims = &im2col_dims; + } else { + TFLITE_DCHECK(!im2col_data); + gemm_input_data = input_data; + gemm_input_dims = &input_dims; + } + + // The following code computes matrix multiplication c = a * transponse(b) + // with CBLAS, where: + // * `a` is a matrix with dimensions (m, k). + // * `b` is a matrix with dimensions (n, k), so transpose(b) is (k, n). + // * `c` is a matrix with dimensions (m, n). + // The naming of variables are aligned with CBLAS specification here. + const float* a = gemm_input_data; + const float* b = filter_data; + float* c = output_data; + int m = gemm_input_dims->sizes[1] * gemm_input_dims->sizes[2] * + gemm_input_dims->sizes[3]; + int n = output_dims.sizes[0]; + int k = gemm_input_dims->sizes[0]; + // The stride of matrix a, b and c respectively. + int stride_a = k; + int stride_b = k; + int stride_c = n; + + cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, m, n, k, 1.0f, a, + stride_a, b, stride_b, 0.0f, c, stride_c); + + optimized_ops::AddBiasAndEvalActivationFunction( + bias_data, bias_dims, output_data, output_dims, output_activation_min, + output_activation_max); +} + +} // namespace cblas_ops +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_CBLAS_CONV_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/cblas_reference.h b/tensorflow/contrib/lite/kernels/internal/optimized/cblas_reference.h new file mode 100644 index 0000000000000000000000000000000000000000..6acc513805c9398c304f3e24175d3bd6c96938f6 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/optimized/cblas_reference.h @@ -0,0 +1,69 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_CBLAS_REFERENCE_H_ +#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_CBLAS_REFERENCE_H_ + +#include "tensorflow/contrib/lite/kernels/internal/compatibility.h" + +// The reference implementation for a small subset of CBLAS interface. +// This is only used for testing CBLAS implementation, and should never be used +// in production code. + +namespace tflite { +namespace cblas_ops { + +// The following code follows the original CBLAS specification, and it might +// conflict with the TensorFlow naming convention. +// TODO(ycling): Find another way to test CBLAS with bazel, without writing +// a reference implementation by ourselves. +enum CBLAS_ORDER { CblasRowMajor = 0, CblasColMajor = 1 }; + +enum CBLAS_TRANSPOSE { CblasNoTrans = 0, CblasTrans = 1, CblasConjTrans = 2 }; + +// A reference implementation for matrix multiplication. +// The following code computes, c = a * transponse(b) matrix multiplication +// with CBLAS, where: +// * `a` is a matrix with dimensions (m, k). +// * `b` is a matrix with dimensions (n, k), so transpose(b) is (k, n). +// * `c` is a matrix with dimensions (m, n). +// The naming of variables is aligned with CBLAS specification here. +void cblas_sgemm(const enum CBLAS_ORDER order, + const enum CBLAS_TRANSPOSE trans_a, + const enum CBLAS_TRANSPOSE trans_b, const int m, const int n, + const int k, const float alpha, const float *a, + const int stride_a, const float *b, const int stride_b, + const float beta, float *c, const int stride_c) { + TFLITE_DCHECK(order == CblasRowMajor); + TFLITE_DCHECK(trans_a == CblasNoTrans); + TFLITE_DCHECK(trans_b == CblasTrans); + TFLITE_DCHECK(beta == 0.0f); + for (int row = 0; row < m; ++row) { + for (int col = 0; col < n; ++col) { + // If `beta` non-zero, multiple it with the original values in output. + // Otherwise, ignore the original value in output completely. + float value = 0.0f; + for (int idx = 0; idx < k; ++idx) { + value += alpha * a[stride_a * row + idx] * b[stride_b * col + idx]; + } + c[stride_c * row + col] = value; + } + } +} + +} // namespace cblas_ops +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_CBLAS_REFERENCE_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/cpu_check.h b/tensorflow/contrib/lite/kernels/internal/optimized/cpu_check.h index 629783d7e58cf740a8633c708ca9821667f86123..e0eca2e736be00ff09737325f06b0035e77e3103 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/cpu_check.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/cpu_check.h @@ -36,15 +36,11 @@ inline bool TestCPUFeatureNeon() { #elif defined USE_NEON || defined __ARM_NEON -inline bool TestCPUFeatureNeon() { - return true; -} +inline bool TestCPUFeatureNeon() { return true; } #else -inline bool TestCPUFeatureNeon() { - return false; -} +inline bool TestCPUFeatureNeon() { return false; } #endif diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h index 81796e295d9c7ae1f04163467c8b2af851b632c2..e2c87df80bd927d823b150ed3799641796dfb4c7 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h @@ -992,11 +992,11 @@ inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims, for (int k = 0; k < 4; k++) { acc[k] = vld1q_f32(acc_buffer + i + 4 * k); } - for (int k = 0; k < 4; k++) { - acc[k] = vmaxq_f32( - vdupq_n_f32(output_activation_min), - vminq_f32(vdupq_n_f32(output_activation_max), acc[k])); - } + for (int k = 0; k < 4; k++) { + acc[k] = vmaxq_f32( + vdupq_n_f32(output_activation_min), + vminq_f32(vdupq_n_f32(output_activation_max), acc[k])); + } for (int k = 0; k < 4; k++) { vst1q_f32(output_ptr + 4 * k, acc[k]); } diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/eigen_spatial_convolutions.h b/tensorflow/contrib/lite/kernels/internal/optimized/eigen_spatial_convolutions.h index f21fbf532ac01ced594715d0a0da9bd6e6f8d0e2..ce3cde76999c77e1f9bf1eaccdba7e84ed508dda 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/eigen_spatial_convolutions.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/eigen_spatial_convolutions.h @@ -39,7 +39,6 @@ limitations under the License. #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #endif - namespace Eigen { /** SpatialConvolution @@ -215,13 +214,12 @@ EIGEN_DEVICE_FUNC } // TODO(yangke): choose() is defined in TensorContraction.h -- consider // moving it to somewhere more "common". - return - input - .extract_image_patches(kernelRows, kernelCols, row_stride, col_stride, - row_in_stride, col_in_stride, padding_type) - .reshape(pre_contract_dims) - .contract(kernel.reshape(kernel_dims), contract_dims) - .reshape(post_contract_dims); + return input + .extract_image_patches(kernelRows, kernelCols, row_stride, col_stride, + row_in_stride, col_in_stride, padding_type) + .reshape(pre_contract_dims) + .contract(kernel.reshape(kernel_dims), contract_dims) + .reshape(post_contract_dims); } } // end namespace Eigen diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h index 8163c76cfd2eb9b320fe65e54c6b88f3d694a598..d5b0f45fd81d60aed2d0faafe01c9f7d236c4342 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h @@ -2938,6 +2938,55 @@ inline void Tanh(const float* input_data, const Dims<4>& input_dims, output_map.array() = input_map.array().tanh(); } +inline void Tanh(const uint8* input_data, const Dims<4>& input_dims, + int32 input_zero_point, int32 input_range_radius, + int32 input_multiplier, int input_left_shift, + uint8* output_data, const Dims<4>& output_dims) { + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int height = MatchingArraySize(input_dims, 2, output_dims, 2); + const int width = MatchingArraySize(input_dims, 1, output_dims, 1); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + for (int c = 0; c < depth; ++c) { + const uint8 input_val_u8 = input_data[Offset(input_dims, c, x, y, b)]; + const int32 input_val_centered = + static_cast(input_val_u8) - input_zero_point; + uint8 output_val; + if (input_val_centered <= -input_range_radius) { + output_val = 0; + } else if (input_val_centered >= input_range_radius) { + output_val = 255; + } else { + const int32 input_val_rescaled = + MultiplyByQuantizedMultiplierGreaterThanOne( + input_val_centered, input_multiplier, input_left_shift); + using FixedPoint4 = gemmlowp::FixedPoint; + using FixedPoint0 = gemmlowp::FixedPoint; + const FixedPoint4 input_val_f4 = + FixedPoint4::FromRaw(input_val_rescaled); + const FixedPoint0 output_val_f0 = gemmlowp::tanh(input_val_f4); + + using gemmlowp::RoundingDivideByPOT; + int32 output_val_s32 = RoundingDivideByPOT(output_val_f0.raw(), 24); + // TODO(mjmatthews): properly wire through this zero offset + output_val_s32 += 127; + if (output_val_s32 == -1) { + // May underflow since we cannot properly represent -1.0f + output_val_s32 = 0; + } + TFLITE_DCHECK_GE(output_val_s32, 0); + TFLITE_DCHECK_LE(output_val_s32, 255); + output_val = static_cast(output_val_s32); + } + output_data[Offset(output_dims, c, x, y, b)] = output_val; + } + } + } + } +} + inline void Dequantize(const uint8* input_data, const Dims<4>& input_dims, int32 zero_point, double scale, float* output_data, const Dims<4>& output_dims) { @@ -3410,7 +3459,7 @@ inline void ResizeBilinearGeneric(const float* input_data, inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims, const int32* output_size_data, const Dims<4>& output_size_dims, float* output_data, - const Dims<4>& output_dims) { + const Dims<4>& output_dims, bool align_corners) { gemmlowp::ScopedProfilingLabel label("ResizeBilinear"); int32 batches = MatchingArraySize(input_dims, 3, output_dims, 3); int32 input_height = ArraySize(input_dims, 2); @@ -3425,13 +3474,20 @@ inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims, int32 output_width = output_size_data[Offset(output_size_dims, 1, 0, 0, 0)]; // Specialize for 2x2 upsample. - if (output_height == 2 * input_height && output_width == 2 * input_width) { + if (!align_corners && output_height == 2 * input_height && + output_width == 2 * input_width) { ResizeBilinear2x2(input_data, input_dims, output_data, output_dims, batches, input_height, input_width, depth, output_height, output_width); } else { float height_scale = static_cast(input_height) / output_height; float width_scale = static_cast(input_width) / output_width; + if (align_corners && output_height > 1) { + height_scale = static_cast(input_height - 1) / (output_height - 1); + } + if (align_corners && output_width > 1) { + width_scale = static_cast(input_width - 1) / (output_width - 1); + } ResizeBilinearGeneric(input_data, input_dims, output_data, output_dims, batches, input_height, input_width, depth, @@ -3440,6 +3496,15 @@ inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims, } } +// legacy, for compatibility with old checked-in code +inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims, + const int32* output_size_data, + const Dims<4>& output_size_dims, float* output_data, + const Dims<4>& output_dims) { + ResizeBilinear(input_data, input_dims, output_size_data, output_size_dims, + output_data, output_dims, /*align_corners=*/false); +} + template inline void SpaceToBatchND(const T* input_data, const Dims<4>& input_dims, const int32* block_shape_data, diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h b/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h index f8be99e82fb8721ced7a3e5da686b20ce241ea2d..4e324a5e107cf5a90c0042331899edab831c8e51 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h @@ -15,7 +15,7 @@ limitations under the License. #ifndef TF_LITE_KERNELS_INTERNAL_OPTIMIZED_TENSOR_UTILS_IMPL_H_ #define TF_LITE_KERNELS_INTERNAL_OPTIMIZED_TENSOR_UTILS_IMPL_H_ -// TDOD(ghodrat): Remove this header file and the dependency to internal data +// TODO(ghodrat): Remove this header file and the dependency to internal data // structure. #include "tensorflow/contrib/lite/builtin_op_data.h" diff --git a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h index afc3e26e7988a369fb777ae99c08c4e98f26ebb8..c05c21b472b05f2cbe133adf94d91ab0c6d9ef40 100644 --- a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h +++ b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h @@ -15,7 +15,7 @@ limitations under the License. #ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_PORTABLE_TENSOR_UTILS_H_ #define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_PORTABLE_TENSOR_UTILS_H_ -// TDOD(ghodrat): Remove this header file and the dependency to internal data +// TODO(ghodrat): Remove this header file and the dependency to internal data // structure. #include "tensorflow/contrib/lite/builtin_op_data.h" diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h index 31bade26f98274e64fc7e224a16d5b78bc8bbe68..40e5c48a4ca6047a5c1cad0c8805ac8f281b357c 100644 --- a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h +++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h @@ -2043,6 +2043,55 @@ inline void Tanh(const float* input_data, const Dims<4>& input_dims, } } +inline void Tanh(const uint8* input_data, const Dims<4>& input_dims, + int32 input_zero_point, int32 input_range_radius, + int32 input_multiplier, int input_left_shift, + uint8* output_data, const Dims<4>& output_dims) { + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int height = MatchingArraySize(input_dims, 2, output_dims, 2); + const int width = MatchingArraySize(input_dims, 1, output_dims, 1); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + for (int c = 0; c < depth; ++c) { + const uint8 input_val_u8 = input_data[Offset(input_dims, c, x, y, b)]; + const int32 input_val_centered = + static_cast(input_val_u8) - input_zero_point; + uint8 output_val; + if (input_val_centered <= -input_range_radius) { + output_val = 0; + } else if (input_val_centered >= input_range_radius) { + output_val = 255; + } else { + const int32 input_val_rescaled = + MultiplyByQuantizedMultiplierGreaterThanOne( + input_val_centered, input_multiplier, input_left_shift); + using FixedPoint4 = gemmlowp::FixedPoint; + using FixedPoint0 = gemmlowp::FixedPoint; + const FixedPoint4 input_val_f4 = + FixedPoint4::FromRaw(input_val_rescaled); + const FixedPoint0 output_val_f0 = gemmlowp::tanh(input_val_f4); + + using gemmlowp::RoundingDivideByPOT; + int32 output_val_s32 = RoundingDivideByPOT(output_val_f0.raw(), 24); + // TODO(mjmatthews): properly wire through this zero offset + output_val_s32 += 127; + if (output_val_s32 == -1) { + // May underflow since we cannot properly represent -1.0f + output_val_s32 = 0; + } + TFLITE_DCHECK_GE(output_val_s32, 0); + TFLITE_DCHECK_LE(output_val_s32, 255); + output_val = static_cast(output_val_s32); + } + output_data[Offset(output_dims, c, x, y, b)] = output_val; + } + } + } + } +} + inline void Dequantize(const uint8* input_data, const Dims<4>& input_dims, int32 zero_point, double scale, float* output_data, const Dims<4>& output_dims) { @@ -2202,7 +2251,7 @@ inline void Gather(const T* input_data, const Dims<4>& input_dims, inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims, const int32* output_size_data, const Dims<4>& output_size_dims, float* output_data, - const Dims<4>& output_dims) { + const Dims<4>& output_dims, bool align_corners) { int32 batches = MatchingArraySize(input_dims, 3, output_dims, 3); int32 input_height = ArraySize(input_dims, 2); int32 input_width = ArraySize(input_dims, 1); @@ -2216,6 +2265,12 @@ inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims, int32 output_width = output_size_data[Offset(output_size_dims, 1, 0, 0, 0)]; float height_scale = static_cast(input_height) / output_height; float width_scale = static_cast(input_width) / output_width; + if (align_corners && output_height > 1) { + height_scale = static_cast(input_height - 1) / (output_height - 1); + } + if (align_corners && output_width > 1) { + width_scale = static_cast(input_width - 1) / (output_width - 1); + } for (int b = 0; b < batches; ++b) { for (int y = 0; y < output_height; ++y) { @@ -2243,6 +2298,15 @@ inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims, } } +// legacy, for compatibility with old checked-in code +inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims, + const int32* output_size_data, + const Dims<4>& output_size_dims, float* output_data, + const Dims<4>& output_dims) { + ResizeBilinear(input_data, input_dims, output_size_data, output_size_dims, + output_data, output_dims, /*align_corners=*/false); +} + template inline void SpaceToBatchND(const T* input_data, const Dims<4>& input_dims, const int32* block_shape_data, @@ -2370,13 +2434,15 @@ inline int StartIndex(int start, int stride, int dim, bool masked) { return masked ? (stride > 0 ? 0 : dim - 1) : start; } -inline int StopIndex(int stop, int stride, int dim, bool masked) { - return masked ? (stride > 0 ? dim : -1) : stop; +inline int StopIndex(int start, int stop, int stride, int dim, bool masked, + bool shrink_axis_masked) { + return shrink_axis_masked ? stride > 0 ? start + 1 : start - 1 + : masked ? (stride > 0 ? dim : -1) : stop; } template inline void StridedSlice(const T* input_data, const Dims<4>& input_dims, - int begin_mask, int end_mask, + int begin_mask, int end_mask, int shrink_axis_mask, const std::vector& starts, const std::vector& stops, const std::vector& strides, T* output_data, @@ -2387,19 +2453,23 @@ inline void StridedSlice(const T* input_data, const Dims<4>& input_dims, const int start_b = StartIndex(starts[3], strides[3], input_dims.sizes[3], begin_mask & 8); const int stop_b = - StopIndex(stops[3], strides[3], input_dims.sizes[3], end_mask & 8); + StopIndex(start_b, stops[3], strides[3], input_dims.sizes[3], + end_mask & 8, shrink_axis_mask & 8); const int start_h = StartIndex(starts[2], strides[2], input_dims.sizes[2], begin_mask & 4); const int stop_h = - StopIndex(stops[2], strides[2], input_dims.sizes[2], end_mask & 4); + StopIndex(start_h, stops[2], strides[2], input_dims.sizes[2], + end_mask & 4, shrink_axis_mask & 4); const int start_w = StartIndex(starts[1], strides[1], input_dims.sizes[1], begin_mask & 2); const int stop_w = - StopIndex(stops[1], strides[1], input_dims.sizes[1], end_mask & 2); + StopIndex(start_w, stops[1], strides[1], input_dims.sizes[1], + end_mask & 2, shrink_axis_mask & 2); const int start_d = StartIndex(starts[0], strides[0], input_dims.sizes[0], begin_mask & 1); const int stop_d = - StopIndex(stops[0], strides[0], input_dims.sizes[0], end_mask & 1); + StopIndex(start_d, stops[0], strides[0], input_dims.sizes[0], + end_mask & 1, shrink_axis_mask & 1); T* out_ptr = output_data; for (int in_b = start_b; LoopCondition(in_b, stop_b, strides[3]); @@ -2417,6 +2487,18 @@ inline void StridedSlice(const T* input_data, const Dims<4>& input_dims, } } +template +inline void StridedSlice(const T* input_data, const Dims<4>& input_dims, + int begin_mask, int end_mask, + const std::vector& starts, + const std::vector& stops, + const std::vector& strides, T* output_data, + const Dims<4>& output_dims) { + StridedSlice(input_data, input_dims, begin_mask, end_mask, + /*shrink_axis_mask=*/0, starts, stops, strides, output_data, + output_dims); +} + template inline void Slice(const T* input_data, const Dims<4>& input_dims, const std::vector& begin, const std::vector& size, diff --git a/tensorflow/contrib/lite/kernels/kernel_util.h b/tensorflow/contrib/lite/kernels/kernel_util.h index 3cfa72615a95d6f215ef9d35f2572026ec90fad8..28f53b9fbbc5620f2fab5c73e40bed8af4af5f1e 100644 --- a/tensorflow/contrib/lite/kernels/kernel_util.h +++ b/tensorflow/contrib/lite/kernels/kernel_util.h @@ -65,7 +65,10 @@ inline bool IsDynamicTensor(TfLiteTensor* tensor) { // Sets tensor to dynamic. inline void SetTensorToDynamic(TfLiteTensor* tensor) { - tensor->allocation_type = kTfLiteDynamic; + if (tensor->allocation_type != kTfLiteDynamic) { + tensor->allocation_type = kTfLiteDynamic; + tensor->data.raw = nullptr; + } } // Calculates the multiplication factor for a quantized convolution (or diff --git a/tensorflow/contrib/lite/kernels/mean.cc b/tensorflow/contrib/lite/kernels/mean.cc index 540e5a364dd60a42c316199d0ebe878ae07e6756..ec1c40202761e3789462a4740e5547eba654b0f9 100644 --- a/tensorflow/contrib/lite/kernels/mean.cc +++ b/tensorflow/contrib/lite/kernels/mean.cc @@ -35,10 +35,12 @@ struct MeanContext { MeanContext(TfLiteContext* context, TfLiteNode* node) { params = reinterpret_cast(node->builtin_data); input = GetInput(context, node, 0); + axis = GetInput(context, node, 1); output = GetOutput(context, node, 0); } TfLiteMeanParams* params; TfLiteTensor* input; + TfLiteTensor* axis; TfLiteTensor* output; }; @@ -54,45 +56,26 @@ void Free(TfLiteContext* context, void* buffer) { delete reinterpret_cast(buffer); } -TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { - TF_LITE_ENSURE(context, NumInputs(node) == 1 || NumInputs(node) == 2); - TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); - - MeanContext op_context(context, node); - int input_num_dims = NumDimensions(op_context.input); - int axis_num_dims = op_context.params->num_axis_dimensions; - - // Creates a temp index to iterate through input data. - int* scratch_tensor_index = reinterpret_cast(node->user_data); - TfLiteIntArrayFree(node->temporaries); - node->temporaries = TfLiteIntArrayCreate(2); - node->temporaries->data[0] = *scratch_tensor_index; - TfLiteTensor* scratch_tensor = &context->tensors[node->temporaries->data[0]]; - scratch_tensor->type = kTfLiteInt32; - scratch_tensor->allocation_type = kTfLiteArenaRw; - TfLiteIntArray* index_size = TfLiteIntArrayCreate(1); - index_size->data[0] = input_num_dims; - TF_LITE_ENSURE_OK(context, - context->ResizeTensor(context, scratch_tensor, index_size)); - - // Creates a temp tensor to store resolved axis given input data. - node->temporaries->data[1] = *scratch_tensor_index + 1; - TfLiteTensor* axis_tensor = &context->tensors[node->temporaries->data[1]]; - axis_tensor->type = kTfLiteInt32; - axis_tensor->allocation_type = kTfLiteArenaRw; +// Resizes the temp tensor that stores resolved axis. +TfLiteStatus ResizeTempAxis(TfLiteContext* context, MeanContext* op_context, + TfLiteTensor* resolved_axis) { TfLiteIntArray* axis_size = TfLiteIntArrayCreate(1); - axis_size->data[0] = op_context.params->num_axis_dimensions; - TF_LITE_ENSURE_OK(context, - context->ResizeTensor(context, axis_tensor, axis_size)); + axis_size->data[0] = static_cast(NumElements(op_context->axis)); + return context->ResizeTensor(context, resolved_axis, axis_size); +} - // Determines size of output tensor. - const TfLiteIntArray* input_dims = op_context.input->dims; - const int* axis = op_context.params->axis; - if (op_context.params->keep_dims) { +// Resizes output array based on the input size and resolved axis. +TfLiteStatus ResizeOutputTensor(TfLiteContext* context, + MeanContext* op_context) { + size_t num_axis = NumElements(op_context->axis); + const TfLiteIntArray* input_dims = op_context->input->dims; + int input_num_dims = NumDimensions(op_context->input); + const int* axis = GetTensorData(op_context->axis); + if (op_context->params->keep_dims) { TfLiteIntArray* output_dims = TfLiteIntArrayCreate(input_num_dims); for (int idx = 0; idx < input_num_dims; ++idx) { bool is_axis = false; - for (int axis_idx = 0; axis_idx < axis_num_dims; ++axis_idx) { + for (int axis_idx = 0; axis_idx < num_axis; ++axis_idx) { if (axis[axis_idx] == idx || axis[axis_idx] + input_num_dims == idx) { is_axis = true; break; @@ -104,11 +87,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { output_dims->data[idx] = input_dims->data[idx]; } } - return context->ResizeTensor(context, op_context.output, output_dims); + return context->ResizeTensor(context, op_context->output, output_dims); } else { // Calculates size of reducing axis. - int num_reduce_axis = axis_num_dims; - for (int i = 0; i < axis_num_dims; ++i) { + int num_reduce_axis = num_axis; + for (int i = 0; i < num_axis; ++i) { int current = axis[i]; if (current < 0) { current += input_num_dims; @@ -131,7 +114,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { int num_skip_axis = 0; for (int idx = 0; idx < input_num_dims; ++idx) { bool is_axis = false; - for (int axis_idx = 0; axis_idx < axis_num_dims; ++axis_idx) { + for (int axis_idx = 0; axis_idx < num_axis; ++axis_idx) { if (axis[axis_idx] == idx || axis[axis_idx] + input_num_dims == idx) { ++num_skip_axis; is_axis = true; @@ -142,24 +125,76 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { output_dims->data[idx - num_skip_axis] = input_dims->data[idx]; } } - return context->ResizeTensor(context, op_context.output, output_dims); + return context->ResizeTensor(context, op_context->output, output_dims); + } +} + +// Initializes temp tensors to store index and resolved axis. +TfLiteStatus InitializeTemporaries(TfLiteContext* context, TfLiteNode* node, + MeanContext* op_context) { + // Creates a temp index to iterate through input data. + int* scratch_tensor_index = reinterpret_cast(node->user_data); + TfLiteIntArrayFree(node->temporaries); + node->temporaries = TfLiteIntArrayCreate(2); + node->temporaries->data[0] = *scratch_tensor_index; + TfLiteTensor* scratch_tensor = &context->tensors[node->temporaries->data[0]]; + scratch_tensor->type = kTfLiteInt32; + scratch_tensor->allocation_type = kTfLiteArenaRw; + TfLiteIntArray* index_size = TfLiteIntArrayCreate(1); + index_size->data[0] = NumDimensions(op_context->input); + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, scratch_tensor, index_size)); + + // Creates a temp tensor to store resolved axis given input data. + node->temporaries->data[1] = *scratch_tensor_index + 1; + TfLiteTensor* resolved_axis = &context->tensors[node->temporaries->data[1]]; + resolved_axis->type = kTfLiteInt32; + return kTfLiteOk; +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + MeanContext op_context(context, node); + TF_LITE_ENSURE_OK(context, InitializeTemporaries(context, node, &op_context)); + + TfLiteTensor* resolved_axis = &context->tensors[node->temporaries->data[1]]; + // Leaves work to Eval if axis is not constant; else resizes output. + if (!IsConstantTensor(op_context.axis)) { + SetTensorToDynamic(op_context.output); + SetTensorToDynamic(resolved_axis); + return kTfLiteOk; } + resolved_axis->allocation_type = kTfLiteArenaRw; + TF_LITE_ENSURE_OK(context, + ResizeTempAxis(context, &op_context, resolved_axis)); + return ResizeOutputTensor(context, &op_context); } template TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { MeanContext op_context(context, node); + int num_axis = static_cast(NumElements(op_context.axis)); TfLiteTensor* temp_index = &context->tensors[node->temporaries->data[0]]; TfLiteTensor* resolved_axis = &context->tensors[node->temporaries->data[1]]; + // Resize the output tensor if the output tensor is dynamic. + if (IsDynamicTensor(op_context.output)) { + TF_LITE_ENSURE_OK(context, + ResizeTempAxis(context, &op_context, resolved_axis)); + TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context)); + TfLiteTensorRealloc(resolved_axis->bytes, resolved_axis); + TfLiteTensorRealloc(op_context.output->bytes, op_context.output); + } -#define TF_LITE_MEAN(kernel_type, data_type) \ - kernel_type::Mean<>( \ - GetTensorData(op_context.input), \ - op_context.input->dims->data, op_context.input->dims->size, \ - GetTensorData(op_context.output), \ - op_context.output->dims->data, op_context.output->dims->size, \ - op_context.params->axis, op_context.params->num_axis_dimensions, \ - op_context.params->keep_dims, GetTensorData(temp_index), \ +#define TF_LITE_MEAN(kernel_type, data_type) \ + kernel_type::Mean<>( \ + GetTensorData(op_context.input), \ + op_context.input->dims->data, op_context.input->dims->size, \ + GetTensorData(op_context.output), \ + op_context.output->dims->data, op_context.output->dims->size, \ + GetTensorData(op_context.axis), num_axis, \ + op_context.params->keep_dims, GetTensorData(temp_index), \ GetTensorData(resolved_axis)) if (kernel_type == kReference) { diff --git a/tensorflow/contrib/lite/kernels/mean_test.cc b/tensorflow/contrib/lite/kernels/mean_test.cc index 4305c0632f5a52b858a056109187ad4a0cc2e46e..c4c53c2ded351849e7c458fc754c36395a25ebd0 100644 --- a/tensorflow/contrib/lite/kernels/mean_test.cc +++ b/tensorflow/contrib/lite/kernels/mean_test.cc @@ -25,58 +25,108 @@ using ::testing::ElementsAreArray; class BaseMeanOpModel : public SingleOpModel { public: - BaseMeanOpModel(const TensorData& input, const TensorData& output, - std::initializer_list axis, bool keep_dims) { - input_ = AddInput(input); - output_ = AddOutput(output); - SetBuiltinOp( - BuiltinOperator_MEAN, BuiltinOptions_MeanOptions, - CreateMeanOptions(builder_, builder_.CreateVector(axis), keep_dims) - .Union()); - BuildInterpreter({GetShape(input_)}); + void SetAxis(std::initializer_list data) { PopulateTensor(axis_, data); } + + template + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); } - int input() { return input_; } + template + std::vector GetOutput() { + return ExtractVector(output_); + } + + std::vector GetOutputShape() { return GetTensorShape(output_); } protected: int input_; + int axis_; int output_; }; -class FloatMeanOpModel : public BaseMeanOpModel { +// Model for the tests case where axis is a const tensor. +class MeanOpConstModel : public BaseMeanOpModel { public: - using BaseMeanOpModel::BaseMeanOpModel; - - void SetInput(std::initializer_list data) { - PopulateTensor(input_, data); + MeanOpConstModel(const TensorData& input, const TensorData& output, + std::initializer_list axis_shape, + std::initializer_list axis, bool keep_dims) { + input_ = AddInput(input); + axis_ = AddConstInput(TensorType_INT32, axis, axis_shape); + output_ = AddOutput(output); + SetBuiltinOp(BuiltinOperator_MEAN, BuiltinOptions_MeanOptions, + CreateMeanOptions(builder_, keep_dims).Union()); + BuildInterpreter({GetShape(input_)}); } +}; - std::vector GetOutput() { return ExtractVector(output_); } - std::vector GetOutputShape() { return GetTensorShape(output_); } +// Model for the tests case where axis is a dynamic tensor. +class MeanOpDynamicModel : public BaseMeanOpModel { + public: + MeanOpDynamicModel(const TensorData& input, const TensorData& output, + const TensorData& axis, bool keep_dims) { + input_ = AddInput(input); + axis_ = AddInput(axis); + output_ = AddOutput(output); + SetBuiltinOp(BuiltinOperator_MEAN, BuiltinOptions_MeanOptions, + CreateMeanOptions(builder_, keep_dims).Union()); + BuildInterpreter({GetShape(input_)}); + } }; -TEST(FloatMeanOpTest, NotKeepDims) { +TEST(ConstMeanOpTest, NotKeepDims) { + std::initializer_list data = { + 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, + 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}; + MeanOpConstModel m({TensorType_FLOAT32, {4, 3, 2}}, {TensorType_FLOAT32, {2}}, + {4}, {1, 0, -3, -3}, false); + m.SetInput(data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({12, 13}))); +} + +TEST(ConstMeanOpTest, KeepDims) { + std::initializer_list data = { + 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, + 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}; + MeanOpConstModel m({TensorType_FLOAT32, {4, 3, 2}}, {TensorType_FLOAT32, {3}}, + {2}, {0, 2}, true); + m.SetInput(data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3, 1})); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear({10.5, 12.5, 14.5}))); +} + +TEST(DynamicMeanOpTest, NotKeepDims) { std::initializer_list data = { 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}; - FloatMeanOpModel m({TensorType_FLOAT32, {4, 3, 2}}, {TensorType_FLOAT32, {2}}, - {1, 0, -3, -3}, false); + MeanOpDynamicModel m({TensorType_FLOAT32, {4, 3, 2}}, + {TensorType_FLOAT32, {2}}, {TensorType_INT32, {4}}, + false); + std::initializer_list axis = {1, 0, -3, -3}; + m.SetAxis(axis); m.SetInput(data); m.Invoke(); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2})); - EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({12, 13}))); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({12, 13}))); } -TEST(FloatMeanOpTest, KeepDims) { +TEST(DynamicMeanOpTest, KeepDims) { std::initializer_list data = { 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}; - FloatMeanOpModel m({TensorType_FLOAT32, {4, 3, 2}}, {TensorType_FLOAT32, {3}}, - {0, 2}, true); + MeanOpDynamicModel m({TensorType_FLOAT32, {4, 3, 2}}, + {TensorType_FLOAT32, {3}}, {TensorType_INT32, {2}}, + true); + std::initializer_list axis = {0, 2}; + m.SetAxis(axis); m.SetInput(data); m.Invoke(); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3, 1})); - EXPECT_THAT(m.GetOutput(), + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({10.5, 12.5, 14.5}))); } diff --git a/tensorflow/contrib/lite/kernels/mul.cc b/tensorflow/contrib/lite/kernels/mul.cc index 81c73f2523186c2d4072d56bdc8980fcdbb588a3..54575019de4c678ce25561cf2ac8dc80c9973363 100644 --- a/tensorflow/contrib/lite/kernels/mul.cc +++ b/tensorflow/contrib/lite/kernels/mul.cc @@ -37,7 +37,23 @@ constexpr int kInputTensor1 = 0; constexpr int kInputTensor2 = 1; constexpr int kOutputTensor = 0; +struct OpData { + bool requires_broadcast; +}; + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + auto* data = new OpData; + data->requires_broadcast = false; + return data; +} + +void Free(TfLiteContext* context, void* buffer) { + delete reinterpret_cast(buffer); +} + TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + OpData* data = reinterpret_cast(node->user_data); + TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); @@ -45,43 +61,56 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); - TF_LITE_ENSURE_EQ(context, NumDimensions(input1), NumDimensions(input2)); - for (int i = 0; i < NumDimensions(input1); ++i) { - TF_LITE_ENSURE_EQ(context, SizeOfDimension(input1, i), - SizeOfDimension(input2, i)); - } + TF_LITE_ENSURE_EQ(context, input1->type, input2->type); + output->type = input2->type; + + data->requires_broadcast = !HaveSameShapes(input1, input2); - TF_LITE_ENSURE_EQ(context, input1->type, output->type); - TF_LITE_ENSURE_EQ(context, input2->type, output->type); + TfLiteIntArray* output_size = nullptr; + if (data->requires_broadcast) { + TF_LITE_ENSURE_OK(context, CalculateShapeForBroadcast( + context, input1, input2, &output_size)); + } else { + output_size = TfLiteIntArrayCopy(input1->dims); + } - TfLiteIntArray* output_size = TfLiteIntArrayCopy(input1->dims); return context->ResizeTensor(context, output, output_size); } template void EvalFloat(TfLiteContext* context, TfLiteNode* node, - TfLiteMulParams* params, TfLiteTensor* input1, - TfLiteTensor* input2, TfLiteTensor* output) { + TfLiteMulParams* params, const OpData* data, + TfLiteTensor* input1, TfLiteTensor* input2, + TfLiteTensor* output) { float output_activation_min, output_activation_max; CalculateActivationRangeFloat(params->activation, &output_activation_min, &output_activation_max); -#define TF_LITE_MUL(type) \ - type::Mul(GetTensorData(input1), GetTensorDims(input1), \ - GetTensorData(input2), GetTensorDims(input2), \ - output_activation_min, output_activation_max, \ - GetTensorData(output), GetTensorDims(output)) +#define TF_LITE_MUL(type, opname) \ + type::opname(GetTensorData(input1), GetTensorDims(input1), \ + GetTensorData(input2), GetTensorDims(input2), \ + output_activation_min, output_activation_max, \ + GetTensorData(output), GetTensorDims(output)) if (kernel_type == kReference) { - TF_LITE_MUL(reference_ops); + if (data->requires_broadcast) { + TF_LITE_MUL(reference_ops, BroadcastMul); + } else { + TF_LITE_MUL(reference_ops, Mul); + } } else { - TF_LITE_MUL(optimized_ops); + if (data->requires_broadcast) { + TF_LITE_MUL(optimized_ops, BroadcastMul); + } else { + TF_LITE_MUL(optimized_ops, Mul); + } } #undef TF_LITE_MUL } template void EvalQuantized(TfLiteContext* context, TfLiteNode* node, - TfLiteMulParams* params, TfLiteTensor* input1, - TfLiteTensor* input2, TfLiteTensor* output) { + TfLiteMulParams* params, const OpData* data, + TfLiteTensor* input1, TfLiteTensor* input2, + TfLiteTensor* output) { auto input1_offset = -input1->params.zero_point; auto input2_offset = -input2->params.zero_point; auto output_offset = output->params.zero_point; @@ -98,17 +127,19 @@ void EvalQuantized(TfLiteContext* context, TfLiteNode* node, CalculateActivationRangeUint8(params->activation, output, &output_activation_min, &output_activation_max); -#define TF_LITE_MUL(type) \ - type::BroadcastMul(GetTensorData(input1), GetTensorDims(input1), \ - input1_offset, GetTensorData(input2), \ - GetTensorDims(input2), input2_offset, output_offset, \ - output_multiplier, output_shift, output_activation_min, \ - output_activation_max, GetTensorData(output), \ - GetTensorDims(output)); +#define TF_LITE_MUL(type, opname) \ + type::opname(GetTensorData(input1), GetTensorDims(input1), \ + input1_offset, GetTensorData(input2), \ + GetTensorDims(input2), input2_offset, output_offset, \ + output_multiplier, output_shift, output_activation_min, \ + output_activation_max, GetTensorData(output), \ + GetTensorDims(output)); + // The quantized version of Mul doesn't support activations, so we + // always use BroadcastMul. if (kernel_type == kReference) { - TF_LITE_MUL(reference_ops); + TF_LITE_MUL(reference_ops, BroadcastMul); } else { - TF_LITE_MUL(optimized_ops); + TF_LITE_MUL(optimized_ops, BroadcastMul); } #undef TF_LITE_MUL } @@ -116,15 +147,17 @@ void EvalQuantized(TfLiteContext* context, TfLiteNode* node, template TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { auto* params = reinterpret_cast(node->builtin_data); + OpData* data = reinterpret_cast(node->user_data); TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); if (output->type == kTfLiteFloat32) { - EvalFloat(context, node, params, input1, input2, output); + EvalFloat(context, node, params, data, input1, input2, output); } else if (output->type == kTfLiteUInt8) { - EvalQuantized(context, node, params, input1, input2, output); + EvalQuantized(context, node, params, data, input1, input2, + output); } else { context->ReportError(context, "Mul only supports FLOAT32 and quantized UINT8 now."); @@ -137,19 +170,19 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } // namespace mul TfLiteRegistration* Register_MUL_REF() { - static TfLiteRegistration r = {nullptr, nullptr, mul::Prepare, + static TfLiteRegistration r = {mul::Init, mul::Free, mul::Prepare, mul::Eval}; return &r; } TfLiteRegistration* Register_MUL_GENERIC_OPT() { - static TfLiteRegistration r = {nullptr, nullptr, mul::Prepare, + static TfLiteRegistration r = {mul::Init, mul::Free, mul::Prepare, mul::Eval}; return &r; } TfLiteRegistration* Register_MUL_NEON_OPT() { - static TfLiteRegistration r = {nullptr, nullptr, mul::Prepare, + static TfLiteRegistration r = {mul::Init, mul::Free, mul::Prepare, mul::Eval}; return &r; } diff --git a/tensorflow/contrib/lite/kernels/mul_test.cc b/tensorflow/contrib/lite/kernels/mul_test.cc index 8838b300c0af167bf2ffcf944fc7c31d6173f462..f1a30f82634631ba8320421d5b36ffe446f443fa 100644 --- a/tensorflow/contrib/lite/kernels/mul_test.cc +++ b/tensorflow/contrib/lite/kernels/mul_test.cc @@ -25,10 +25,11 @@ using ::testing::ElementsAreArray; class BaseMulOpModel : public SingleOpModel { public: - BaseMulOpModel(TensorData input, TensorData output, + BaseMulOpModel(const TensorData& input1, const TensorData& input2, + const TensorData& output, ActivationFunctionType activation_type) { - input1_ = AddInput(input); - input2_ = AddInput(input); + input1_ = AddInput(input1); + input2_ = AddInput(input2); output_ = AddOutput(output); SetBuiltinOp(BuiltinOperator_MUL, BuiltinOptions_MulOptions, CreateMulOptions(builder_, activation_type).Union()); @@ -70,6 +71,7 @@ class QuantizedMulOpModel : public BaseMulOpModel { TEST(FloatMulOpTest, NoActivation) { FloatMulOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}}, + {TensorType_FLOAT32, {1, 2, 2, 1}}, {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE); m.PopulateTensor(m.input1(), {-2.0, 0.2, 0.7, 0.8}); m.PopulateTensor(m.input2(), {0.1, 0.2, 0.3, 0.5}); @@ -79,9 +81,9 @@ TEST(FloatMulOpTest, NoActivation) { } TEST(FloatMulOpTest, ActivationRELU_N1_TO_1) { - FloatMulOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}}, - {TensorType_FLOAT32, {}}, - ActivationFunctionType_RELU_N1_TO_1); + FloatMulOpModel m( + {TensorType_FLOAT32, {1, 2, 2, 1}}, {TensorType_FLOAT32, {1, 2, 2, 1}}, + {TensorType_FLOAT32, {}}, ActivationFunctionType_RELU_N1_TO_1); m.PopulateTensor(m.input1(), {-2.0, 0.2, 0.7, 0.8}); m.PopulateTensor(m.input2(), {0.1, 0.2, 0.3, 5}); m.Invoke(); @@ -94,6 +96,7 @@ TEST(FloatMulOpTest, VariousInputShapes) { {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}}; for (int i = 0; i < test_shapes.size(); ++i) { FloatMulOpModel m({TensorType_FLOAT32, test_shapes[i]}, + {TensorType_FLOAT32, test_shapes[i]}, {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE); m.PopulateTensor(m.input1(), {-2.0, 0.2, 0.7, 0.8, 1.1, 2.0}); m.PopulateTensor(m.input2(), {0.1, 0.2, 0.3, 0.5, 1.1, 0.1}); @@ -105,8 +108,26 @@ TEST(FloatMulOpTest, VariousInputShapes) { } } +TEST(FloatMulOpTest, WithBroadcast) { + std::vector> test_shapes = { + {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}}; + for (int i = 0; i < test_shapes.size(); ++i) { + FloatMulOpModel m({TensorType_FLOAT32, test_shapes[i]}, + {TensorType_FLOAT32, {}}, // always a scalar + {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE); + m.PopulateTensor(m.input1(), {-2.0, 0.2, 0.7, 0.8, 1.1, 2.0}); + m.PopulateTensor(m.input2(), {0.1}); + m.Invoke(); + EXPECT_THAT( + m.GetOutput(), + ElementsAreArray(ArrayFloatNear({-0.2, 0.02, 0.07, 0.08, 0.11, 0.2}))) + << "With shape number " << i; + } +} + TEST(QuantizedMulOpTest, NoActivation) { QuantizedMulOpModel m({TensorType_UINT8, {1, 2, 2, 1}, -1.0, 1.0}, + {TensorType_UINT8, {1, 2, 2, 1}, -1.0, 1.0}, {TensorType_UINT8, {}, -1.0, 1.0}, ActivationFunctionType_NONE); m.QuantizeAndPopulate(m.input1(), {-0.8, 0.2, 0.9, 0.7}); @@ -117,6 +138,32 @@ TEST(QuantizedMulOpTest, NoActivation) { kQuantizedTolerance))); } +// for quantized Mul, the error shouldn't exceed 2*step +float GetTolerance(int min, int max) { + float kQuantizedStep = (max - min) / 255.0; + float kQuantizedTolerance = 2.0 * kQuantizedStep; + return kQuantizedTolerance; +} + +TEST(QuantizedMulOpTest, WithBroadcast) { + float kQuantizedTolerance = GetTolerance(-3.0, 3.0); + std::vector> test_shapes = { + {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}}; + for (int i = 0; i < test_shapes.size(); ++i) { + QuantizedMulOpModel m({TensorType_UINT8, test_shapes[i], -3.0, 3.0}, + {TensorType_UINT8, {}, -3.0, 3.0}, // always a scalar + {TensorType_UINT8, {}, -3.0, 3.0}, + ActivationFunctionType_NONE); + m.QuantizeAndPopulate(m.input1(), {-2.0, 0.2, 0.7, 0.8, 1.1, 2.0}); + m.QuantizeAndPopulate(m.input2(), {0.1}); + m.Invoke(); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear( + {-0.2, 0.02, 0.07, 0.08, 0.11, 0.2}, kQuantizedTolerance))) + << "With shape number " << i; + } +} + } // namespace } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/optional_tensor_test.cc b/tensorflow/contrib/lite/kernels/optional_tensor_test.cc index 17166715ca30ff3d8ba3d384110e403f8910e39d..cee3ec6197c698a11004d42dccdfe2bcca088015 100644 --- a/tensorflow/contrib/lite/kernels/optional_tensor_test.cc +++ b/tensorflow/contrib/lite/kernels/optional_tensor_test.cc @@ -243,7 +243,6 @@ class LSTMOpModel : public SingleOpModel { int n_output_; }; - TEST(LSTMOpTest, BlackBoxTestWithCifgWithPeepholeNoProjectionNoClipping) { const int n_batch = 1; const int n_input = 2; @@ -282,7 +281,6 @@ TEST(LSTMOpTest, BlackBoxTestWithCifgWithPeepholeNoProjectionNoClipping) { {0}, // projection_bias tensor }); - lstm.SetInputToCellWeights({-0.49770179, -0.27711356, -0.09624726, 0.05100781, 0.04717243, 0.48944736, -0.38535351, -0.17212132}); diff --git a/tensorflow/contrib/lite/kernels/pad.cc b/tensorflow/contrib/lite/kernels/pad.cc index 4003ed10df4df1e36fd654322a213f5513cafcaa..48114e5a4069abf864a996141c7b0906301d9809 100644 --- a/tensorflow/contrib/lite/kernels/pad.cc +++ b/tensorflow/contrib/lite/kernels/pad.cc @@ -177,9 +177,7 @@ TfLiteRegistration* Register_PAD_GENERIC_OPT() { return &r; } -TfLiteRegistration* Register_PAD() { - return Register_PAD_GENERIC_OPT(); -} +TfLiteRegistration* Register_PAD() { return Register_PAD_GENERIC_OPT(); } } // namespace builtin } // namespace ops diff --git a/tensorflow/contrib/lite/kernels/resize_bilinear.cc b/tensorflow/contrib/lite/kernels/resize_bilinear.cc index 9a419af0238e1a25e4b9e81f109b54de6b49097b..c5d60cae3ab0b203299e04a25c392519a2a23b75 100644 --- a/tensorflow/contrib/lite/kernels/resize_bilinear.cc +++ b/tensorflow/contrib/lite/kernels/resize_bilinear.cc @@ -36,6 +36,17 @@ constexpr int kInputTensor = 0; constexpr int kSizeTensor = 1; constexpr int kOutputTensor = 0; +TfLiteStatus ResizeOutputTensor(TfLiteContext* context, TfLiteTensor* input, + TfLiteTensor* size, TfLiteTensor* output) { + TfLiteIntArray* output_size = TfLiteIntArrayCreate(4); + output_size->data[0] = input->dims->data[0]; + const int32* size_data = GetTensorData(size); + output_size->data[1] = size_data[0]; + output_size->data[2] = size_data[1]; + output_size->data[3] = input->dims->data[3]; + return context->ResizeTensor(context, output, output_size); +} + TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); @@ -55,32 +66,34 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { // integers. output->type = kTfLiteFloat32; - // TODO(ahentz): if the input is constant, we can allocate here. - output->allocation_type = kTfLiteDynamic; - return kTfLiteOk; + if (!IsConstantTensor(size)) { + SetTensorToDynamic(output); + return kTfLiteOk; + } + return ResizeOutputTensor(context, input, size, output); } template TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + auto* params = + reinterpret_cast(node->builtin_data); + TfLiteTensor* input = GetInput(context, node, kInputTensor); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); TfLiteTensor* size = GetInput(context, node, kSizeTensor); - // TODO(ahentz): we only need to do this here if it wasn't done in Eval(). - TfLiteIntArray* output_size = TfLiteIntArrayCreate(4); - output_size->data[0] = input->dims->data[0]; - const int32* size_data = GetTensorData(size); - output_size->data[1] = size_data[0]; - output_size->data[2] = size_data[1]; - output_size->data[3] = input->dims->data[3]; - context->ResizeTensor(context, output, output_size); - TfLiteTensorRealloc(output->bytes, output); + if (IsDynamicTensor(output)) { + TF_LITE_ENSURE_OK(context, + ResizeOutputTensor(context, input, size, output)); + TfLiteTensorRealloc(output->bytes, output); + } if (output->type == kTfLiteFloat32) { -#define TF_LITE_RESIZE_BILINEAR(type) \ - type::ResizeBilinear(GetTensorData(input), GetTensorDims(input), \ - GetTensorData(size), GetTensorDims(size), \ - GetTensorData(output), GetTensorDims(output)) +#define TF_LITE_RESIZE_BILINEAR(type) \ + type::ResizeBilinear(GetTensorData(input), GetTensorDims(input), \ + GetTensorData(size), GetTensorDims(size), \ + GetTensorData(output), GetTensorDims(output), \ + params->align_corners) if (kernel_type == kReference) { TF_LITE_RESIZE_BILINEAR(reference_ops); diff --git a/tensorflow/contrib/lite/kernels/resize_bilinear_test.cc b/tensorflow/contrib/lite/kernels/resize_bilinear_test.cc index 2b1aaf654f87f435ec464b2cc1a63c77ba86ae5b..4e03f3820a5c14ee1692c553db61e385716b1723 100644 --- a/tensorflow/contrib/lite/kernels/resize_bilinear_test.cc +++ b/tensorflow/contrib/lite/kernels/resize_bilinear_test.cc @@ -25,14 +25,24 @@ using ::testing::ElementsAreArray; class ResizeBilinearOpModel : public SingleOpModel { public: - ResizeBilinearOpModel(std::initializer_list input_shape) { - input_ = AddInput(TensorType_FLOAT32); - size_ = AddInput(TensorType_INT32); - output_ = AddOutput(TensorType_FLOAT32); + ResizeBilinearOpModel(const TensorData& input, + std::initializer_list size_data = {}) { + bool const_size = size_data.size() != 0; + input_ = AddInput(input); + if (const_size) { + size_ = AddConstInput(TensorType_INT32, size_data, {2}); + } else { + size_ = AddInput({TensorType_INT32, {2}}); + } + output_ = AddOutput(TensorType_FLOAT32); // Always float. SetBuiltinOp(BuiltinOperator_RESIZE_BILINEAR, BuiltinOptions_ResizeBilinearOptions, CreateResizeBilinearOptions(builder_).Union()); - BuildInterpreter({input_shape, {2}}); + if (const_size) { + BuildInterpreter({GetShape(input_)}); + } else { + BuildInterpreter({GetShape(input_), GetShape(size_)}); + } } void SetInput(std::initializer_list data) { @@ -49,23 +59,33 @@ class ResizeBilinearOpModel : public SingleOpModel { }; TEST(ResizeBilinearOpTest, HorizontalResize) { - ResizeBilinearOpModel m({1, 1, 2, 1}); + ResizeBilinearOpModel m({TensorType_FLOAT32, {1, 1, 2, 1}}); m.SetInput({3, 6}); m.SetSize({1, 3}); m.Invoke(); EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({3, 5, 6}))); + + ResizeBilinearOpModel const_m({TensorType_FLOAT32, {1, 1, 2, 1}}, {1, 3}); + const_m.SetInput({3, 6}); + const_m.Invoke(); + EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({3, 5, 6}))); } TEST(ResizeBilinearOpTest, VerticalResize) { - ResizeBilinearOpModel m({1, 2, 1, 1}); + ResizeBilinearOpModel m({TensorType_FLOAT32, {1, 2, 1, 1}}); m.SetInput({3, 9}); m.SetSize({3, 1}); m.Invoke(); EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({3, 7, 9}))); + + ResizeBilinearOpModel const_m({TensorType_FLOAT32, {1, 2, 1, 1}}, {3, 1}); + const_m.SetInput({3, 9}); + const_m.Invoke(); + EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({3, 7, 9}))); } TEST(ResizeBilinearOpTest, TwoDimensionalResize) { - ResizeBilinearOpModel m({1, 2, 2, 1}); + ResizeBilinearOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}}); m.SetInput({ 3, 6, // 9, 12 // @@ -77,10 +97,22 @@ TEST(ResizeBilinearOpTest, TwoDimensionalResize) { 7, 9, 10, // 9, 11, 12, // }))); + + ResizeBilinearOpModel const_m({TensorType_FLOAT32, {1, 2, 2, 1}}, {3, 3}); + const_m.SetInput({ + 3, 6, // + 9, 12 // + }); + const_m.Invoke(); + EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 3, 5, 6, // + 7, 9, 10, // + 9, 11, 12, // + }))); } TEST(ResizeBilinearOpTest, TwoDimensionalResizeWithTwoBatches) { - ResizeBilinearOpModel m({2, 2, 2, 1}); + ResizeBilinearOpModel m({TensorType_FLOAT32, {2, 2, 2, 1}}); m.SetInput({ 3, 6, // 9, 12, // @@ -97,10 +129,27 @@ TEST(ResizeBilinearOpTest, TwoDimensionalResizeWithTwoBatches) { 8, 12, 14, // 10, 14, 16, // }))); + + ResizeBilinearOpModel const_m({TensorType_FLOAT32, {2, 2, 2, 1}}, {3, 3}); + const_m.SetInput({ + 3, 6, // + 9, 12, // + 4, 10, // + 10, 16 // + }); + const_m.Invoke(); + EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 3, 5, 6, // + 7, 9, 10, // + 9, 11, 12, // + 4, 8, 10, // + 8, 12, 14, // + 10, 14, 16, // + }))); } TEST(ResizeBilinearOpTest, ThreeDimensionalResize) { - ResizeBilinearOpModel m({1, 2, 2, 2}); + ResizeBilinearOpModel m({TensorType_FLOAT32, {1, 2, 2, 2}}); m.SetInput({ 3, 4, 6, 10, // 9, 10, 12, 16, // @@ -112,6 +161,18 @@ TEST(ResizeBilinearOpTest, ThreeDimensionalResize) { 7, 8, 9, 12, 10, 14, // 9, 10, 11, 14, 12, 16, // }))); + + ResizeBilinearOpModel const_m({TensorType_FLOAT32, {1, 2, 2, 2}}, {3, 3}); + const_m.SetInput({ + 3, 4, 6, 10, // + 9, 10, 12, 16, // + }); + const_m.Invoke(); + EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 3, 4, 5, 8, 6, 10, // + 7, 8, 9, 12, 10, 14, // + 9, 10, 11, 14, 12, 16, // + }))); } } // namespace diff --git a/tensorflow/contrib/lite/kernels/space_to_batch_nd.cc b/tensorflow/contrib/lite/kernels/space_to_batch_nd.cc index 2e22d0db56a233bf554c57cf86275832ce941a18..e2e1873f770fad889137b43d87585602162819f7 100644 --- a/tensorflow/contrib/lite/kernels/space_to_batch_nd.cc +++ b/tensorflow/contrib/lite/kernels/space_to_batch_nd.cc @@ -33,17 +33,16 @@ enum KernelType { kGenericOptimized, }; -// Inputs specified in the 2nd tensor (block_shape) and 3rd tensor (paddings) -// are ignored. Only use the `block_shape` and `paddings` specified in params. -// TODO(nupurgarg): Support inputs as tensors in SpaceToBatchND. struct SpaceToBatchNDContext { SpaceToBatchNDContext(TfLiteContext* context, TfLiteNode* node) { - params = reinterpret_cast(node->builtin_data); input = GetInput(context, node, 0); + block_shape = GetInput(context, node, 1); + paddings = GetInput(context, node, 2); output = GetOutput(context, node, 0); } - TfLiteSpaceToBatchNDParams* params; TfLiteTensor* input; + TfLiteTensor* block_shape; + TfLiteTensor* paddings; TfLiteTensor* output; }; @@ -51,32 +50,29 @@ struct SpaceToBatchNDContext { // The 4D array need to have exactly 2 spatial dimensions. // TODO(nupurgarg): Support arbitrary dimension in SpaceToBatchND. const int kInputDimensionNum = 4; -const int kOutputDimensionNum = 4; +const int kBlockSizeDimensionNum = 1; const int kSpatialDimensionNum = 2; -const int kPaddingDimensionNum = 4; -TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { - TF_LITE_ENSURE(context, NumInputs(node) >= 1 && NumInputs(node) <= 3); - TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); +TfLiteStatus ResizeOutputTensor(TfLiteContext* context, + SpaceToBatchNDContext* op_context) { + TfLiteIntArray* input_size = op_context->input->dims; + const int32* block_shape = GetTensorData(op_context->block_shape); + const int32* paddings_data = GetTensorData(op_context->paddings); - SpaceToBatchNDContext op_context(context, node); - TF_LITE_ENSURE_EQ(context, NumDimensions(op_context.input), - kInputDimensionNum); - TF_LITE_ENSURE_EQ(context, op_context.params->num_spatial_dimensions, + TF_LITE_ENSURE_EQ(context, NumDimensions(op_context->block_shape), + kBlockSizeDimensionNum); + TF_LITE_ENSURE_EQ(context, op_context->block_shape->dims->data[0], + kSpatialDimensionNum); + TF_LITE_ENSURE_EQ(context, NumDimensions(op_context->paddings), kSpatialDimensionNum); - TF_LITE_ENSURE_EQ(context, op_context.input->type, op_context.output->type); - - const TfLiteIntArray* input_size = op_context.input->dims; - const int* block_shape = op_context.params->block_shape; - TfLiteIntArray* output_size = TfLiteIntArrayCreate(kOutputDimensionNum); + TfLiteIntArray* output_size = TfLiteIntArrayCopy(input_size); // Ensures the input height and width (with padding) is a multiple of block // shape height and width. for (int dim = 0; dim < kSpatialDimensionNum; ++dim) { - int final_dim_size = - (input_size->data[dim + 1] + op_context.params->before_paddings[dim] + - op_context.params->after_paddings[dim]); + int final_dim_size = (input_size->data[dim + 1] + paddings_data[dim * 2] + + paddings_data[dim * 2 + 1]); TF_LITE_ENSURE_EQ(context, final_dim_size % block_shape[dim], 0); output_size->data[dim + 1] = final_dim_size / block_shape[dim]; } @@ -88,33 +84,44 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { output_size->data[0] = output_batch_size; output_size->data[3] = output_channel_size; - return context->ResizeTensor(context, op_context.output, output_size); + return context->ResizeTensor(context, op_context->output, output_size); +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 3); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + SpaceToBatchNDContext op_context(context, node); + TF_LITE_ENSURE_EQ(context, NumDimensions(op_context.input), + kInputDimensionNum); + TF_LITE_ENSURE_EQ(context, op_context.input->type, op_context.output->type); + + if (!IsConstantTensor(op_context.block_shape) || + !IsConstantTensor(op_context.paddings)) { + SetTensorToDynamic(op_context.output); + return kTfLiteOk; + } + return ResizeOutputTensor(context, &op_context); } template TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { SpaceToBatchNDContext op_context(context, node); - int block_shape_dims_array[1] = {kSpatialDimensionNum}; - Dims<4> block_shape_dims = GetTensorDims(block_shape_dims_array, 1); - - // Initialize padding array in the format accepted by the kernel code. - // TODO(nupurgarg): Make kernel code accept padding array format that is - // consistent with Pad operation (i.e. before_paddings and after_paddings). - TfLiteIntArray* padding_data = TfLiteIntArrayCreate(kPaddingDimensionNum); - padding_data->data[0] = op_context.params->before_paddings[0]; - padding_data->data[1] = op_context.params->after_paddings[0]; - padding_data->data[2] = op_context.params->before_paddings[1]; - padding_data->data[3] = op_context.params->after_paddings[1]; - int padding_dims_array[1] = {kPaddingDimensionNum}; - Dims<4> padding_dims = GetTensorDims(padding_dims_array, 1); - -#define TF_LITE_SPACE_TO_BATCH_ND(type, scalar) \ - type::SpaceToBatchND(GetTensorData(op_context.input), \ - GetTensorDims(op_context.input), \ - op_context.params->block_shape, block_shape_dims, \ - padding_data->data, padding_dims, \ - GetTensorData(op_context.output), \ + // Resize the output tensor if the output tensor is dynamic. + if (IsDynamicTensor(op_context.output)) { + TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context)); + TfLiteTensorRealloc(op_context.output->bytes, op_context.output); + } + +#define TF_LITE_SPACE_TO_BATCH_ND(type, scalar) \ + type::SpaceToBatchND(GetTensorData(op_context.input), \ + GetTensorDims(op_context.input), \ + GetTensorData(op_context.block_shape), \ + GetTensorDims(op_context.block_shape), \ + GetTensorData(op_context.paddings), \ + GetTensorDims(op_context.paddings), \ + GetTensorData(op_context.output), \ GetTensorDims(op_context.output)) switch (op_context.input->type) { // Already know in/out types are same. case kTfLiteFloat32: @@ -151,8 +158,6 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { return kTfLiteError; } #undef TF_LITE_SPACE_TO_BATCH_ND - - TfLiteIntArrayFree(padding_data); return kTfLiteOk; } diff --git a/tensorflow/contrib/lite/kernels/space_to_batch_nd_test.cc b/tensorflow/contrib/lite/kernels/space_to_batch_nd_test.cc index 45a6aef73d05b57a7f9a7fc6f58c3971c6e03118..92a4a037d5873e608ee7bdbdfc5eaa5e9b62bc8c 100644 --- a/tensorflow/contrib/lite/kernels/space_to_batch_nd_test.cc +++ b/tensorflow/contrib/lite/kernels/space_to_batch_nd_test.cc @@ -26,41 +26,81 @@ using ::testing::ElementsAreArray; class SpaceToBatchNDOpModel : public SingleOpModel { public: - SpaceToBatchNDOpModel(std::initializer_list input_shape, - std::initializer_list block_shape, - std::initializer_list before_paddings, - std::initializer_list after_paddings) { - input_ = AddInput(TensorType_FLOAT32); - output_ = AddOutput(TensorType_FLOAT32); - SetBuiltinOp(BuiltinOperator_SPACE_TO_BATCH_ND, - BuiltinOptions_SpaceToBatchNDOptions, - CreateSpaceToBatchNDOptions( - builder_, builder_.CreateVector(block_shape), - builder_.CreateVector(before_paddings), - builder_.CreateVector(after_paddings)) - .Union()); - BuildInterpreter({input_shape}); - } - void SetInput(std::initializer_list data) { PopulateTensor(input_, data); } + void SetBlockShape(std::initializer_list data) { + PopulateTensor(block_shape_, data); + } + + void SetPaddings(std::initializer_list data) { + PopulateTensor(paddings_, data); + } + std::vector GetOutput() { return ExtractVector(output_); } std::vector GetOutputShape() { return GetTensorShape(output_); } - private: + protected: int input_; + int block_shape_; + int paddings_; int output_; }; +// Tests case where block_shape and paddings are const tensors. +// +// Example usage is as follows: +// SpaceToBatchNDOpConstModel m(input_shape, block_shape, paddings); +// m.SetInput(input_data); +// m.Invoke(); +class SpaceToBatchNDOpConstModel : public SpaceToBatchNDOpModel { + public: + SpaceToBatchNDOpConstModel(std::initializer_list input_shape, + std::initializer_list block_shape, + std::initializer_list paddings) { + input_ = AddInput(TensorType_FLOAT32); + block_shape_ = AddConstInput(TensorType_INT32, block_shape, {2}); + paddings_ = AddConstInput(TensorType_INT32, paddings, {2, 2}); + output_ = AddOutput(TensorType_FLOAT32); + + SetBuiltinOp(BuiltinOperator_SPACE_TO_BATCH_ND, + BuiltinOptions_SpaceToBatchNDOptions, + CreateSpaceToBatchNDOptions(builder_).Union()); + BuildInterpreter({input_shape}); + } +}; + +// Tests case where block_shape and paddings are non-const tensors. +// +// Example usage is as follows: +// SpaceToBatchNDOpDynamicModel m(input_shape); +// m.SetInput(input_data); +// m.SetBlockShape(block_shape); +// m.SetPaddings(paddings); +// m.Invoke(); +class SpaceToBatchNDOpDynamicModel : public SpaceToBatchNDOpModel { + public: + SpaceToBatchNDOpDynamicModel(std::initializer_list input_shape) { + input_ = AddInput(TensorType_FLOAT32); + block_shape_ = AddInput(TensorType_INT32); + paddings_ = AddInput(TensorType_INT32); + output_ = AddOutput(TensorType_FLOAT32); + + SetBuiltinOp(BuiltinOperator_SPACE_TO_BATCH_ND, + BuiltinOptions_SpaceToBatchNDOptions, + CreateSpaceToBatchNDOptions(builder_).Union()); + BuildInterpreter({input_shape, {2}, {2, 2}}); + } +}; + TEST(SpaceToBatchNDOpTest, InvalidShapeTest) { - EXPECT_DEATH(SpaceToBatchNDOpModel({1, 3, 3, 1}, {2, 2}, {0, 0}, {0, 0}), + EXPECT_DEATH(SpaceToBatchNDOpConstModel({1, 3, 3, 1}, {2, 2}, {0, 0, 0, 0}), "Cannot allocate tensors"); } -TEST(SpaceToBatchNDOpTest, SimpleTest) { - SpaceToBatchNDOpModel m({1, 4, 4, 1}, {2, 2}, {0, 0}, {0, 0}); +TEST(SpaceToBatchNDOpTest, SimpleConstTest) { + SpaceToBatchNDOpConstModel m({1, 4, 4, 1}, {2, 2}, {0, 0, 0, 0}); m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); m.Invoke(); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4, 2, 2, 1})); @@ -68,17 +108,39 @@ TEST(SpaceToBatchNDOpTest, SimpleTest) { 13, 15, 6, 8, 14, 16})); } -TEST(SpaceToBatchNDOpTest, MultipleInputBatches) { - SpaceToBatchNDOpModel m({2, 2, 4, 1}, {2, 2}, {0, 0}, {0, 0}); +TEST(SpaceToBatchNDOpTest, SimpleDynamicTest) { + SpaceToBatchNDOpDynamicModel m({1, 4, 4, 1}); m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); + m.SetBlockShape({2, 2}); + m.SetPaddings({0, 0, 0, 0}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4, 2, 2, 1})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 3, 9, 11, 2, 4, 10, 12, 5, 7, + 13, 15, 6, 8, 14, 16})); +} + +TEST(SpaceToBatchNDOpTest, MultipleInputBatchesConstTest) { + SpaceToBatchNDOpConstModel m({2, 2, 4, 1}, {2, 2}, {0, 0, 0, 0}); + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({8, 1, 2, 1})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 3, 9, 11, 2, 4, 10, 12, 5, 7, + 13, 15, 6, 8, 14, 16})); +} + +TEST(SpaceToBatchNDOpTest, MultipleInputBatchesDynamicTest) { + SpaceToBatchNDOpDynamicModel m({2, 2, 4, 1}); + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); + m.SetBlockShape({2, 2}); + m.SetPaddings({0, 0, 0, 0}); m.Invoke(); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({8, 1, 2, 1})); EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 3, 9, 11, 2, 4, 10, 12, 5, 7, 13, 15, 6, 8, 14, 16})); } -TEST(SpaceToBatchNDOpTest, SimplePadding) { - SpaceToBatchNDOpModel m({1, 5, 2, 1}, {3, 2}, {1, 2}, {0, 0}); +TEST(SpaceToBatchNDOpTest, SimplePaddingConstTest) { + SpaceToBatchNDOpConstModel m({1, 5, 2, 1}, {3, 2}, {1, 0, 2, 0}); m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10}); m.Invoke(); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({6, 2, 2, 1})); @@ -88,9 +150,36 @@ TEST(SpaceToBatchNDOpTest, SimplePadding) { })); } -TEST(SpaceToBatchNDOpTest, ComplexPadding) { - SpaceToBatchNDOpModel m({1, 4, 2, 1}, {3, 2}, {1, 2}, {1, 4}); +TEST(SpaceToBatchNDOpTest, SimplePaddingDynamicTest) { + SpaceToBatchNDOpDynamicModel m({1, 5, 2, 1}); + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10}); + m.SetBlockShape({3, 2}); + m.SetPaddings({1, 0, 2, 0}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({6, 2, 2, 1})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({ + 0, 0, 0, 5, 0, 0, 0, 6, 0, 1, 0, 7, + 0, 2, 0, 8, 0, 3, 0, 9, 0, 4, 0, 10, + })); +} + +TEST(SpaceToBatchNDOpTest, ComplexPaddingConstTest) { + SpaceToBatchNDOpConstModel m({1, 4, 2, 1}, {3, 2}, {1, 1, 2, 4}); + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({6, 2, 4, 1})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({ + 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 6, 0, 0, + 0, 1, 0, 0, 0, 7, 0, 0, 0, 2, 0, 0, 0, 8, 0, 0, + 0, 3, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, + })); +} + +TEST(SpaceToBatchNDOpTest, ComplexPaddingDynamicTest) { + SpaceToBatchNDOpDynamicModel m({1, 4, 2, 1}); m.SetInput({1, 2, 3, 4, 5, 6, 7, 8}); + m.SetBlockShape({3, 2}); + m.SetPaddings({1, 1, 2, 4}); m.Invoke(); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({6, 2, 4, 1})); EXPECT_THAT(m.GetOutput(), ElementsAreArray({ diff --git a/tensorflow/contrib/lite/kernels/strided_slice.cc b/tensorflow/contrib/lite/kernels/strided_slice.cc index 91ba4a9b7851c35a5138f4ccea307c810a4731a1..c4ffdf79d3aa7d47b9747bdf4208f8317d9fd22e 100644 --- a/tensorflow/contrib/lite/kernels/strided_slice.cc +++ b/tensorflow/contrib/lite/kernels/strided_slice.cc @@ -57,65 +57,6 @@ struct StridedSliceContext { int dims; }; -TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { - TF_LITE_ENSURE_EQ(context, NumInputs(node), 4); - TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); - - StridedSliceContext op_context(context, node); - - // Ensure validity of input tensor and its dimension - TF_LITE_ENSURE_EQ(context, NumDimensions(op_context.begin), 1); - TF_LITE_ENSURE_EQ(context, NumDimensions(op_context.end), 1); - TF_LITE_ENSURE_EQ(context, NumDimensions(op_context.strides), 1); - TF_LITE_ENSURE_EQ(context, op_context.input->type, op_context.output->type); - // Only INT32 begin/end/strides are supported - // TODO(soroosh) add support for INT64 - TF_LITE_ENSURE_EQ(context, op_context.begin->type, kTfLiteInt32); - TF_LITE_ENSURE_EQ(context, op_context.end->type, kTfLiteInt32); - TF_LITE_ENSURE_EQ(context, op_context.strides->type, kTfLiteInt32); - TF_LITE_ENSURE_MSG(context, op_context.dims <= 4, - "StridedSlice op only supports 1D-4D input arrays."); - - // TODO(soroosh): add the following missing functionalities - TF_LITE_ENSURE_MSG(context, op_context.params->ellipsis_mask == 0, - "ellipsis_mask is not implemented yet."); - TF_LITE_ENSURE_MSG(context, op_context.params->new_axis_mask == 0, - "new_axis_mask is not implemented yet."); - TF_LITE_ENSURE_MSG(context, op_context.params->shrink_axis_mask == 0, - "shrink_axis_mask is not implemented yet."); - - // TODO(soroosh): optimize for constant tensors to do allocation in Prepare - op_context.output->allocation_type = kTfLiteDynamic; - return kTfLiteOk; -} // namespace strided_slice - -// TODO(soroosh): consolidate with BytesRequired in interpreter.h -TfLiteStatus BytesRequired(TfLiteContext* context, TfLiteType type, - const int* dims, int dims_size, size_t* bytes) { - // TODO(aselle): Check for overflow here using overflow.h in TensorFlow - // MultiplyWithoutOverflow. - TF_LITE_ENSURE(context, bytes != nullptr); - size_t count = 1; - for (int k = 0; k < dims_size; k++) count *= dims[k]; - switch (type) { - case kTfLiteFloat32: - *bytes = sizeof(float) * count; - break; - case kTfLiteInt32: - *bytes = sizeof(int32_t) * count; - break; - case kTfLiteUInt8: - *bytes = sizeof(uint8_t) * count; - break; - case kTfLiteInt64: - *bytes = sizeof(int64_t) * count; - break; - default: - return kTfLiteError; - } - return kTfLiteOk; -} - // Reverse order of bits in the mask to match the expected order in kernel inline int ReverseMaskBits(int mask, int num_dimensions) { int out = 0; @@ -146,40 +87,111 @@ inline int32_t ClampedIndex(int32_t index, int dim, bool pos_stride) { std::min(std::max(index, -dim), dim - 1), dim)); } +inline int32_t GetBeginValueAtIndex(StridedSliceContext* op_context, int idx) { + const int dim = op_context->input->dims->data[idx]; + const bool pos_stride = GetTensorData(op_context->strides)[idx] > 0; + return op_context->params->begin_mask & (1 << idx) + ? pos_stride ? 0 : dim - 1 + : ClampedIndex(GetTensorData(op_context->begin)[idx], dim, + pos_stride); +} + +inline int32_t GetEndValueAtIndex(StridedSliceContext* op_context, int idx) { + const int dim = op_context->input->dims->data[idx]; + const bool pos_stride = GetTensorData(op_context->strides)[idx] > 0; + return op_context->params->end_mask & (1 << idx) + ? pos_stride ? dim : -1 + : ClampedIndex(GetTensorData(op_context->end)[idx], dim, + pos_stride); +} + +// Processes the indexing tensors (begin, end and strides) to resize the +// output tensor. This function is callable from both Prepare() and Eval() as +// long as the caller ensures the indexing tensors are present. +TfLiteStatus ResizeOutputTensor(TfLiteContext* context, + StridedSliceContext* op_context) { + std::vector output_shape_vector; + + for (int idx = op_context->dims - 1; idx >= 0; --idx) { + int32_t stride = GetTensorData(op_context->strides)[idx]; + TF_LITE_ENSURE_MSG(context, stride != 0, "stride value has to be non-zero"); + + int32_t begin = GetBeginValueAtIndex(op_context, idx); + int32_t end = GetEndValueAtIndex(op_context, idx); + + // This is valid for both positive and negative strides + int32_t dim_shape = ceil((end - begin) / static_cast(stride)); + dim_shape = dim_shape < 0 ? 0 : dim_shape; + if (!(op_context->params->shrink_axis_mask & (1 << idx))) { + output_shape_vector.push_back(dim_shape); + } + } + + TfLiteIntArray* output_shape = + TfLiteIntArrayCreate(output_shape_vector.size()); + + std::reverse_copy(output_shape_vector.begin(), output_shape_vector.end(), + output_shape->data); + + TF_LITE_ENSURE_STATUS( + context->ResizeTensor(context, op_context->output, output_shape)); + + return kTfLiteOk; +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 4); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + StridedSliceContext op_context(context, node); + + // Ensure validity of input tensor and its dimension + TF_LITE_ENSURE_EQ(context, NumDimensions(op_context.begin), 1); + TF_LITE_ENSURE_EQ(context, NumDimensions(op_context.end), 1); + TF_LITE_ENSURE_EQ(context, NumDimensions(op_context.strides), 1); + TF_LITE_ENSURE_EQ(context, op_context.input->type, op_context.output->type); + // Only INT32 begin/end/strides are supported + // TODO(soroosh) add support for INT64 + TF_LITE_ENSURE_EQ(context, op_context.begin->type, kTfLiteInt32); + TF_LITE_ENSURE_EQ(context, op_context.end->type, kTfLiteInt32); + TF_LITE_ENSURE_EQ(context, op_context.strides->type, kTfLiteInt32); + TF_LITE_ENSURE_MSG(context, op_context.dims <= 4, + "StridedSlice op only supports 1D-4D input arrays."); + + // TODO(soroosh): add the following missing functionalities + TF_LITE_ENSURE_MSG(context, op_context.params->ellipsis_mask == 0, + "ellipsis_mask is not implemented yet."); + TF_LITE_ENSURE_MSG(context, op_context.params->new_axis_mask == 0, + "new_axis_mask is not implemented yet."); + + // Postpone allocation of output if any of the indexing tensors is not + // constant + if (!(IsConstantTensor(op_context.begin) && + IsConstantTensor(op_context.end) && + IsConstantTensor(op_context.strides))) { + SetTensorToDynamic(op_context.output); + return kTfLiteOk; + } + return ResizeOutputTensor(context, &op_context); +} + template TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { StridedSliceContext op_context(context, node); - std::vector starts; - std::vector stops; - std::vector strides; + if (IsDynamicTensor(op_context.output)) { + TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context)); + TfLiteTensorRealloc(op_context.output->bytes, op_context.output); + } - // Determine size of output tensor and map indices - TfLiteIntArray* output_shape = TfLiteIntArrayCreate(op_context.dims); - for (int idx = op_context.dims - 1; idx >= 0; --idx) { - int dim = op_context.input->dims->data[idx]; - int32_t stride = GetTensorData(op_context.strides)[idx]; - TF_LITE_ENSURE_MSG(context, stride != 0, "stride value has to be non-zero"); - bool pos_stride = stride > 0; - - int32_t begin = - op_context.params->begin_mask & (1 << idx) - ? pos_stride ? 0 : dim - 1 - : ClampedIndex(GetTensorData(op_context.begin)[idx], dim, - pos_stride); - int32_t end = - op_context.params->end_mask & (1 << idx) - ? pos_stride ? dim : -1 - : ClampedIndex(GetTensorData(op_context.end)[idx], dim, - pos_stride); + std::vector starts; + std::vector stops; + std::vector strides; - // This is valid for both positive and negative strides - output_shape->data[idx] = ceil((end - begin) / static_cast(stride)); - output_shape->data[idx] = - output_shape->data[idx] < 0 ? 0 : output_shape->data[idx]; - starts.emplace_back(begin); - stops.emplace_back(end); - strides.emplace_back(stride); + for (int idx = op_context.dims - 1; idx >= 0; --idx) { + starts.emplace_back(GetBeginValueAtIndex(&op_context, idx)); + stops.emplace_back(GetEndValueAtIndex(&op_context, idx)); + strides.emplace_back(GetTensorData(op_context.strides)[idx]); } for (int i = op_context.dims; i < kMaxDim; i++) { @@ -188,27 +200,19 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { strides.emplace_back(1); } - TF_LITE_ENSURE_STATUS( - context->ResizeTensor(context, op_context.output, output_shape)); - - size_t required_bytes; - TF_LITE_ENSURE_OK( - context, - BytesRequired(context, op_context.output->type, output_shape->data, - output_shape->size, &required_bytes)); - TfLiteTensorRealloc(required_bytes, op_context.output); - op_context.params->begin_mask = ReverseMaskBits(op_context.params->begin_mask, op_context.dims); op_context.params->end_mask = ReverseMaskBits(op_context.params->end_mask, op_context.dims); - -#define TF_LITE_STRIDED_SLICE(kernel_type, data_type) \ - kernel_type::StridedSlice( \ - GetTensorData(op_context.input), \ - GetTensorDims(op_context.input), op_context.params->begin_mask, \ - op_context.params->end_mask, starts, stops, strides, \ - GetTensorData(op_context.output), \ + op_context.params->shrink_axis_mask = + ReverseMaskBits(op_context.params->shrink_axis_mask, op_context.dims); + +#define TF_LITE_STRIDED_SLICE(kernel_type, data_type) \ + kernel_type::StridedSlice( \ + GetTensorData(op_context.input), \ + GetTensorDims(op_context.input), op_context.params->begin_mask, \ + op_context.params->end_mask, op_context.params->shrink_axis_mask, \ + starts, stops, strides, GetTensorData(op_context.output), \ GetTensorDims(op_context.output)) switch (op_context.input->type) { diff --git a/tensorflow/contrib/lite/kernels/strided_slice_test.cc b/tensorflow/contrib/lite/kernels/strided_slice_test.cc index cd4a364682c0e66b2ceec92c0b34461945caf779..5cac04b38364958c5b0794c21742e8b592372ae9 100644 --- a/tensorflow/contrib/lite/kernels/strided_slice_test.cc +++ b/tensorflow/contrib/lite/kernels/strided_slice_test.cc @@ -21,6 +21,7 @@ limitations under the License. namespace tflite { namespace { +using ::int32; using ::testing::ElementsAreArray; class StridedSliceOpModel : public SingleOpModel { @@ -79,8 +80,6 @@ TEST(StridedSliceOpTest, UnssupportedArgs) { "ellipsis_mask is not implemented yet."); EXPECT_DEATH(StridedSliceOpModel({3, 2}, {2}, {2}, {2}, 0, 0, 0, 1, 0), "new_axis_mask is not implemented yet."); - EXPECT_DEATH(StridedSliceOpModel({3, 2}, {2}, {2}, {2}, 0, 0, 0, 0, 1), - "shrink_axis_mask is not implemented yet."); } TEST(StridedSliceOpTest, In1D) { @@ -213,6 +212,7 @@ TEST(StridedSliceOpTest, In1D_EndMask) { EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3})); EXPECT_THAT(m.GetOutput(), ElementsAreArray({2, 3, 4})); } + TEST(StridedSliceOpTest, In1D_NegStride) { StridedSliceOpModel m({3}, {1}, {1}, {1}, 0, 0, 0, 0, 0); m.SetInput({1, 2, 3}); @@ -234,6 +234,7 @@ TEST(StridedSliceOpTest, In1D_EvenLenStride2) { EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1})); EXPECT_THAT(m.GetOutput(), ElementsAreArray({1})); } + TEST(StridedSliceOpTest, In1D_OddLenStride2) { StridedSliceOpModel m({3}, {1}, {1}, {1}, 0, 0, 0, 0, 0); m.SetInput({1, 2, 3}); @@ -255,6 +256,7 @@ TEST(StridedSliceOpTest, In2D_Identity) { EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 3})); EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6})); } + TEST(StridedSliceOpTest, In2D) { StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 0); m.SetInput({1, 2, 3, 4, 5, 6}); @@ -320,6 +322,7 @@ TEST(StridedSliceOpTest, In2D_NegStrideBeginMask) { EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3})); EXPECT_THAT(m.GetOutput(), ElementsAreArray({6, 5, 4})); } + TEST(StridedSliceOpTest, In2D_NegStrideEndMask) { StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 0, 2, 0, 0, 0); m.SetInput({1, 2, 3, 4, 5, 6}); @@ -354,6 +357,7 @@ TEST(StridedSliceOpTest, In3D_NegStride) { EXPECT_THAT(m.GetOutput(), ElementsAreArray({12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1})); } + TEST(StridedSliceOpTest, In3D_Strided2) { StridedSliceOpModel m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 0); m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); @@ -365,6 +369,159 @@ TEST(StridedSliceOpTest, In3D_Strided2) { EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 5})); } +TEST(StridedSliceOpTest, In1D_ShrinkAxisMask1) { + StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 1); + m.SetInput({1, 2, 3, 4}); + m.SetBegin({1}); + m.SetEnd({3}); + m.SetStrides({1}); + m.Invoke(); + EXPECT_TRUE(m.GetOutputShape().empty()); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({2})); +} + +TEST(StridedSliceOpTest, In1D_EmptyOutputShrinkAxisMask1) { + StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 1); + m.SetInput({1, 2, 3, 4}); + m.SetBegin({2}); + m.SetEnd({1}); + m.SetStrides({1}); + m.Invoke(); + EXPECT_TRUE(m.GetOutputShape().empty()); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({3})); +} + +TEST(StridedSliceOpTest, In1D_BeginMaskShrinkAxisMask1) { + StridedSliceOpModel m({4}, {1}, {1}, {1}, 1, 0, 0, 0, 1); + m.SetInput({1, 2, 3, 4}); + m.SetBegin({1}); + m.SetEnd({3}); + m.SetStrides({1}); + m.Invoke(); + EXPECT_TRUE(m.GetOutputShape().empty()); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1})); +} + +TEST(StridedSliceOpTest, In1D_NegativeBeginNegativeStrideShrinkAxisMask1) { + StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 1); + m.SetInput({1, 2, 3, 4}); + m.SetBegin({-2}); + m.SetEnd({-3}); + m.SetStrides({-1}); + m.Invoke(); + EXPECT_TRUE(m.GetOutputShape().empty()); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({3})); +} + +TEST(StridedSliceOpTest, In2D_ShrinkAxisMask1) { + StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 1); + m.SetInput({1, 2, 3, 4, 5, 6}); + m.SetBegin({0, 0}); + m.SetEnd({2, 3}); + m.SetStrides({1, 1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3})); +} + +TEST(StridedSliceOpTest, In2D_ShrinkAxisMask2) { + StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 2); + m.SetInput({1, 2, 3, 4, 5, 6}); + m.SetBegin({0, 0}); + m.SetEnd({2, 3}); + m.SetStrides({1, 1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 4})); +} + +TEST(StridedSliceOpTest, In2D_ShrinkAxisMask3) { + StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 3); + m.SetInput({1, 2, 3, 4, 5, 6}); + m.SetBegin({0, 0}); + m.SetEnd({2, 3}); + m.SetStrides({1, 1}); + m.Invoke(); + EXPECT_TRUE(m.GetOutputShape().empty()); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1})); +} + +TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis1) { + StridedSliceOpModel m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 1); + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + m.SetBegin({0, 0, 0}); + m.SetEnd({2, 3, 2}); + m.SetStrides({1, 1, 1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 2})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6})); +} + +TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis2) { + StridedSliceOpModel m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 2); + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + m.SetBegin({0, 0, 0}); + m.SetEnd({2, 3, 2}); + m.SetStrides({1, 1, 1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 7, 8})); +} + +TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis3) { + StridedSliceOpModel m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 3); + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + m.SetBegin({0, 0, 0}); + m.SetEnd({2, 3, 2}); + m.SetStrides({1, 1, 1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2})); +} + +TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis4) { + StridedSliceOpModel m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 4); + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + m.SetBegin({0, 0, 0}); + m.SetEnd({2, 3, 2}); + m.SetStrides({1, 1, 1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 3})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 3, 5, 7, 9, 11})); +} + +TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis5) { + StridedSliceOpModel m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 5); + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + m.SetBegin({0, 0, 0}); + m.SetEnd({2, 3, 2}); + m.SetStrides({1, 1, 1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 3, 5})); +} + +TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis6) { + StridedSliceOpModel m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 6); + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + m.SetBegin({0, 0, 0}); + m.SetEnd({2, 3, 2}); + m.SetStrides({1, 1, 1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 7})); +} + +TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis7) { + StridedSliceOpModel m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 7); + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + m.SetBegin({0, 0, 0}); + m.SetEnd({2, 3, 2}); + m.SetStrides({1, 1, 1}); + m.Invoke(); + EXPECT_TRUE(m.GetOutputShape().empty()); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1})); +} } // namespace } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/svdf.cc b/tensorflow/contrib/lite/kernels/svdf.cc index 72f705fe4242b01c1516c99d3500484e8729fd9a..c69755447d5093e25d408eb6dea80750937465e7 100644 --- a/tensorflow/contrib/lite/kernels/svdf.cc +++ b/tensorflow/contrib/lite/kernels/svdf.cc @@ -15,8 +15,8 @@ limitations under the License. #include #include #include -#include #include +#include #include #include diff --git a/tensorflow/contrib/lite/kernels/svdf_test.cc b/tensorflow/contrib/lite/kernels/svdf_test.cc index 4de2ceaf053df31a4bc857fb250db416c071e80f..0f166dc69b95f3459388135b3a6c4d9b73a31cb4 100644 --- a/tensorflow/contrib/lite/kernels/svdf_test.cc +++ b/tensorflow/contrib/lite/kernels/svdf_test.cc @@ -14,8 +14,8 @@ limitations under the License. ==============================================================================*/ // Unit test for TFLite SVDF op. -#include #include +#include #include #include diff --git a/tensorflow/contrib/lite/kernels/test_util.cc b/tensorflow/contrib/lite/kernels/test_util.cc index 3a58e7ec321f649a6cae4cc0969807c2c74c6529..6f56aa6bf38781e860e33e8ac3b6a0bb8b50bb01 100644 --- a/tensorflow/contrib/lite/kernels/test_util.cc +++ b/tensorflow/contrib/lite/kernels/test_util.cc @@ -172,11 +172,14 @@ void SingleOpModel::BuildInterpreter( auto* model = GetModel(builder_.GetBufferPointer()); - ops::builtin::BuiltinOpResolver builtins; - for (const auto& reg : custom_registrations_) { - builtins.AddCustom(reg.first.data(), reg.second()); + if (!resolver_) { + auto resolver = new ops::builtin::BuiltinOpResolver(); + for (const auto& reg : custom_registrations_) { + resolver->AddCustom(reg.first.data(), reg.second()); + } + resolver_ = std::unique_ptr(resolver); } - InterpreterBuilder(model, builtins)(&interpreter_); + InterpreterBuilder(model, *resolver_)(&interpreter_); CHECK(interpreter_ != nullptr); diff --git a/tensorflow/contrib/lite/kernels/test_util.h b/tensorflow/contrib/lite/kernels/test_util.h index cc445299ff9f0b75610c7ff38f28facbbbe5587d..7d476ba1eaffbb24fb77390c0e71c32d60b6411e 100644 --- a/tensorflow/contrib/lite/kernels/test_util.h +++ b/tensorflow/contrib/lite/kernels/test_util.h @@ -85,6 +85,23 @@ struct TensorData { int32_t zero_point; }; +class SingleOpResolver : public OpResolver { + public: + SingleOpResolver(const BuiltinOperator op, TfLiteRegistration* registration) + : op_(op), registration_(registration) {} + TfLiteRegistration* FindOp(BuiltinOperator op) const override { + if (op == op_) { + return registration_; + } + return nullptr; + } + TfLiteRegistration* FindOp(const char* op) const override { return nullptr; } + + private: + const BuiltinOperator op_; + TfLiteRegistration* registration_; +}; + class SingleOpModel { public: SingleOpModel() {} @@ -178,11 +195,16 @@ class SingleOpModel { return result; } + void SetResolver(std::unique_ptr resolver) { + resolver_ = std::move(resolver); + } + protected: int32_t GetTensorSize(int index) const; flatbuffers::FlatBufferBuilder builder_; std::unique_ptr interpreter_; + std::unique_ptr resolver_; private: int AddTensor(TensorData t, std::initializer_list data); @@ -197,6 +219,36 @@ class SingleOpModel { std::map> custom_registrations_; }; +// Base class for single op unit tests. +// The tests are parameterized to test multiple kernels for a single op. +// The parameters are strings like "optimized" and "reference" to have better +// readability in test reports. +// +// To use this class: +// * Define a constant map from strings to TfLiteRegistration. +// * Implement a test class that inherits SingleOpTest. +// * Instantiate the test cases with SingleOpTest::GetKernelTags helper +// function. +// * Call GetRegistration to get the TfLiteRegistration to be used before +// building the interpreter. +class SingleOpTest : public ::testing::TestWithParam { + public: + static std::vector GetKernelTags( + const std::map& kernel_map) { + std::vector tags; + for (auto it : kernel_map) { + tags.push_back(it.first); + } + return tags; + } + + protected: + virtual const std::map& GetKernelMap() = 0; + TfLiteRegistration* GetRegistration() { + return GetKernelMap().at(GetParam()); + } +}; + // Strings have a special implementation that is in test_util.cc template <> std::vector SingleOpModel::ExtractVector(int index); diff --git a/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc b/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc index f5f1ec2cf3f45ae730b849b18e2b85fac50159c7..ac00c37b67dcbe77023a2495a698967ca555b1d5 100644 --- a/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc +++ b/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc @@ -15,14 +15,15 @@ limitations under the License. #include #include #include -#include #include +#include #include #include #include "tensorflow/contrib/lite/builtin_op_data.h" #include "tensorflow/contrib/lite/context.h" #include "tensorflow/contrib/lite/kernels/activation_functor.h" +#include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h" #include "tensorflow/contrib/lite/kernels/op_macros.h" namespace tflite { @@ -82,48 +83,12 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { output_size_array->data[0] = (time_major) ? max_time : batch_size; output_size_array->data[1] = (time_major) ? batch_size : max_time; output_size_array->data[2] = num_units; - TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, output, - output_size_array)); + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, output, output_size_array)); return kTfLiteOk; } -namespace { -void RnnStep(const float* input_ptr_batch, const float* input_weights_ptr, - const float* recurrent_weights_ptr, const float* bias_ptr, - int input_size, int num_units, int input_weights_stride, - int recurrent_weights_stride, TfLiteFusedActivation activation, - float* hidden_state_ptr_batch, float* output_ptr_batch) { - // Output = bias - for (int o = 0; o < num_units; o++) { - output_ptr_batch[o] = bias_ptr[o]; - } - - // Output += input * input_weights - for (int o = 0; o < num_units; o++) { - for (int i = 0; i < input_size; i++) { - output_ptr_batch[o] += input_ptr_batch[i] * input_weights_ptr[i]; - } - input_weights_ptr += input_weights_stride; - } - - // Output += recurrent_weights * hidden_state - for (int o = 0; o < num_units; o++) { - for (int h = 0; h < num_units; h++) { - output_ptr_batch[o] += - hidden_state_ptr_batch[h] * recurrent_weights_ptr[h]; - } - recurrent_weights_ptr += recurrent_weights_stride; - } - - // Output = activation(Output) and update hidden_state - for (int o = 0; o < num_units; o++) { - output_ptr_batch[o] = (ActivationFunctor(activation))(output_ptr_batch[o]); - hidden_state_ptr_batch[o] = output_ptr_batch[o]; - } -} -} // namespace - TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { auto* params = reinterpret_cast(node->builtin_data); @@ -147,30 +112,25 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { (time_major) ? input->dims->data[0] : input->dims->data[1]; const int num_units = input_weights->dims->data[0]; const int input_size = input->dims->data[2]; - const int input_weights_stride = input_weights->dims->data[1]; - const int recurrent_weights_stride = recurrent_weights->dims->data[1]; // Initialize input_weights and recurrent_weights. const float* input_weights_ptr = input_weights->data.f; const float* recurrent_weights_ptr = recurrent_weights->data.f; if (time_major) { - // Unroll the sequence + // Initialize the pointer to hidden state. + float* hidden_state_ptr_batch = hidden_state->data.f; + // Unroll the sequence and use batch batch operations for efficiency. for (int s = 0; s < max_time; s++) { - for (int b = 0; b < batch_size; b++) { - // Initialize the pointer to hidden state. - float* hidden_state_ptr_batch = hidden_state->data.f + b * num_units; - // Initialize the pointer to input and output. - const float* input_ptr_batch = - input->data.f + s * input_size * batch_size + b * input_size; - float* output_ptr_batch = - output->data.f + s * num_units * batch_size + b * num_units; - - RnnStep(input_ptr_batch, input_weights_ptr, recurrent_weights_ptr, - bias_ptr, input_size, num_units, input_weights_stride, - recurrent_weights_stride, params->activation, - hidden_state_ptr_batch, output_ptr_batch); - } + // Initialize the pointer to input and output. + const float* input_ptr_batch = + input->data.f + s * input_size * batch_size; + float* output_ptr_batch = output->data.f + s * num_units * batch_size; + + kernel_utils::RnnBatchStep(input_ptr_batch, input_weights_ptr, + recurrent_weights_ptr, bias_ptr, input_size, + num_units, batch_size, params->activation, + hidden_state_ptr_batch, output_ptr_batch); } } else { // For each batch @@ -184,10 +144,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { float* output_ptr_batch = output->data.f + b * num_units * max_time + s * num_units; - RnnStep(input_ptr_batch, input_weights_ptr, recurrent_weights_ptr, - bias_ptr, input_size, num_units, input_weights_stride, - recurrent_weights_stride, params->activation, - hidden_state_ptr_batch, output_ptr_batch); + kernel_utils::RnnBatchStep( + input_ptr_batch, input_weights_ptr, recurrent_weights_ptr, bias_ptr, + input_size, num_units, /*batch_size=*/1, params->activation, + hidden_state_ptr_batch, output_ptr_batch); } } } diff --git a/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn_test.cc b/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn_test.cc index 82c680ec3d8656004d721c8498292677cb061b6b..7e32969763b59620dc3534708f965750680002d2 100644 --- a/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn_test.cc +++ b/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn_test.cc @@ -14,8 +14,8 @@ limitations under the License. ==============================================================================*/ // Unit test for TFLite Sequential RNN op. -#include #include +#include #include #include @@ -120,8 +120,7 @@ static float rnn_golden_output[] = { 0.415153, 0.210318, 0, 0, 0, 0, 0, 2.02616, 0, 0.728256, 0.84183, 0.0907453, - 0.628881, 3.58099, 1.49974, 0 -}; + 0.628881, 3.58099, 1.49974, 0}; class UnidirectionalRNNOpModel : public SingleOpModel { public: diff --git a/tensorflow/contrib/lite/lib_package/create_ios_frameworks.sh b/tensorflow/contrib/lite/lib_package/create_ios_frameworks.sh new file mode 100755 index 0000000000000000000000000000000000000000..b58ae266017caf8781c28331f49a8f5bc1550767 --- /dev/null +++ b/tensorflow/contrib/lite/lib_package/create_ios_frameworks.sh @@ -0,0 +1,81 @@ +#!/bin/bash -x +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +set -e + +echo "Starting" +TFLITE_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)/.." + +TMP_DIR=$(mktemp -d) +echo "Package dir: " $TMP_DIR +FW_DIR=$TMP_DIR/tensorflow_lite_ios_frameworks +FW_DIR_TFLITE=$FW_DIR/tensorflow_lite.framework +FW_DIR_TFLITE_HDRS=$FW_DIR_TFLITE/Headers + +echo "Creating target Headers directories" +mkdir -p $FW_DIR_TFLITE_HDRS + +echo "Headers, populating: TensorFlow Lite" +cd $TFLITE_DIR/../../.. + +find tensorflow/contrib/lite -name '*.h' \ + -not -path 'tensorflow/contrib/lite/downloads/*' \ + -not -path 'tensorflow/contrib/lite/examples/*' \ + -not -path 'tensorflow/contrib/lite/gen/*' \ + -not -path 'tensorflow/contrib/lite/toco/*' \ + -not -path 'tensorflow/contrib/lite/nnapi/*' \ + -not -path 'tensorflow/contrib/lite/java/*' \ + | tar -cf $FW_DIR_TFLITE_HDRS/tmp.tar -T - +cd $FW_DIR_TFLITE_HDRS +tar xf tmp.tar +rm -f tmp.tar + +echo "Headers, populating: Flatbuffer" +cd $TFLITE_DIR/downloads/flatbuffers/include/ +find . -name '*.h' | tar -cf $FW_DIR_TFLITE_HDRS/tmp.tar -T - +cd $FW_DIR_TFLITE_HDRS +tar xf tmp.tar +rm -f tmp.tar + +cd $TFLITE_DIR/../../.. +echo "Generate master LICENSE file and copy to target" +bazel build //tensorflow/tools/lib_package:clicenses_generate +cp $TFLITE_DIR/../../../bazel-genfiles/tensorflow/tools/lib_package/include/tensorflow/c/LICENSE \ + $FW_DIR_TFLITE + +echo "Copying static libraries" +cp $TFLITE_DIR/gen/lib/libtensorflow-lite.a \ + $FW_DIR_TFLITE/tensorflow_lite + +# This is required, otherwise they interfere with the documentation of the +# pod at cocoapods.org. +echo "Remove all README files" +cd $FW_DIR_TFLITE_HDRS +find . -type f -name README\* -exec rm -f {} \; +find . -type f -name readme\* -exec rm -f {} \; + +TARGET_GEN_LOCATION="$TFLITE_DIR/gen/ios_frameworks" +echo "Moving results to target: " $TARGET_GEN_LOCATION +cd $FW_DIR +zip -q -r tensorflow_lite.framework.zip tensorflow_lite.framework -x .DS_Store +rm -rf $TARGET_GEN_LOCATION +mkdir -p $TARGET_GEN_LOCATION +cp -r tensorflow_lite.framework.zip $TARGET_GEN_LOCATION + +echo "Cleaning up" +rm -rf $TMP_DIR + +echo "Finished" diff --git a/tensorflow/contrib/lite/model.cc b/tensorflow/contrib/lite/model.cc index ec4d6e3487a8207fcf4ed2d17aac104d83d6a782..14b6709964b54a6532273a69cca51c560b1cc103 100644 --- a/tensorflow/contrib/lite/model.cc +++ b/tensorflow/contrib/lite/model.cc @@ -469,6 +469,7 @@ void* ParseOpData(const Operator* op, BuiltinOperator op_type, auto* params = MallocPOD(); if (auto* schema_params = op->builtin_options_as_ResizeBilinearOptions()) { + params->align_corners = schema_params->align_corners(); } builtin_data = reinterpret_cast(params); break; @@ -516,41 +517,9 @@ void* ParseOpData(const Operator* op, BuiltinOperator op_type, break; } case BuiltinOperator_SPACE_TO_BATCH_ND: { - auto* params = MallocPOD(); - if (auto* schema_params = - op->builtin_options_as_SpaceToBatchNDOptions()) { - const auto& block_shape = schema_params->block_shape(); - FlatBufferIntVectorToArray(sizeof(params->block_shape), block_shape, - params->block_shape, error_reporter); - const auto& before_paddings = schema_params->before_paddings(); - FlatBufferIntVectorToArray(sizeof(params->before_paddings), - before_paddings, params->before_paddings, - error_reporter); - const auto& after_paddings = schema_params->after_paddings(); - FlatBufferIntVectorToArray(sizeof(params->after_paddings), - after_paddings, params->after_paddings, - error_reporter); - params->num_spatial_dimensions = block_shape->Length(); - } - builtin_data = reinterpret_cast(params); break; } case BuiltinOperator_BATCH_TO_SPACE_ND: { - auto* params = MallocPOD(); - if (auto* schema_params = - op->builtin_options_as_BatchToSpaceNDOptions()) { - const auto& block_shape = schema_params->block_shape(); - FlatBufferIntVectorToArray(sizeof(params->block_shape), block_shape, - params->block_shape, error_reporter); - const auto& before_crops = schema_params->before_crops(); - FlatBufferIntVectorToArray(sizeof(params->before_crops), before_crops, - params->before_crops, error_reporter); - const auto& after_crops = schema_params->after_crops(); - FlatBufferIntVectorToArray(sizeof(params->after_crops), after_crops, - params->after_crops, error_reporter); - params->num_spatial_dimensions = block_shape->Length(); - } - builtin_data = reinterpret_cast(params); break; } case BuiltinOperator_TRANSPOSE: { @@ -559,11 +528,7 @@ void* ParseOpData(const Operator* op, BuiltinOperator op_type, case BuiltinOperator_MEAN: { auto* params = MallocPOD(); if (auto* schema_params = op->builtin_options_as_MeanOptions()) { - const auto& axis = schema_params->axis(); - FlatBufferIntVectorToArray(sizeof(params->axis), axis, params->axis, - error_reporter); params->keep_dims = schema_params->keep_dims(); - params->num_axis_dimensions = axis->Length(); } builtin_data = reinterpret_cast(params); break; diff --git a/tensorflow/contrib/lite/python/BUILD b/tensorflow/contrib/lite/python/BUILD index 3d6a3ec0fd4c673f601254b19452bbf8b9454e27..2d8c49b7d7a5ae5c180f100a399a1870679c455f 100644 --- a/tensorflow/contrib/lite/python/BUILD +++ b/tensorflow/contrib/lite/python/BUILD @@ -13,6 +13,7 @@ py_library( srcs_version = "PY2AND3", visibility = ["//visibility:public"], deps = [ + ":op_hint", "//tensorflow/contrib/lite/toco:model_flags_proto_py", "//tensorflow/contrib/lite/toco:toco_flags_proto_py", "//tensorflow/contrib/lite/toco/python:tensorflow_wrap_toco", @@ -20,6 +21,17 @@ py_library( ], ) +py_library( + name = "op_hint", + srcs = ["op_hint.py"], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/contrib/framework:framework_py", + "//tensorflow/python:platform", + ], +) + py_test( name = "lite_test", srcs = ["lite_test.py"], @@ -27,6 +39,7 @@ py_test( tags = ["no_oss"], deps = [ ":lite", + ":op_hint", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:dtypes", diff --git a/tensorflow/contrib/lite/python/lite.py b/tensorflow/contrib/lite/python/lite.py index 3c369774beda57cca3bc1ea0ab9a9ad619841e7e..5d2f21653762a405a57288a7ba38323e5e42b3e1 100644 --- a/tensorflow/contrib/lite/python/lite.py +++ b/tensorflow/contrib/lite/python/lite.py @@ -18,16 +18,21 @@ EXPERIMENTAL: APIs here are unstable and likely to change without notice. @@toco_convert @@toco_convert_protos +@@OpHint +@@convert_op_hints_to_stubs """ from __future__ import absolute_import from __future__ import division from __future__ import print_function - import os import subprocess import tempfile +# pylint: disable=unused-import +from tensorflow.contrib.lite.python.op_hint import convert_op_hints_to_stubs +from tensorflow.contrib.lite.python.op_hint import OpHint +# pylint: enable=unused-import from tensorflow.contrib.lite.toco import model_flags_pb2 as _model_flags_pb2 from tensorflow.contrib.lite.toco import toco_flags_pb2 as _toco_flags_pb2 from tensorflow.contrib.lite.toco import types_pb2 as _types_pb2 diff --git a/tensorflow/contrib/lite/python/lite_test.py b/tensorflow/contrib/lite/python/lite_test.py index 7d55f3fe6fe41a5d9e4e57c7a8e664bba6887fc7..b8b4510188bee867b32ffde714b27f41a1df778a 100644 --- a/tensorflow/contrib/lite/python/lite_test.py +++ b/tensorflow/contrib/lite/python/lite_test.py @@ -18,10 +18,14 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.lite.python import lite +from tensorflow.contrib.lite.python.op_hint import _tensor_name_base as _tensor_name_base from tensorflow.python.client import session from tensorflow.python.framework import dtypes from tensorflow.python.framework import test_util +from tensorflow.python.framework.graph_util_impl import _bfs_for_reachable_nodes +from tensorflow.python.framework.graph_util_impl import _extract_graph_summary from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops from tensorflow.python.platform import test @@ -35,7 +39,8 @@ class LiteTest(test_util.TensorFlowTestCase): # Try running on valid graph result = lite.toco_convert(sess.graph_def, [in_tensor], [out_tensor]) self.assertTrue(result) - # TODO(aselle): remove tests that fail. + # TODO(aselle): remove tests that fail (we must get TOCO to not fatal + # all the time). # Try running on identity graph (known fail) # with self.assertRaisesRegexp(RuntimeError, "!model->operators.empty()"): # result = lite.toco_convert(sess.graph_def, [in_tensor], [in_tensor]) @@ -51,5 +56,116 @@ class LiteTest(test_util.TensorFlowTestCase): quantized_input_stats=[(0., 1.)]) self.assertTrue(result) + +class LiteTestOpHint(test_util.TensorFlowTestCase): + """Test the hint to stub functionality.""" + + def _getGraphOpTypes(self, graphdef, output_nodes): + """Returns used op types in `graphdef` reachable from `output_nodes`. + + This is used to check that after the stub transformation the expected + nodes are there. Typically use this with self.assertCountEqual(...). + + NOTE: this is not a exact test that the graph is the correct output, but + it balances compact expressibility of test with sanity checking. + + Args: + graphdef: TensorFlow proto graphdef. + output_nodes: A list of output node names that we need to reach. + + Returns: + A set of node types reachable from `output_nodes`. + """ + name_to_input_name, name_to_node, _ = ( + _extract_graph_summary(graphdef)) + # Find all nodes that are needed by the outputs + used_node_names = _bfs_for_reachable_nodes(output_nodes, name_to_input_name) + return set([name_to_node[node_name].op for node_name in used_node_names]) + + def _countIdentities(self, nodes): + """Count the number of "Identity" op types in the list of proto nodes. + + Args: + nodes: NodeDefs of the graph. + + Returns: + The number of nodes with op type "Identity" found. + """ + return len([x for x in nodes if x.op == "Identity"]) + + def testSwishLiteHint(self): + """Makes a custom op swish and makes sure it gets converted as a unit.""" + image = array_ops.constant([1., 2., 3., 4.]) + swish_scale = array_ops.constant(1.0) + + def _swish(input_tensor, scale): + custom = lite.OpHint("cool_activation") + input_tensor, scale = custom.add_inputs(input_tensor, scale) + output = math_ops.sigmoid(input_tensor) * input_tensor * scale + output, = custom.add_outputs(output) + return output + output = array_ops.identity(_swish(image, swish_scale), name="ModelOutput") + + with self.test_session() as sess: + # check if identities have been put into the graph (2 input, 1 output, + # and 1 final output). + self.assertEqual(self._countIdentities(sess.graph_def.node), 4) + + stubbed_graphdef = lite.convert_op_hints_to_stubs(sess) + + self.assertCountEqual( + self._getGraphOpTypes( + stubbed_graphdef, output_nodes=[_tensor_name_base(output)]), + ["cool_activation", "Const", "Identity"]) + + def testScaleAndBiasAndIdentity(self): + """This tests a scaled add which has 3 inputs and 2 outputs.""" + a = array_ops.constant(1.) + x = array_ops.constant([2., 3.]) + b = array_ops.constant([4., 5.]) + + def _scaled_and_bias_and_identity(a, x, b): + custom = lite.OpHint("scale_and_bias_and_identity") + a, x, b = custom.add_inputs(a, x, b) + return custom.add_outputs(a * x + b, x) + output = array_ops.identity(_scaled_and_bias_and_identity(a, x, b), + name="ModelOutput") + + with self.test_session() as sess: + # make sure one identity for each input (3) and output (2) => 3 + 2 = 5 + # +1 for the final output + self.assertEqual(self._countIdentities(sess.graph_def.node), 6) + + stubbed_graphdef = lite.convert_op_hints_to_stubs(sess) + + self.assertCountEqual( + self._getGraphOpTypes( + stubbed_graphdef, output_nodes=[_tensor_name_base(output)]), + ["scale_and_bias_and_identity", "Const", "Identity", "Pack"]) + + def testTwoFunctions(self): + """Tests if two functions are converted correctly.""" + a = array_ops.constant([1.]) + b = array_ops.constant([1.]) + def _double_values(x): + custom = lite.OpHint("add_test") + x = custom.add_inputs(x) + output = math_ops.multiply(x, x) + output, = custom.add_outputs(output) + return output + output = array_ops.identity( + math_ops.add(_double_values(a), _double_values(b)), name="ModelOutput") + + with self.test_session() as sess: + # make sure one identity for each input (2) and output (2) => 2 + 2 + # +1 for the final output + self.assertEqual(self._countIdentities(sess.graph_def.node), 5) + stubbed_graphdef = lite.convert_op_hints_to_stubs(sess) + self.assertCountEqual( + self._getGraphOpTypes( + stubbed_graphdef, output_nodes=[_tensor_name_base(output)]), + ["add_test", "Const", "Identity", "Add"]) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/lite/python/op_hint.py b/tensorflow/contrib/lite/python/op_hint.py new file mode 100644 index 0000000000000000000000000000000000000000..7c587e38b16dc3011fc7c8bef4eec4d0ea99ec21 --- /dev/null +++ b/tensorflow/contrib/lite/python/op_hint.py @@ -0,0 +1,291 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Define tflite op hints (intrinsic operations). + +This essentially allows defining a TensorFlow API for tflite operations in +Python with hints on how they are represented in TensorFlow Lite. This basically +is a form of tflite intrinsic. It wraps a subpart of a TensorFlow execution +graph and is useful for LSTMs and other complicated TensorFlow constructions +that are difficult to pattern match in TOCO, but are represented by a single +accelerated tflite op. + +Example: + def tflite_cool_activation(input): + # A cool activation function. + custom = tf.contrib.lite.OpHint("cool_activation") + input = custom.add_inputs(input) + output = tf.sigmoid(input) * input + custom.add_outputs(output) + return output + + image = tf.placeholder(tf.float32, (1, 16, 16, 1)) + output = tf.identity(tflite_cool_activation(image)) + + session = tf.Session() + + graphdef_to_convert = tf.contrib.lite.convert_op_hints_to_stubs(session) + tflite_graph = tf.contrib.lite.toco_convert(graphdef_to_convert, + [image], [output]) + [image], [output]) + with open("/tmp/graph.fb", "wb") as fp: + fp.write(tflite_graph) + +How does it work?: + +OpHint is a helper that you use when defining a vanilla python function. +It allows you to wrap arguments with tf.identities with some custom attributes. +These attributes allow you to find the original block of ops that was created. +For example, if you use cool_activation above you essentially get: + +a_input = tf.identity() +result = tf.multiply(tf.sigmoid(a_input), a_input) +output = tf.identity() + +a_input, output are identities that have parameters representing +what argument they are, what the name of the function they should turn into +in tf lite as well as a guid that uniquely identifies a particular invocation. + +Once you have built your whole tensorflow graph, you can run it and train it +as usual, but after you have done that, you need to convert the graph into +a form that replaces these subgraphs wrapped in identities to stub ops. These +ops don't actually exist in the normal TensorFlow runtime, but will be +understood by toco later. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections as _collections +import itertools as _itertools +import uuid as _uuid + +from tensorflow.contrib import framework as _framework +from tensorflow.python.framework import ops as _ops +from tensorflow.python.ops import array_ops as _array_ops +from tensorflow.python.util.all_util import remove_undocumented + + +class OpHint(object): + """A class that helps build tflite function invocations. + + It allows you to take a bunch of TensorFlow ops and annotate the construction + such that toco knows how to convert it to tflite. This embeds a pseudo + function in a TensorFlow graph. This allows embedding high-level API usage + information in a lower level TensorFlow implementation so that an alternative + implementation can be substituted later. + + Essentially, any "input" into this pseudo op is fed into an identity, and + attributes are added to that input before being used by the constituent ops + that make up the pseudo op. A similar process is done to any output that + is to be exported from the current op. + + TODO(aselle): When TensorFlow functions functionality works for arbitrary + constructs, this mechanism can be retired and changed to use python defun's. + """ + + # Attr constants that are used for representation in the GraphDef + FUNCTION_NAME_ATTR = "_tflite_function_name" + FUNCTION_UUID_ATTR = "_tflite_function_uuid" + FUNCTION_INPUT_INDEX_ATTR = "_tflite_function_input_index" + FUNCTION_OUTPUT_INDEX_ATTR = "_tflite_function_output_index" + + def __init__(self, function_name, **kwargs): + """Create a OpHint. + + Args: + function_name: Name of the function (the custom op name in tflite) + **kwargs: Keyword arguments of any constant attributes for the function. + """ + self._function_name = function_name + self._unique_function_id = _uuid.uuid1().hex # TODO(aselle): Unique enough? + self._curr_input_index = 0 + self._curr_output_index = 0 + self._attrs_to_store_later = kwargs + self._stored_attrs = False + + def _setattr(self, dest_op, name, value): + tensor_value = _ops.convert_to_tensor(value) + dest_op.op.node_def.attr[name].tensor.CopyFrom( + tensor_value.op.node_def.attr["value"].tensor) + + def add_inputs(self, *args): + """Add a sequence of inputs to the function invocation. + + Args: + *args: List of inputs to be converted (should be Tf.Tensor). + Returns: + Wrapped inputs (identity standins that have additional metadata). These + are also are also tf.Tensor's. + """ + + def augmented_identity(arg): + identity_op = _array_ops.identity(arg) + attr = identity_op.op.node_def.attr + attr[OpHint.FUNCTION_NAME_ATTR].s = self._function_name + attr[OpHint.FUNCTION_UUID_ATTR].s = self._unique_function_id + attr[OpHint.FUNCTION_INPUT_INDEX_ATTR].i = self._curr_input_index + self._curr_input_index += 1 + return identity_op + + return [augmented_identity(arg) for arg in args] + + def add_outputs(self, *args): + """Add a sequence of outputs to the function invocation. + + Args: + *args: List of outputs to be converted (should be tf.Tensor). + Returns: + Wrapped outputs (identity standins that have additional metadata). These + are also tf.Tensor's. + """ + + def augmented_identity(arg): + identity_op = _array_ops.identity(arg) + attr = identity_op.op.node_def.attr + attr[OpHint.FUNCTION_NAME_ATTR].s = self._function_name + attr[OpHint.FUNCTION_UUID_ATTR].s = self._unique_function_id + attr[OpHint.FUNCTION_OUTPUT_INDEX_ATTR].i = self._curr_output_index + self._curr_output_index += 1 + return identity_op + + wrapped_outputs = [augmented_identity(arg) for arg in args] + + if not self._stored_attrs: + for key, value in self._attrs_to_store_later.iteritems(): + self._setattr(wrapped_outputs[0], "_tflite_attr_" + key, value) + self._stored_attrs = True + + return wrapped_outputs + + +class _LiteFuncCall(object): + """Represent a TensorFlow Lite custom function. + + This is uses to accumulate found hints in the graphdef into a single + conceptual unit. + + Properties: + self.inputs: inputs to the op (hash from index # to argument) + self.outputs: outputs to the op (hash from index # to argument) + self.function_name: the tflite custom op name to use + self.uuid: a unique call id for this particular call (i.e. + multiple function calls would have the same function_name but different + uuids. + self.params: A param name to key value for op constant data. I.e. for + axis on a reduction, strides on a convolution, etc. + """ + + def __init__(self): + self.inputs = {} + self.outputs = {} + self.function_name = None + self.uuid = None + self.params = {} + + def __str__(self): + return "tflite function %s call %s\n\tinputs: %r\n\toutputs: %r" % ( + self.function_name, self.uuid, self.inputs, self.outputs) + + +def _find_all_hints_in_graph_def(session): + """Look at the current default graph and return a list of LiteFuncCall objs. + + Args: + session: A TensorFlow session that contains the graph to convert. + Returns: + a list of `LifeFuncCall` objects in the form + + """ + func_calls = _collections.defaultdict(_LiteFuncCall) + seen_ops = set() + + for op in session.graph.get_operations(): + for operand in _itertools.chain(op.inputs, op.outputs): + if operand in seen_ops: + continue + seen_ops.add(operand) + attr = operand.op.node_def.attr + uuid = attr[OpHint.FUNCTION_UUID_ATTR].s + if OpHint.FUNCTION_UUID_ATTR not in attr: + continue + call_def = func_calls[uuid] + call_def.uuid = uuid + if OpHint.FUNCTION_UUID_ATTR in attr: + call_def.function_name = attr[OpHint.FUNCTION_NAME_ATTR].s + if OpHint.FUNCTION_INPUT_INDEX_ATTR in attr: + call_def.inputs[attr[OpHint.FUNCTION_INPUT_INDEX_ATTR].i] = operand + if OpHint.FUNCTION_OUTPUT_INDEX_ATTR in attr: + call_def.outputs[attr[OpHint.FUNCTION_OUTPUT_INDEX_ATTR].i] = operand + + for a in attr: + if a.startswith("_tflite_attr_"): + # TODO(aselle): Remember the attribute tensors so we can put them + # in collapse. + call_def.params[a.replace("_tflite_attr_,", "")] = attr[a].tensor + + return func_calls + + +def _tensor_name_base(full_tensor_name): + """Removes the device assignment code from a tensor. + + e.g. _tensor_name_base("foo:3") => "foo" + + Args: + full_tensor_name: A tensor name that is annotated with a device placement + (this is what tensor flow introspection gives). + Returns: + A name without any device assignment. + """ + return full_tensor_name.name.split(":")[0] + + +def convert_op_hints_to_stubs(session): + """Converts a graphdef with LiteOp hints into stub operations. + + This is used to prepare for toco conversion of complex intrinsic usages. + + Args: + session: A TensorFlow session that contains the graph to convert. + Returns: + A new graphdef with all ops contained in OpHints being replaced by + a single op call with the right parameters. + """ + hints = _find_all_hints_in_graph_def(session) + current_graph_def = session.graph_def + for call in hints.values(): + input_names = [None] * len(call.inputs) + output_names = [None] * len(call.outputs) + output_dtypes = [None] * len(call.outputs) + output_quantized = False + for input_index, tensor in call.inputs.items(): + input_names[input_index] = _tensor_name_base(tensor) + for output_index, tensor in call.outputs.items(): + output_names[output_index] = _tensor_name_base(tensor) + output_dtypes[output_index] = tensor.dtype.as_datatype_enum + # TODO(aselle): Support quantized flag properly + current_graph_def = _framework.fuse_op( + current_graph_def, input_names, output_names, output_dtypes, + output_quantized, call.uuid, call.function_name) + for node in current_graph_def.node: + if node.name == call.uuid: + for param, tensor in call.params.items(): + node.attr[param].tensor.CopyFrom(tensor) + return current_graph_def + + +_allowed_symbols = ["OpHint", "convert_op_hints_to_stubs"] +remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/lite/schema/schema.fbs b/tensorflow/contrib/lite/schema/schema.fbs index 50709344eac1af25bd34a218b4aa35eeeeeb85b7..36cc2724eb1d927d39cff25a46a57aca4f572547 100644 --- a/tensorflow/contrib/lite/schema/schema.fbs +++ b/tensorflow/contrib/lite/schema/schema.fbs @@ -273,6 +273,9 @@ table LSTMOptions { } table ResizeBilinearOptions { + new_height: int (deprecated); + new_width: int (deprecated); + align_corners: bool; } // A call operation options @@ -289,15 +292,9 @@ table ReshapeOptions { } table SpaceToBatchNDOptions { - block_shape:[int]; - before_paddings:[int]; - after_paddings:[int]; } table BatchToSpaceNDOptions { - block_shape:[int]; - before_crops:[int]; - after_crops:[int]; } table SkipGramOptions { @@ -336,7 +333,6 @@ table TransposeOptions { } table MeanOptions { - axis:[int]; keep_dims: bool; } diff --git a/tensorflow/contrib/lite/schema/schema_generated.h b/tensorflow/contrib/lite/schema/schema_generated.h index f1ee925df23894470dd951bf5ad00655a7614da0..e2ac0b9d1e05cdc7e89da32107044320d6e4ea5a 100755 --- a/tensorflow/contrib/lite/schema/schema_generated.h +++ b/tensorflow/contrib/lite/schema/schema_generated.h @@ -2626,28 +2626,36 @@ flatbuffers::Offset CreateLSTMOptions( struct ResizeBilinearOptionsT : public flatbuffers::NativeTable { typedef ResizeBilinearOptions TableType; - ResizeBilinearOptionsT() {} + bool align_corners; + ResizeBilinearOptionsT() + : align_corners(false) { + } }; -struct ResizeBilinearOptions FLATBUFFERS_FINAL_CLASS - : private flatbuffers::Table { +struct ResizeBilinearOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { typedef ResizeBilinearOptionsT NativeTableType; + enum { + VT_ALIGN_CORNERS = 8 + }; + bool align_corners() const { + return GetField(VT_ALIGN_CORNERS, 0) != 0; + } bool Verify(flatbuffers::Verifier &verifier) const { - return VerifyTableStart(verifier) && verifier.EndTable(); + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_ALIGN_CORNERS) && + verifier.EndTable(); } - ResizeBilinearOptionsT *UnPack( - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - void UnPackTo( - ResizeBilinearOptionsT *_o, - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - static flatbuffers::Offset Pack( - flatbuffers::FlatBufferBuilder &_fbb, const ResizeBilinearOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); + ResizeBilinearOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(ResizeBilinearOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const ResizeBilinearOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); }; struct ResizeBilinearOptionsBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; + void add_align_corners(bool align_corners) { + fbb_.AddElement(ResizeBilinearOptions::VT_ALIGN_CORNERS, static_cast(align_corners), 0); + } explicit ResizeBilinearOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); @@ -2661,14 +2669,14 @@ struct ResizeBilinearOptionsBuilder { }; inline flatbuffers::Offset CreateResizeBilinearOptions( - flatbuffers::FlatBufferBuilder &_fbb) { + flatbuffers::FlatBufferBuilder &_fbb, + bool align_corners = false) { ResizeBilinearOptionsBuilder builder_(_fbb); + builder_.add_align_corners(align_corners); return builder_.Finish(); } -flatbuffers::Offset CreateResizeBilinearOptions( - flatbuffers::FlatBufferBuilder &_fbb, const ResizeBilinearOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); +flatbuffers::Offset CreateResizeBilinearOptions(flatbuffers::FlatBufferBuilder &_fbb, const ResizeBilinearOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); struct CallOptionsT : public flatbuffers::NativeTable { typedef CallOptions TableType; @@ -2834,33 +2842,14 @@ flatbuffers::Offset CreateReshapeOptions( struct SpaceToBatchNDOptionsT : public flatbuffers::NativeTable { typedef SpaceToBatchNDOptions TableType; - std::vector block_shape; - std::vector before_paddings; - std::vector after_paddings; SpaceToBatchNDOptionsT() {} }; struct SpaceToBatchNDOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { typedef SpaceToBatchNDOptionsT NativeTableType; - enum { VT_BLOCK_SHAPE = 4, VT_BEFORE_PADDINGS = 6, VT_AFTER_PADDINGS = 8 }; - const flatbuffers::Vector *block_shape() const { - return GetPointer *>(VT_BLOCK_SHAPE); - } - const flatbuffers::Vector *before_paddings() const { - return GetPointer *>(VT_BEFORE_PADDINGS); - } - const flatbuffers::Vector *after_paddings() const { - return GetPointer *>(VT_AFTER_PADDINGS); - } bool Verify(flatbuffers::Verifier &verifier) const { - return VerifyTableStart(verifier) && - VerifyOffset(verifier, VT_BLOCK_SHAPE) && - verifier.Verify(block_shape()) && - VerifyOffset(verifier, VT_BEFORE_PADDINGS) && - verifier.Verify(before_paddings()) && - VerifyOffset(verifier, VT_AFTER_PADDINGS) && - verifier.Verify(after_paddings()) && verifier.EndTable(); + return VerifyTableStart(verifier) && verifier.EndTable(); } SpaceToBatchNDOptionsT *UnPack( const flatbuffers::resolver_function_t *_resolver = nullptr) const; @@ -2875,18 +2864,6 @@ struct SpaceToBatchNDOptions FLATBUFFERS_FINAL_CLASS struct SpaceToBatchNDOptionsBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; - void add_block_shape( - flatbuffers::Offset> block_shape) { - fbb_.AddOffset(SpaceToBatchNDOptions::VT_BLOCK_SHAPE, block_shape); - } - void add_before_paddings( - flatbuffers::Offset> before_paddings) { - fbb_.AddOffset(SpaceToBatchNDOptions::VT_BEFORE_PADDINGS, before_paddings); - } - void add_after_paddings( - flatbuffers::Offset> after_paddings) { - fbb_.AddOffset(SpaceToBatchNDOptions::VT_AFTER_PADDINGS, after_paddings); - } explicit SpaceToBatchNDOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); @@ -2900,62 +2877,25 @@ struct SpaceToBatchNDOptionsBuilder { }; inline flatbuffers::Offset CreateSpaceToBatchNDOptions( - flatbuffers::FlatBufferBuilder &_fbb, - flatbuffers::Offset> block_shape = 0, - flatbuffers::Offset> before_paddings = 0, - flatbuffers::Offset> after_paddings = 0) { + flatbuffers::FlatBufferBuilder &_fbb) { SpaceToBatchNDOptionsBuilder builder_(_fbb); - builder_.add_after_paddings(after_paddings); - builder_.add_before_paddings(before_paddings); - builder_.add_block_shape(block_shape); return builder_.Finish(); } -inline flatbuffers::Offset -CreateSpaceToBatchNDOptionsDirect( - flatbuffers::FlatBufferBuilder &_fbb, - const std::vector *block_shape = nullptr, - const std::vector *before_paddings = nullptr, - const std::vector *after_paddings = nullptr) { - return tflite::CreateSpaceToBatchNDOptions( - _fbb, block_shape ? _fbb.CreateVector(*block_shape) : 0, - before_paddings ? _fbb.CreateVector(*before_paddings) : 0, - after_paddings ? _fbb.CreateVector(*after_paddings) : 0); -} - flatbuffers::Offset CreateSpaceToBatchNDOptions( flatbuffers::FlatBufferBuilder &_fbb, const SpaceToBatchNDOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); struct BatchToSpaceNDOptionsT : public flatbuffers::NativeTable { typedef BatchToSpaceNDOptions TableType; - std::vector block_shape; - std::vector before_crops; - std::vector after_crops; BatchToSpaceNDOptionsT() {} }; struct BatchToSpaceNDOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { typedef BatchToSpaceNDOptionsT NativeTableType; - enum { VT_BLOCK_SHAPE = 4, VT_BEFORE_CROPS = 6, VT_AFTER_CROPS = 8 }; - const flatbuffers::Vector *block_shape() const { - return GetPointer *>(VT_BLOCK_SHAPE); - } - const flatbuffers::Vector *before_crops() const { - return GetPointer *>(VT_BEFORE_CROPS); - } - const flatbuffers::Vector *after_crops() const { - return GetPointer *>(VT_AFTER_CROPS); - } bool Verify(flatbuffers::Verifier &verifier) const { - return VerifyTableStart(verifier) && - VerifyOffset(verifier, VT_BLOCK_SHAPE) && - verifier.Verify(block_shape()) && - VerifyOffset(verifier, VT_BEFORE_CROPS) && - verifier.Verify(before_crops()) && - VerifyOffset(verifier, VT_AFTER_CROPS) && - verifier.Verify(after_crops()) && verifier.EndTable(); + return VerifyTableStart(verifier) && verifier.EndTable(); } BatchToSpaceNDOptionsT *UnPack( const flatbuffers::resolver_function_t *_resolver = nullptr) const; @@ -2970,18 +2910,6 @@ struct BatchToSpaceNDOptions FLATBUFFERS_FINAL_CLASS struct BatchToSpaceNDOptionsBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; - void add_block_shape( - flatbuffers::Offset> block_shape) { - fbb_.AddOffset(BatchToSpaceNDOptions::VT_BLOCK_SHAPE, block_shape); - } - void add_before_crops( - flatbuffers::Offset> before_crops) { - fbb_.AddOffset(BatchToSpaceNDOptions::VT_BEFORE_CROPS, before_crops); - } - void add_after_crops( - flatbuffers::Offset> after_crops) { - fbb_.AddOffset(BatchToSpaceNDOptions::VT_AFTER_CROPS, after_crops); - } explicit BatchToSpaceNDOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); @@ -2995,29 +2923,11 @@ struct BatchToSpaceNDOptionsBuilder { }; inline flatbuffers::Offset CreateBatchToSpaceNDOptions( - flatbuffers::FlatBufferBuilder &_fbb, - flatbuffers::Offset> block_shape = 0, - flatbuffers::Offset> before_crops = 0, - flatbuffers::Offset> after_crops = 0) { + flatbuffers::FlatBufferBuilder &_fbb) { BatchToSpaceNDOptionsBuilder builder_(_fbb); - builder_.add_after_crops(after_crops); - builder_.add_before_crops(before_crops); - builder_.add_block_shape(block_shape); return builder_.Finish(); } -inline flatbuffers::Offset -CreateBatchToSpaceNDOptionsDirect( - flatbuffers::FlatBufferBuilder &_fbb, - const std::vector *block_shape = nullptr, - const std::vector *before_crops = nullptr, - const std::vector *after_crops = nullptr) { - return tflite::CreateBatchToSpaceNDOptions( - _fbb, block_shape ? _fbb.CreateVector(*block_shape) : 0, - before_crops ? _fbb.CreateVector(*before_crops) : 0, - after_crops ? _fbb.CreateVector(*after_crops) : 0); -} - flatbuffers::Offset CreateBatchToSpaceNDOptions( flatbuffers::FlatBufferBuilder &_fbb, const BatchToSpaceNDOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); @@ -3437,21 +3347,16 @@ flatbuffers::Offset CreateTransposeOptions( struct MeanOptionsT : public flatbuffers::NativeTable { typedef MeanOptions TableType; - std::vector axis; bool keep_dims; MeanOptionsT() : keep_dims(false) {} }; struct MeanOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { typedef MeanOptionsT NativeTableType; - enum { VT_AXIS = 4, VT_KEEP_DIMS = 6 }; - const flatbuffers::Vector *axis() const { - return GetPointer *>(VT_AXIS); - } + enum { VT_KEEP_DIMS = 4 }; bool keep_dims() const { return GetField(VT_KEEP_DIMS, 0) != 0; } bool Verify(flatbuffers::Verifier &verifier) const { - return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_AXIS) && - verifier.Verify(axis()) && + return VerifyTableStart(verifier) && VerifyField(verifier, VT_KEEP_DIMS) && verifier.EndTable(); } MeanOptionsT *UnPack( @@ -3467,9 +3372,6 @@ struct MeanOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { struct MeanOptionsBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; - void add_axis(flatbuffers::Offset> axis) { - fbb_.AddOffset(MeanOptions::VT_AXIS, axis); - } void add_keep_dims(bool keep_dims) { fbb_.AddElement(MeanOptions::VT_KEEP_DIMS, static_cast(keep_dims), 0); @@ -3487,22 +3389,12 @@ struct MeanOptionsBuilder { }; inline flatbuffers::Offset CreateMeanOptions( - flatbuffers::FlatBufferBuilder &_fbb, - flatbuffers::Offset> axis = 0, - bool keep_dims = false) { + flatbuffers::FlatBufferBuilder &_fbb, bool keep_dims = false) { MeanOptionsBuilder builder_(_fbb); - builder_.add_axis(axis); builder_.add_keep_dims(keep_dims); return builder_.Finish(); } -inline flatbuffers::Offset CreateMeanOptionsDirect( - flatbuffers::FlatBufferBuilder &_fbb, - const std::vector *axis = nullptr, bool keep_dims = false) { - return tflite::CreateMeanOptions( - _fbb, axis ? _fbb.CreateVector(*axis) : 0, keep_dims); -} - flatbuffers::Offset CreateMeanOptions( flatbuffers::FlatBufferBuilder &_fbb, const MeanOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); @@ -5706,33 +5598,6 @@ inline void SpaceToBatchNDOptions::UnPackTo( const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { - auto _e = block_shape(); - if (_e) { - _o->block_shape.resize(_e->size()); - for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { - _o->block_shape[_i] = _e->Get(_i); - } - } - }; - { - auto _e = before_paddings(); - if (_e) { - _o->before_paddings.resize(_e->size()); - for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { - _o->before_paddings[_i] = _e->Get(_i); - } - } - }; - { - auto _e = after_paddings(); - if (_e) { - _o->after_paddings.resize(_e->size()); - for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { - _o->after_paddings[_i] = _e->Get(_i); - } - } - }; } inline flatbuffers::Offset SpaceToBatchNDOptions::Pack( @@ -5752,14 +5617,7 @@ inline flatbuffers::Offset CreateSpaceToBatchNDOptions( const flatbuffers::rehasher_function_t *__rehasher; } _va = {&_fbb, _o, _rehasher}; (void)_va; - auto _block_shape = - _o->block_shape.size() ? _fbb.CreateVector(_o->block_shape) : 0; - auto _before_paddings = - _o->before_paddings.size() ? _fbb.CreateVector(_o->before_paddings) : 0; - auto _after_paddings = - _o->after_paddings.size() ? _fbb.CreateVector(_o->after_paddings) : 0; - return tflite::CreateSpaceToBatchNDOptions(_fbb, _block_shape, - _before_paddings, _after_paddings); + return tflite::CreateSpaceToBatchNDOptions(_fbb); } inline BatchToSpaceNDOptionsT *BatchToSpaceNDOptions::UnPack( @@ -5774,33 +5632,6 @@ inline void BatchToSpaceNDOptions::UnPackTo( const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { - auto _e = block_shape(); - if (_e) { - _o->block_shape.resize(_e->size()); - for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { - _o->block_shape[_i] = _e->Get(_i); - } - } - }; - { - auto _e = before_crops(); - if (_e) { - _o->before_crops.resize(_e->size()); - for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { - _o->before_crops[_i] = _e->Get(_i); - } - } - }; - { - auto _e = after_crops(); - if (_e) { - _o->after_crops.resize(_e->size()); - for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { - _o->after_crops[_i] = _e->Get(_i); - } - } - }; } inline flatbuffers::Offset BatchToSpaceNDOptions::Pack( @@ -5820,14 +5651,7 @@ inline flatbuffers::Offset CreateBatchToSpaceNDOptions( const flatbuffers::rehasher_function_t *__rehasher; } _va = {&_fbb, _o, _rehasher}; (void)_va; - auto _block_shape = - _o->block_shape.size() ? _fbb.CreateVector(_o->block_shape) : 0; - auto _before_crops = - _o->before_crops.size() ? _fbb.CreateVector(_o->before_crops) : 0; - auto _after_crops = - _o->after_crops.size() ? _fbb.CreateVector(_o->after_crops) : 0; - return tflite::CreateBatchToSpaceNDOptions(_fbb, _block_shape, _before_crops, - _after_crops); + return tflite::CreateBatchToSpaceNDOptions(_fbb); } inline SkipGramOptionsT *SkipGramOptions::UnPack( @@ -6122,15 +5946,6 @@ inline void MeanOptions::UnPackTo( MeanOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { - auto _e = axis(); - if (_e) { - _o->axis.resize(_e->size()); - for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { - _o->axis[_i] = _e->Get(_i); - } - } - }; { auto _e = keep_dims(); _o->keep_dims = _e; @@ -6154,9 +5969,8 @@ inline flatbuffers::Offset CreateMeanOptions( const flatbuffers::rehasher_function_t *__rehasher; } _va = {&_fbb, _o, _rehasher}; (void)_va; - auto _axis = _o->axis.size() ? _fbb.CreateVector(_o->axis) : 0; auto _keep_dims = _o->keep_dims; - return tflite::CreateMeanOptions(_fbb, _axis, _keep_dims); + return tflite::CreateMeanOptions(_fbb, _keep_dims); } inline SqueezeOptionsT *SqueezeOptions::UnPack( diff --git a/tensorflow/contrib/lite/testing/BUILD b/tensorflow/contrib/lite/testing/BUILD index 50e8ca75f8efd600d4773b83cd2c8de11c9d13ca..b949045128fc15b6abe8f6c59d63dfd2b47c3c30 100644 --- a/tensorflow/contrib/lite/testing/BUILD +++ b/tensorflow/contrib/lite/testing/BUILD @@ -197,7 +197,7 @@ cc_binary( tf_cc_test( name = "generated_examples_zip_test", - size = "medium", + size = "large", srcs = ["generated_examples_zip_test.cc"], args = [ "--zip_files_dir=tensorflow/contrib/lite/testing/optest", @@ -206,7 +206,7 @@ tf_cc_test( "--unzip_binary_path=/usr/bin/unzip", ], data = [":optest"], - shard_count = 10, + shard_count = 20, tags = ["no_oss"], deps = [ ":parse_testdata_lib", diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/contrib/lite/testing/generate_examples.py index bc9b23aeb42dd454af95756f640e6e9ddad3cdde..147400ec37c606308244cc862bbee5b88ba553ec 100644 --- a/tensorflow/contrib/lite/testing/generate_examples.py +++ b/tensorflow/contrib/lite/testing/generate_examples.py @@ -94,7 +94,8 @@ KNOWN_BUGS = { r"softmax.*input_shape=\[1,3,4,3\]": "67749831", # SpaceToDepth only supports float32. r"space_to_depth.*(float16|int32|uint8|int64)": "68018134", - # BatchToSpaceND doesn't support cropping. + # BatchToSpaceND doesn't support cropping. This catches test cases with + # const tensors as crops. r"batch_to_space_nd.*crops=\[\[1,1\],\[1,1\]\]": "70594634", # BatchToSpaceND only supports 4D tensors. r"batch_to_space_nd.*input_shape=\[8,2,2,2,1,1\]": "70594733", @@ -618,7 +619,7 @@ def make_constant_tests(zip_path): def build_graph(parameters): # Since Toco & Tflite can't have a single constant op in the entire graph, - # this test adds a zero tesnor with a constant op tensor. + # this test adds a zero tensor with a constant op tensor. input1 = tf.placeholder(dtype=parameters["dtype"], name="input1", shape=parameters["input_shape"]) out = tf.ones(parameters["input_shape"], dtype=parameters["dtype"]) + input1 @@ -694,6 +695,7 @@ def make_mean_tests(zip_path): [2, 1], [2, 1, 0], [2, 0, 1], -1, -2, -3, [1, -1], [0, -1], [-1, 0], [-1, -2, -3], [0, 0, 0], [2, 2, 0], [1, 0, -3, -3] ], + "const_axis": [True, False], "keep_dims": [True, False], }, { "input_dtype": [tf.float32, tf.int32, tf.int64], @@ -704,6 +706,7 @@ def make_mean_tests(zip_path): -3, -4, [0, -2], [2, 3, -1, 0], [3, 1, 2, -3], [3, -4], [2, 2, 2], [2, 2, 3], [-3, -3, -4], [-3, 2, 1] ], + "const_axis": [True, False], "keep_dims": [True, False], }] @@ -713,17 +716,31 @@ def make_mean_tests(zip_path): dtype=parameters["input_dtype"], name="input", shape=parameters["input_shape"]) + + # Get axis as either a placeholder or constants. + if parameters["const_axis"]: + axis = parameters["axis"] + input_tensors = [input_tensor] + else: + if isinstance(parameters["axis"], list): + shape = [len(parameters["axis"])] + else: + shape = [0] # shape for None or integers. + axis = tf.placeholder(dtype=tf.int32, name="axis", shape=shape) + input_tensors = [input_tensor, axis] + out = tf.reduce_mean( - input_tensor, - axis=parameters["axis"], - keep_dims=parameters["keep_dims"]) - return [input_tensor], [out] + input_tensor, axis=axis, keep_dims=parameters["keep_dims"]) + return input_tensors, [out] def build_inputs(parameters, sess, inputs, outputs): - input_values = create_tensor_data(parameters["input_dtype"], - parameters["input_shape"]) - return [input_values], sess.run( - outputs, feed_dict=dict(zip(inputs, [input_values]))) + values = [ + create_tensor_data(parameters["input_dtype"], parameters["input_shape"]) + ] + if not parameters["const_axis"]: + if parameters["axis"]: + values.append(np.array(parameters["axis"])) + return values, sess.run(outputs, feed_dict=dict(zip(inputs, values))) make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) @@ -984,13 +1001,15 @@ def make_concatenation_tests(zip_path): test_parameters = [{ "base_shape": [[1, 3, 4, 3], [3, 4]], "num_tensors": [1, 2, 3, 4, 5, 6], - "axis": [0, 1, 2, 3], + "axis": [0, 1, 2, 3, -3, -2, -1], }] def get_shape(parameters, delta): """Return a tweaked version of 'base_shape'.""" axis = parameters["axis"] shape = parameters["base_shape"][:] + if axis < 0: + axis += len(shape) if axis < len(shape): shape[axis] += delta return shape @@ -1318,12 +1337,16 @@ def make_space_to_batch_nd_tests(zip_path): "input_shape": [[1, 2, 2, 3], [2, 2, 4, 1]], "block_shape": [[1, 3], [2, 2]], "paddings": [[[0, 0], [0, 0]], [[0, 0], [2, 0]], [[1, 1], [1, 1]]], + "constant_block_shape": [True, False], + "constant_paddings": [True, False], }, { "dtype": [tf.float32], "input_shape": [[2, 3, 7, 3]], "block_shape": [[1, 3], [2, 2]], "paddings": [[[0, 0], [2, 0]], [[1, 0], [1, 0]]], + "constant_block_shape": [True, False], + "constant_paddings": [True, False], }, # Non-4D use case: 1 bath dimension, 3 spatial dimensions, 2 others. { @@ -1331,23 +1354,47 @@ def make_space_to_batch_nd_tests(zip_path): "input_shape": [[1, 4, 4, 4, 1, 1]], "block_shape": [[2, 2, 2]], "paddings": [[[0, 0], [0, 0], [0, 0]]], + "constant_block_shape": [True, False], + "constant_paddings": [True, False], }, ] def build_graph(parameters): + """Build a space_to_batch graph given `parameters`.""" input_tensor = tf.placeholder( dtype=parameters["dtype"], name="input", shape=parameters["input_shape"]) - out = tf.space_to_batch_nd(input_tensor, parameters["block_shape"], - parameters["paddings"]) - return [input_tensor], [out] + input_tensors = [input_tensor] + + # Get block_shape either as a const or as a placeholder (tensor). + if parameters["constant_block_shape"]: + block_shape = parameters["block_shape"] + else: + shape = [len(parameters["block_shape"])] + block_shape = tf.placeholder(dtype=tf.int32, name="shape", shape=shape) + input_tensors.append(block_shape) + + # Get paddings either as a const or as a placeholder (tensor). + if parameters["constant_paddings"]: + paddings = parameters["paddings"] + else: + shape = [len(parameters["paddings"]), 2] + paddings = tf.placeholder(dtype=tf.int32, name="paddings", shape=shape) + input_tensors.append(paddings) + + out = tf.space_to_batch_nd(input_tensor, block_shape, paddings) + return input_tensors, [out] def build_inputs(parameters, sess, inputs, outputs): - input_values = create_tensor_data(parameters["dtype"], - parameters["input_shape"]) - return [input_values], sess.run( - outputs, feed_dict=dict(zip(inputs, [input_values]))) + values = [ + create_tensor_data(parameters["dtype"], parameters["input_shape"]) + ] + if not parameters["constant_block_shape"]: + values.append(np.array(parameters["block_shape"])) + if not parameters["constant_paddings"]: + values.append(np.array(parameters["paddings"])) + return values, sess.run(outputs, feed_dict=dict(zip(inputs, values))) make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) @@ -1361,6 +1408,8 @@ def make_batch_to_space_nd_tests(zip_path): "input_shape": [[12, 2, 2, 1]], "block_shape": [[1, 4], [2, 2], [3, 4]], "crops": [[[0, 0], [0, 0]], [[1, 1], [1, 1]]], + "constant_block_shape": [True, False], + "constant_crops": [True, False], }, # Non-4D use case: 1 bath dimension, 3 spatial dimensions, 2 others. { @@ -1368,23 +1417,47 @@ def make_batch_to_space_nd_tests(zip_path): "input_shape": [[8, 2, 2, 2, 1, 1]], "block_shape": [[2, 2, 2]], "crops": [[[0, 0], [0, 0], [0, 0]]], + "constant_block_shape": [True, False], + "constant_crops": [True, False], }, ] def build_graph(parameters): + """Build a batch_to_space graph given `parameters`.""" input_tensor = tf.placeholder( dtype=parameters["dtype"], name="input", shape=parameters["input_shape"]) - out = tf.batch_to_space_nd(input_tensor, parameters["block_shape"], - parameters["crops"]) - return [input_tensor], [out] + input_tensors = [input_tensor] + + # Get block_shape either as a const or as a placeholder (tensor). + if parameters["constant_block_shape"]: + block_shape = parameters["block_shape"] + else: + shape = [len(parameters["block_shape"])] + block_shape = tf.placeholder(dtype=tf.int32, name="shape", shape=shape) + input_tensors.append(block_shape) + + # Get crops either as a const or as a placeholder (tensor). + if parameters["constant_crops"]: + crops = parameters["crops"] + else: + shape = [len(parameters["crops"]), 2] + crops = tf.placeholder(dtype=tf.int32, name="crops", shape=shape) + input_tensors.append(crops) + + out = tf.batch_to_space_nd(input_tensor, block_shape, crops) + return input_tensors, [out] def build_inputs(parameters, sess, inputs, outputs): - input_values = create_tensor_data(parameters["dtype"], - parameters["input_shape"]) - return [input_values], sess.run( - outputs, feed_dict=dict(zip(inputs, [input_values]))) + values = [ + create_tensor_data(parameters["dtype"], parameters["input_shape"]) + ] + if not parameters["constant_block_shape"]: + values.append(np.array(parameters["block_shape"])) + if not parameters["constant_crops"]: + values.append(np.array(parameters["crops"])) + return values, sess.run(outputs, feed_dict=dict(zip(inputs, values))) make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) @@ -1489,9 +1562,11 @@ def make_strided_slice_tests(zip_path): "input_shape": [[12, 2, 2, 5]], "begin": [[0, 0, 0, 0], [1, 0, 1, 0]], "end": [[8, 2, 2, 3], [12, 2, 2, 5]], - "strides": [None, [1, 1, 1, 1], [2, 1, 3, 1]], - "begin_mask": [None, 0, 1, 2, 8], - "end_mask": [None, 0, 1, 2, 8], + "strides": [None, [2, 1, 3, 1]], + "begin_mask": [None, 1, 8], + "end_mask": [None, 1, 8], + "shrink_axis_mask": [None, 1, 8, 11, 15, -1], + "constant_indices": [False, True], }, # 2-D { @@ -1500,9 +1575,11 @@ def make_strided_slice_tests(zip_path): "input_shape": [[2, 3]], "begin": [[0, 0], [1, 0]], "end": [[2, 3], [2, 2]], - "strides": [None, [1, 1], [2, 2]], - "begin_mask": [None, 0, 1, 2], - "end_mask": [None, 0, 1, 2], + "strides": [None, [2, 2]], + "begin_mask": [None, 1, 2], + "end_mask": [None, 1, 2], + "shrink_axis_mask": [None, 1, 2, 3, -1], + "constant_indices": [False, True], }, # Negative strides { @@ -1512,8 +1589,10 @@ def make_strided_slice_tests(zip_path): "begin": [[0, -1]], "end": [[2, -3]], "strides": [[1, -1]], - "begin_mask": [None, 0, 1, 2], - "end_mask": [None, 0, 1, 2], + "begin_mask": [None, 1, 2], + "end_mask": [None, 1, 2], + "shrink_axis_mask": [None, 1, 2, 3, -1], + "constant_indices": [False], }, ] @@ -1523,23 +1602,29 @@ def make_strided_slice_tests(zip_path): dtype=parameters["dtype"], name="input", shape=parameters["input_shape"]) - begin = tf.placeholder( - dtype=parameters["index_type"], - name="begin", - shape=[len(parameters["input_shape"])]) - end = tf.placeholder( - dtype=parameters["index_type"], - name="end", - shape=[len(parameters["input_shape"])]) - strides = ( - tf.placeholder( - dtype=parameters["index_type"], - name="strides", - shape=[len(parameters["input_shape"])]) - if parameters["strides"] is not None else None) - tensors = [input_tensor, begin, end] - if strides is not None: - tensors.append(strides) + if parameters["constant_indices"]: + begin = parameters["begin"] + end = parameters["end"] + strides = parameters["strides"] + tensors = [input_tensor] + else: + begin = tf.placeholder( + dtype=parameters["index_type"], + name="begin", + shape=[len(parameters["input_shape"])]) + end = tf.placeholder( + dtype=parameters["index_type"], + name="end", + shape=[len(parameters["input_shape"])]) + strides = ( + tf.placeholder( + dtype=parameters["index_type"], + name="strides", + shape=[len(parameters["input_shape"])]) + if parameters["strides"] is not None else None) + tensors = [input_tensor, begin, end] + if strides is not None: + tensors.append(strides) out = tf.strided_slice( input_tensor, begin, @@ -1554,14 +1639,17 @@ def make_strided_slice_tests(zip_path): input_values = create_tensor_data(parameters["dtype"], parameters["input_shape"]) index_type = _TF_TYPE_INFO[parameters["index_type"]][0] - begin_values = np.array(parameters["begin"]).astype(index_type) - end_values = np.array(parameters["end"]).astype(index_type) - stride_values = ( - np.array(parameters["strides"]).astype(index_type) - if parameters["strides"] is not None else None) - values = [input_values, begin_values, end_values] - if stride_values is not None: - values.append(stride_values) + values = [input_values] + if not parameters["constant_indices"]: + begin_values = np.array(parameters["begin"]).astype(index_type) + end_values = np.array(parameters["end"]).astype(index_type) + stride_values = ( + np.array(parameters["strides"]).astype(index_type) + if parameters["strides"] is not None else None) + values.append(begin_values) + values.append(end_values) + if stride_values is not None: + values.append(stride_values) return values, sess.run(outputs, feed_dict=dict(zip(inputs, values))) diff --git a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc index 2bbfe77a12370b9e3246cd060b858922fde12e1e..5ea3e21f6a1636d1e7029bed8e75b2f68f656103 100644 --- a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc +++ b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc @@ -47,9 +47,7 @@ tensorflow::Env* env = tensorflow::Env::Default(); // Key is a substring of the test name and value is a bug number. // TODO(ahentz): make sure we clean this list up frequently. std::map kBrokenTests = { - // Add doesn't support broadcasting. - {R"(^\/adda.*input_shape_1=\[1,3,4,3\],input_shape_2=\[3\])", "68500195"}, - {R"(^\/mula.*input_shape_1=\[1,3,4,3\],input_shape_2=\[3\])", "68500195"}, + // Sub and Div don't support broadcasting. {R"(^\/diva.*input_shape_1=\[1,3,4,3\],input_shape_2=\[3\])", "68500195"}, {R"(^\/suba.*input_shape_1=\[1,3,4,3\],input_shape_2=\[3\])", "68500195"}, @@ -67,7 +65,11 @@ std::map kBrokenTests = { // L2Norm only supports tensors with 4D or fewer. {R"(^\/l2normdim=.*,epsilon=.*,input_shape=\[.,.,.,.,.*\])", "67963684"}, - // SpaceToBatch only supports 4D tensors. + // BatchToSpaceND doesn't support cropping. This catches test cases with + // non-const tensors as crops. + {R"(^\/batch_to_space_nd.*crops=\[\[1,1\],\[1,1\]\])", "70594634"}, + + // SpaceToBatchND only supports 4D tensors. {R"(^\/space_to_batch_nd.*input_shape=\[1,4,4,4,1,1\])", "70848787"}, // L2Norm only works for dim=-1. @@ -87,9 +89,6 @@ std::map kBrokenTests = { // ResizeBilinear looks completely incompatible with Tensorflow {R"(^\/resize_bilinear.*dtype=tf.int32)", "72401107"}, - {R"(^\/resize_bilinearalign_corners=True,.*,size=\[2,2\])", "72401483"}, - {R"(^\/resize_bilinearalign_corners=True,.*,size=\[4,3\])", "72401483"}, - {R"(^\/resize_bilinearalign_corners=True,.*,size=\[5,6\])", "72401483"}, // Transpose only supports 1D-4D input tensors. {R"(^\/transpose.*input_shape=\[.,.,.,.,.\])", "71545879"}, @@ -239,8 +238,7 @@ INSTANTIATE_TESTS(avg_pool) INSTANTIATE_TESTS(space_to_batch_nd) INSTANTIATE_TESTS(batch_to_space_nd) INSTANTIATE_TESTS(concat) -// TODO(b/71642435) re-enable this test -// INSTANTIATE_TESTS(constant) +INSTANTIATE_TESTS(constant) INSTANTIATE_TESTS(control_dep) INSTANTIATE_TESTS(conv) INSTANTIATE_TESTS(depthwiseconv) diff --git a/tensorflow/contrib/lite/toco/BUILD b/tensorflow/contrib/lite/toco/BUILD index 6fc7e5e3fdd4da8f8b224b8c10a6be8154204c94..20c156a93262568cf0c6c349b44fbf3d3afa5bc4 100644 --- a/tensorflow/contrib/lite/toco/BUILD +++ b/tensorflow/contrib/lite/toco/BUILD @@ -205,6 +205,7 @@ cc_library( "graph_transformations/remove_trivial_quantized_activation_func.cc", "graph_transformations/remove_trivial_reshape.cc", "graph_transformations/remove_unused_op.cc", + "graph_transformations/reorder_activation_functions.cc", "graph_transformations/resolve_batch_normalization.cc", "graph_transformations/resolve_batch_to_space_nd_attributes.cc", "graph_transformations/resolve_constant_binary.cc", diff --git a/tensorflow/contrib/lite/toco/export_tensorflow.cc b/tensorflow/contrib/lite/toco/export_tensorflow.cc index 529df3cd2e56f1888f3d431ddcd7dc7051a98355..be6d506bf3d17ce42bf97da222670fb06680fa13 100644 --- a/tensorflow/contrib/lite/toco/export_tensorflow.cc +++ b/tensorflow/contrib/lite/toco/export_tensorflow.cc @@ -621,7 +621,8 @@ void ConvertSoftmaxOperator(const Model& model, const SoftmaxOperator& src_op, GraphDef* tensorflow_graph) { string softmax_input; Operator* providing_op = GetOpWithOutput(model, src_op.inputs[0]); - if (providing_op->type == OperatorType::kTensorFlowReshape) { + if (providing_op != nullptr && + providing_op->type == OperatorType::kTensorFlowReshape) { softmax_input = src_op.inputs[0]; } else { // Insert a reshape operator that reduces the dimensions down to the 2 that @@ -991,6 +992,7 @@ void ConvertResizeBilinearOperator(const Model& model, *resize_op->add_input() = src_op.inputs[0]; *resize_op->add_input() = src_op.inputs[1]; (*resize_op->mutable_attr())["T"].set_type(DT_FLOAT); + (*resize_op->mutable_attr())["align_corners"].set_b(src_op.align_corners); } namespace { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/fuse_activation_functions.cc b/tensorflow/contrib/lite/toco/graph_transformations/fuse_activation_functions.cc index 88e59664ec427841df6f20686238feacef6a47e9..ab943f72d1dd87ae9ff4bd53a807cd4923a88c38 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/fuse_activation_functions.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/fuse_activation_functions.cc @@ -68,12 +68,7 @@ bool FuseActivationFunctions::Run(Model* model, std::size_t op_index) { return false; } - // TODO(b/72172404): Great many ops don't support activation function - // fusing. Switch to a categorizing function instead. - if (op->type == OperatorType::kConcatenation || - op->type == OperatorType::kSlice || - op->type == OperatorType::kTensorFlowReshape || - op->type == OperatorType::kTensorFlowSplit) { + if (!OperatorSupportsFusedActivation(op->type)) { AddMessageF( "Not fusing activation function because the %s op doesn't support it", LogName(*op)); diff --git a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h index e11bebcd4e0f66faf63290e3af0c72c39811cebe..cf90ebe99697fe8a40b4c707e70fdc5318123854 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h +++ b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h @@ -144,6 +144,7 @@ DECLARE_GRAPH_TRANSFORMATION(ResolveConstantUnaryOperator) DECLARE_GRAPH_TRANSFORMATION(CreateIm2colArrays) DECLARE_GRAPH_TRANSFORMATION(DropIm2colArrays) DECLARE_GRAPH_TRANSFORMATION(ReadFakeQuantMinMax) +DECLARE_GRAPH_TRANSFORMATION(ReorderActivationFunctions) DECLARE_GRAPH_TRANSFORMATION(ResolveReorderAxes) DECLARE_GRAPH_TRANSFORMATION(ResolveTensorFlowConcat) DECLARE_GRAPH_TRANSFORMATION(ResolveTensorFlowMatMul) diff --git a/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc b/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc index 9689b205cd137904504d87906cb691d0ed8235bf..f1892136cfe5ea8b82b579226f3c16b279781f49 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc @@ -219,6 +219,12 @@ bool HardcodeMinMax::Run(Model* model, std::size_t op_index) { changed = HardcodeMinMaxForOutput(model, op, 0, 255. / 256.); break; + case OperatorType::kTanh: + // We hardcode quantization_params to: zero_point=127, scale=1/128. + // This choice of minmax is the one that is equivalent to that. + changed = HardcodeMinMaxForOutput(model, op, -127. / 128., 1.0); + break; + default: break; } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc index 4fb3b6ae7a5fc5bfc2719b978331c67ae799eb54..fa7e70d90b421eba7545e458a59c915b5ac183ef 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc @@ -546,6 +546,9 @@ void ProcessConcatenationOperator(Model* model, ConcatenationOperator* op) { // Use 0 input as basis for output dimensions. const auto& first_input_array = model->GetArray(op->inputs[0]); output_array.copy_shape(first_input_array.shape()); + // Negative axis means the count starts at the back of the dims(). + int axis = op->axis; + if (axis < 0) axis += first_input_array.shape().dims().size(); // Determine the concat size, and enfore that all inputs have // the same dimensions count. int concat_size = 0; @@ -558,14 +561,14 @@ void ProcessConcatenationOperator(Model* model, ConcatenationOperator* op) { CHECK_EQ(input_array.shape().dimensions_count(), output_array.shape().dimensions_count()); const std::vector& input_dims = input_array.shape().dims(); - CHECK_LT(op->axis, input_dims.size()); - concat_size += input_dims[op->axis]; + CHECK_LT(axis, input_dims.size()); + concat_size += input_dims[axis]; } // Write out the concat_size on the output array shape. auto& output_shape = *output_array.mutable_shape(); auto& output_dims = *output_shape.mutable_dims(); - CHECK_LT(op->axis, output_shape.dimensions_count()); - output_dims[op->axis] = concat_size; + CHECK_LT(axis, output_shape.dimensions_count()); + output_dims[axis] = concat_size; } void ProcessRangeOperator(Model* model, RangeOperator* op) { @@ -1120,7 +1123,8 @@ void ProcessStridedSliceOperator(Model* model, StridedSliceOperator* op) { stop += input_array.shape().dims(i); } - int dim_size = (stop - start) / op->strides[i]; + int dim_size = ceil((stop - start) / static_cast(op->strides[i])); + dim_size = dim_size < 0 ? 0 : dim_size; if (op->shrink_axis_mask & mask) { CHECK_EQ(dim_size, 1) << "Output size for an axis must compute to 1 when " "shrinking that axis"; diff --git a/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc b/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc index b973b2b813147cc580d2e87cea7d395f180f5aa1..139c19022ed45e0243ca7c1a84717a5d79bb81cf 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc @@ -41,10 +41,14 @@ bool SupportsQuantization(const Operator& op) { type == OperatorType::kConcatenation || type == OperatorType::kL2Normalization || type == OperatorType::kAdd || type == OperatorType::kAveragePool || type == OperatorType::kMaxPool || + type == OperatorType::kTensorFlowMinimum || + type == OperatorType::kTensorFlowMaximum || type == OperatorType::kLogistic || type == OperatorType::kSoftmax || + type == OperatorType::kTensorFlowSplit || type == OperatorType::kSub || type == OperatorType::kSqueeze || type == OperatorType::kPad || type == OperatorType::kTensorFlowReshape || - type == OperatorType::kMul || type == OperatorType::kSpaceToDepth || + type == OperatorType::kTanh || type == OperatorType::kMul || + type == OperatorType::kSpaceToDepth || type == OperatorType::kDepthToSpace; } @@ -258,6 +262,17 @@ bool ChooseHardcodedQuantizationForOperatorOutput( *quantization_params)); return true; } + if (op.type == OperatorType::kTanh) { + // Tanh has the range: [-1, 1]. + *quantized_data_type = ArrayDataType::kUint8; + quantization_params->zero_point = 127; + quantization_params->scale = 1. / 128.; + // 0 should be exactly representable, as values will typically be centered + // around 0, with many values near 0. + CHECK( + IsExactlyRepresentable(0., *quantized_data_type, *quantization_params)); + return true; + } return false; } @@ -395,11 +410,22 @@ bool Quantize::Run(Model* model, std::size_t op_index) { if (IsConstantParameterArray(*model, input)) { QuantizeArray(this, model, input, quantized_data_type, quantization_params); + } else if (toco::IsRnnStateArray(*model, input)) { + // Simply Quantize the Array + auto& array = model->GetArray(op.inputs[input_index]); + array.GetOrCreateQuantizationParams() = quantization_params; + array.data_type = quantized_data_type; } else { auto dequantize_it = FindOpWithOutput(*model, input); - CHECK(dequantize_it != model->operators.end()); + CHECK(dequantize_it != model->operators.end()) + << "Cannot quantize input \"" << input + << "\" on operator with output \"" << op.outputs[0] + << "\". Nothing feeding input."; auto* dequantize_op = dequantize_it->get(); - CHECK(dequantize_op->type == OperatorType::kDequantize); + CHECK(dequantize_op->type == OperatorType::kDequantize) + << "Cannot quantize input \"" << input + << "\" on operator with output \"" << op.outputs[0] + << "\". Input is not fed by a Dequantize operator."; op.inputs[input_index] = dequantize_op->inputs[0]; // Check if the output of that Dequantize op was not used by any // other operator. We will then erase that Dequantize op. diff --git a/tensorflow/contrib/lite/toco/graph_transformations/reorder_activation_functions.cc b/tensorflow/contrib/lite/toco/graph_transformations/reorder_activation_functions.cc new file mode 100644 index 0000000000000000000000000000000000000000..cabbc4d313be3069053f056eb0de45c37ba2e7a4 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/reorder_activation_functions.cc @@ -0,0 +1,85 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/runtime/types.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +bool ReorderActivationFunctions::Run(Model* model, std::size_t op_index) { + const auto ac_it = model->operators.begin() + op_index; + std::unique_ptr& ac_op = *ac_it; + DCHECK(ac_op); + + if (ac_op->type != OperatorType::kRelu6 && + ac_op->type != OperatorType::kRelu1 && + ac_op->type != OperatorType::kRelu) { + return false; + } + + auto exchange_it = FindOpWithOutput(*model, ac_op->inputs[0]); + if (exchange_it == model->operators.end()) return false; + // Find the op producing the array passed to this activation function + std::unique_ptr& exchange_op = *exchange_it; + DCHECK(exchange_op); + + if (exchange_op->type != OperatorType::kTensorFlowReshape) { + return false; + } + + DCHECK_EQ(exchange_op->outputs[0], ac_op->inputs[0]); + const auto& exchange_op_input = exchange_op->inputs[0]; + const auto& intermediate_array = exchange_op->outputs[0]; + const auto& ac_op_output = ac_op->outputs[0]; + + int count_ops_consuming_output = + CountOpsWithInput(*model, intermediate_array); + DCHECK_GE(count_ops_consuming_output, 1); + if (count_ops_consuming_output > 1) { + AddMessageF( + "Not exchanging activation function with %s because it is consumed by " + "more than 1 other operator", + LogName(*exchange_op)); + return false; + } + + // Rewire by changing inputs, including all consumers. + Operator* consumer = GetFirstOpWithInput(*model, ac_op_output); + while (consumer) { + for (int i = 0; i < consumer->inputs.size(); ++i) { + if (consumer->inputs[i] == ac_op_output) { + consumer->inputs[i] = intermediate_array; + } + } + consumer = GetFirstOpWithInput(*model, ac_op_output); + } + ac_op->inputs[0] = exchange_op_input; + exchange_op->inputs[0] = ac_op_output; + + // Finally, reorder operators. Note that this only works when there are no + // other direct descendents of the exchange_op. + ac_op.swap(exchange_op); + + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_concatenation.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_concatenation.cc index 5ac449749adbc9b5422f996eeccb72575dca8722..db68968bad1e6fbf020c6ac82d6871b5a071b29e 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_concatenation.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_concatenation.cc @@ -73,7 +73,7 @@ void CopyTensorSegments(const std::vector& input_arrays, // Receives a series of input arrays of type Array and an integer showing the // axis on which those arrays will be concatenated. It returns the concatenated -// arrray. +// array. template void ConcatenateTensorBuffers(const std::vector& input_arrays, int concatenation_axis, diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc index ca378af4c5c1e1b8cf42a10d3820db3feeb49a05..c12706e52dd6257d7267bc4fb14ffb0053f3d69c 100644 --- a/tensorflow/contrib/lite/toco/import_tensorflow.cc +++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc @@ -173,7 +173,8 @@ void ImportFloatArray(const TensorProto& input_tensor, Array* output_array) { } auto& output_float_data = output_array->GetMutableBuffer().data; - output_float_data.resize(input_flat_size); + output_float_data.resize(RequiredBufferSizeForShape(output_array->shape()), + 0.f); if (input_tensor.float_val_size() == 1) { for (int i = 0; i < input_flat_size; i++) { output_float_data[i] = input_tensor.float_val(0); @@ -203,7 +204,7 @@ void ImportQuint8Array(const TensorProto& input_tensor, Array* output_array) { } auto& output_int_data = output_array->GetMutableBuffer().data; - output_int_data.resize(input_flat_size); + output_int_data.resize(RequiredBufferSizeForShape(output_array->shape()), 0); if (input_tensor.int_val_size()) { for (int i = 0; i < input_tensor.int_val_size(); i++) { output_int_data[i] = input_tensor.int_val(i); @@ -229,7 +230,7 @@ void ImportInt32Array(const TensorProto& input_tensor, Array* output_array) { } auto& output_int_data = output_array->GetMutableBuffer().data; - output_int_data.resize(input_flat_size); + output_int_data.resize(RequiredBufferSizeForShape(output_array->shape()), 0); if (input_tensor.int_val_size()) { for (int i = 0; i < input_tensor.int_val_size(); i++) { output_int_data[i] = input_tensor.int_val(i); @@ -255,7 +256,7 @@ void ImportInt64Array(const TensorProto& input_tensor, Array* output_array) { } auto& output_int_data = output_array->GetMutableBuffer().data; - output_int_data.resize(input_flat_size); + output_int_data.resize(RequiredBufferSizeForShape(output_array->shape()), 0); if (input_tensor.int64_val_size()) { for (int i = 0; i < input_tensor.int64_val_size(); i++) { output_int_data[i] = input_tensor.int64_val(i); @@ -281,7 +282,7 @@ void ImportStringArray(const TensorProto& input_tensor, Array* output_array) { } auto& output_string_data = output_array->GetMutableBuffer().data; - output_string_data.resize(input_flat_size); + output_string_data.resize(RequiredBufferSizeForShape(output_array->shape())); if (input_flat_size != input_tensor.string_val_size()) { LOG(FATAL) << "Input_content string_val doesn't have the right " "dimensions for this string tensor."; @@ -1311,6 +1312,12 @@ void ConvertResizeBilinearOperator(const NodeDef& node, CHECK_EQ(node.op(), "ResizeBilinear"); CheckInputsCount(node, tf_import_flags, 2); auto* op = new ResizeBilinearOperator; + + op->align_corners = false; + if (HasAttr(node, "align_corners")) { + op->align_corners = GetBoolAttr(node, "align_corners"); + } + op->inputs.push_back(node.input(0)); op->inputs.push_back(node.input(1)); op->outputs.push_back(node.name()); diff --git a/tensorflow/contrib/lite/toco/model.h b/tensorflow/contrib/lite/toco/model.h index 6fba8f2629f785ffeb3ae37b80ec1d24c29d9d56..447618ec8519f20ccdf346fc766415b5b829c2b5 100644 --- a/tensorflow/contrib/lite/toco/model.h +++ b/tensorflow/contrib/lite/toco/model.h @@ -1273,6 +1273,8 @@ struct ArgMaxOperator : Operator { // TensorFlow equivalent: ResizeBilinear struct ResizeBilinearOperator : Operator { ResizeBilinearOperator() : Operator(OperatorType::kResizeBilinear) {} + + bool align_corners = false; }; // SpaceToBatchND operator. It divides spatial dimensions into a grid of diff --git a/tensorflow/contrib/lite/toco/tensorflow_util.cc b/tensorflow/contrib/lite/toco/tensorflow_util.cc index 82e2800ca2f5bb017f91b5bf43d8d3cd05e97b83..0e7e9c41a066581b14fe1b78f83d8d57b916be6c 100644 --- a/tensorflow/contrib/lite/toco/tensorflow_util.cc +++ b/tensorflow/contrib/lite/toco/tensorflow_util.cc @@ -51,7 +51,8 @@ void LogDumpGraphDef(int log_level, const string& message, BEGIN DUMP OF TENSORFLOW GRAPHDEF (%s) There are %d nodes. There are %zu different op types: -)MSG", message, tf_graph.node_size(), ops.size()); +)MSG", + message, tf_graph.node_size(), ops.size()); for (const auto& op : ops) { toco::port::AppendF(&dump, " %s\n", op); } @@ -63,7 +64,8 @@ PROTO DUMP BEGIN NODE: name = %s op = %s inputs = [ -)MSG", node.name(), node.op()); +)MSG", + node.name(), node.op()); for (const auto& input : node.input()) { toco::port::AppendF(&dump, " %s\n", input); } diff --git a/tensorflow/contrib/lite/toco/tflite/export.cc b/tensorflow/contrib/lite/toco/tflite/export.cc index 391ef87029d019ab52af2716f72883f5f82f94d9..27719599708a7eb14f72a82f8e5d76b3b8af9dc4 100644 --- a/tensorflow/contrib/lite/toco/tflite/export.cc +++ b/tensorflow/contrib/lite/toco/tflite/export.cc @@ -26,6 +26,9 @@ namespace toco { namespace tflite { +using flatbuffers::FlatBufferBuilder; +using flatbuffers::Offset; +using flatbuffers::Vector; using ::tflite::Buffer; using ::tflite::BuiltinOperator; using ::tflite::BuiltinOperator_CUSTOM; @@ -39,9 +42,6 @@ using ::tflite::Operator; using ::tflite::OperatorCode; using ::tflite::SubGraph; using ::tflite::Tensor; -using flatbuffers::FlatBufferBuilder; -using flatbuffers::Offset; -using flatbuffers::Vector; namespace { diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc index 2d6bccce2b8aa498babf53565e4feb5a261c2058..04aaedd59d76ce63a80e80debba48f2b085d8db9 100644 --- a/tensorflow/contrib/lite/toco/tflite/operator.cc +++ b/tensorflow/contrib/lite/toco/tflite/operator.cc @@ -140,25 +140,11 @@ class SpaceToBatchND flatbuffers::Offset WriteOptions( const TocoOperator& op, flatbuffers::FlatBufferBuilder* builder) const override { - auto block_shape = builder->CreateVector(op.block_shape); - auto before_paddings = builder->CreateVector(op.before_paddings); - auto after_paddings = builder->CreateVector(op.after_paddings); - return ::tflite::CreateSpaceToBatchNDOptions( - *builder, block_shape, before_paddings, after_paddings); + return ::tflite::CreateSpaceToBatchNDOptions(*builder); } void ReadOptions(const TfLiteOptions& options, - TocoOperator* op) const override { - op->block_shape.insert(op->block_shape.end(), - options.block_shape()->begin(), - options.block_shape()->end()); - op->before_paddings.insert(op->before_paddings.end(), - options.before_paddings()->begin(), - options.before_paddings()->end()); - op->after_paddings.insert(op->after_paddings.end(), - options.after_paddings()->begin(), - options.after_paddings()->end()); - } + TocoOperator* op) const override {} }; class Sub : public BuiltinOperator WriteOptions( const TocoOperator& op, flatbuffers::FlatBufferBuilder* builder) const override { - auto block_shape = builder->CreateVector(op.block_shape); - auto before_crops = builder->CreateVector(op.before_crops); - auto after_crops = builder->CreateVector(op.after_crops); - return ::tflite::CreateBatchToSpaceNDOptions(*builder, block_shape, - before_crops, after_crops); + return ::tflite::CreateBatchToSpaceNDOptions(*builder); } void ReadOptions(const TfLiteOptions& options, - TocoOperator* op) const override { - op->block_shape.insert(op->block_shape.end(), - options.block_shape()->begin(), - options.block_shape()->end()); - op->before_crops.insert(op->before_crops.end(), - options.before_crops()->begin(), - options.before_crops()->end()); - op->after_crops.insert(op->after_crops.end(), - options.after_crops()->begin(), - options.after_crops()->end()); - } + TocoOperator* op) const override {} }; class Cast : public CustomOperator { @@ -478,8 +450,7 @@ class Pad : public BuiltinOperator WriteOptions( const TocoOperator& op, flatbuffers::FlatBufferBuilder* builder) const override { - auto axis = builder->CreateVector(op.axis); - return ::tflite::CreateMeanOptions(*builder, axis, op.keep_dims); + return ::tflite::CreateMeanOptions(*builder, op.keep_dims); } void ReadOptions(const TfLiteOptions& options, TocoOperator* op) const override { - op->axis.insert(op->axis.end(), options.axis()->begin(), - options.axis()->end()); op->keep_dims = options.keep_dims(); } }; +class ResizeBilinear + : public BuiltinOperator { + public: + using BuiltinOperator::BuiltinOperator; + flatbuffers::Offset WriteOptions( + const TocoOperator& op, + flatbuffers::FlatBufferBuilder* builder) const override { + return ::tflite::CreateResizeBilinearOptions(*builder, op.align_corners); + } + + void ReadOptions(const TfLiteOptions& options, + TocoOperator* op) const override { + op->align_corners = options.align_corners(); + } +}; + class Squeeze : public BuiltinOperator { @@ -788,6 +773,8 @@ std::vector> BuildOperatorList() { OperatorType::kTranspose)); ops.emplace_back( new Mean(::tflite::BuiltinOperator_MEAN, OperatorType::kMean)); + ops.emplace_back(new ResizeBilinear(::tflite::BuiltinOperator_RESIZE_BILINEAR, + OperatorType::kResizeBilinear)); ops.emplace_back( new Squeeze(::tflite::BuiltinOperator_SQUEEZE, OperatorType::kSqueeze)); ops.emplace_back(new StridedSlice(::tflite::BuiltinOperator_STRIDED_SLICE, @@ -820,8 +807,6 @@ std::vector> BuildOperatorList() { new SimpleOperator("RELU_N1_TO_1", OperatorType::kRelu1)); ops.emplace_back( new SimpleOperator("RELU6", OperatorType::kRelu6)); - ops.emplace_back(new SimpleOperator( - "RESIZE_BILINEAR", OperatorType::kResizeBilinear)); ops.emplace_back(new SimpleOperator( "LOGISTIC", OperatorType::kLogistic)); ops.emplace_back( diff --git a/tensorflow/contrib/lite/toco/tflite/operator_test.cc b/tensorflow/contrib/lite/toco/tflite/operator_test.cc index 78af3a767d33f914cd56a037f48530778ffb616e..796534be53cba0ea772e974cd8173c0b4c12e6c3 100644 --- a/tensorflow/contrib/lite/toco/tflite/operator_test.cc +++ b/tensorflow/contrib/lite/toco/tflite/operator_test.cc @@ -104,8 +104,6 @@ TEST_F(OperatorTest, SimpleOperators) { CheckSimpleOperator("RELU", OperatorType::kRelu); CheckSimpleOperator("RELU_N1_TO_1", OperatorType::kRelu1); CheckSimpleOperator("RELU6", OperatorType::kRelu6); - CheckSimpleOperator("RESIZE_BILINEAR", - OperatorType::kResizeBilinear); CheckSimpleOperator("LOGISTIC", OperatorType::kLogistic); CheckSimpleOperator("TANH", OperatorType::kTanh); } @@ -119,40 +117,12 @@ TEST_F(OperatorTest, BuiltinAdd) { output_toco_op->fused_activation_function); } -TEST_F(OperatorTest, BuiltinSpaceToBatchND) { - SpaceToBatchNDOperator op; - op.block_shape = {2, 2}; - op.before_paddings = {1, 2}; - op.after_paddings = {3, 4}; - - auto output_toco_op = SerializeAndDeserialize( - GetOperator("SPACE_TO_BATCH_ND", OperatorType::kSpaceToBatchND), op); - EXPECT_EQ(op.block_shape, output_toco_op->block_shape); - EXPECT_EQ(op.before_paddings, output_toco_op->before_paddings); - EXPECT_EQ(op.after_paddings, output_toco_op->after_paddings); -} - -TEST_F(OperatorTest, BuiltinBatchToSpaceND) { - BatchToSpaceNDOperator op; - op.block_shape = {2, 2}; - op.before_crops = {1, 2}; - op.after_crops = {3, 4}; - - auto output_toco_op = SerializeAndDeserialize( - GetOperator("BATCH_TO_SPACE_ND", OperatorType::kBatchToSpaceND), op); - EXPECT_EQ(op.block_shape, output_toco_op->block_shape); - EXPECT_EQ(op.before_crops, output_toco_op->before_crops); - EXPECT_EQ(op.after_crops, output_toco_op->after_crops); -} - TEST_F(OperatorTest, BuiltinMean) { MeanOperator op; - op.axis = {1, 2}; op.keep_dims = false; auto output_toco_op = SerializeAndDeserialize(GetOperator("MEAN", OperatorType::kMean), op); - EXPECT_EQ(op.axis, output_toco_op->axis); EXPECT_EQ(op.keep_dims, output_toco_op->keep_dims); } @@ -359,6 +329,14 @@ TEST_F(OperatorTest, BuiltinMul) { output_toco_op->fused_activation_function); } +TEST_F(OperatorTest, ResizeBilinear) { + ResizeBilinearOperator op; + op.align_corners = true; + auto output_toco_op = SerializeAndDeserialize( + GetOperator("RESIZE_BILINEAR", OperatorType::kResizeBilinear), op); + EXPECT_EQ(op.align_corners, output_toco_op->align_corners); +} + TEST_F(OperatorTest, Svdf) { SvdfOperator op; op.fused_activation_function = FusedActivationFunctionType::kRelu; diff --git a/tensorflow/contrib/lite/toco/toco_tooling.cc b/tensorflow/contrib/lite/toco/toco_tooling.cc index 727df1cc76ae332682a50db534e6bfa20ffc45ca..b715881774bf00f6cf2a50452a5b5c59c647ade6 100644 --- a/tensorflow/contrib/lite/toco/toco_tooling.cc +++ b/tensorflow/contrib/lite/toco/toco_tooling.cc @@ -68,6 +68,7 @@ void MakeGeneralGraphTransformationsSet( transformations->Add(new ResolveTensorFlowMatMul); transformations->Add(new FuseBinaryIntoPrecedingAffine); transformations->Add(new FuseBinaryIntoFollowingAffine); + transformations->Add(new ReorderActivationFunctions); transformations->Add(new ResolveBatchNormalization); transformations->Add(new ResolveConstantBinaryOperator); transformations->Add(new ResolveConstantFill); diff --git a/tensorflow/contrib/lite/toco/tooling_util.cc b/tensorflow/contrib/lite/toco/tooling_util.cc index dbe0280a88a8771a10f524c9dbb347b326f16006..ff8bc471b7f41bb0c9d51b8b51637412c09dadfc 100644 --- a/tensorflow/contrib/lite/toco/tooling_util.cc +++ b/tensorflow/contrib/lite/toco/tooling_util.cc @@ -304,6 +304,19 @@ string HelpfulOperatorTypeName(const Operator& op) { return OperatorTypeName(op.type); } +bool OperatorSupportsFusedActivation(OperatorType type) { + switch (type) { + case OperatorType::kConcatenation: + case OperatorType::kSlice: + case OperatorType::kSqueeze: + case OperatorType::kTensorFlowReshape: + case OperatorType::kTensorFlowSplit: + return false; + default: + return true; + } +} + void LogSummary(int log_level, const Model& model) { VLOG(log_level) << "Operators summary (" << model.operators.size() << " operators):"; @@ -1747,4 +1760,22 @@ void UseArraysExtraInfo(Model* model) { } } +bool IsRnnSourceArray(const toco::Model& model, const string& array_name) { + for (const auto& rnn_state : model.flags.rnn_states()) { + if (array_name == rnn_state.back_edge_source_array()) { + return true; + } + } + return false; +} + +bool IsRnnStateArray(const toco::Model& model, const string& array_name) { + for (const auto& rnn_state : model.flags.rnn_states()) { + if (array_name == rnn_state.state_array()) { + return true; + } + } + return false; +} + } // namespace toco diff --git a/tensorflow/contrib/lite/toco/tooling_util.h b/tensorflow/contrib/lite/toco/tooling_util.h index 4051ba3576d272e9c7ab39c0cf344dd70f62d0ea..a023bab1a0a0a09f3afe7bcd5afa787c1962874d 100644 --- a/tensorflow/contrib/lite/toco/tooling_util.h +++ b/tensorflow/contrib/lite/toco/tooling_util.h @@ -82,6 +82,8 @@ std::vector>::iterator FindOp(Model& model, const char* OperatorTypeName(OperatorType type); string HelpfulOperatorTypeName(const Operator& op); +bool OperatorSupportsFusedActivation(OperatorType type); + void DumpGraphvizVideoFrame(const Model& model); void LogDump(int log_level, const string& message, const Model& model); void LogSummary(int log_level, const string& message, const Model& model); @@ -283,6 +285,9 @@ ArrayDataType ConvertIODataTypeToArrayDataType(IODataType type); void UseArraysExtraInfo(Model* model); +bool IsRnnSourceArray(const toco::Model& model, const string& array_name); +bool IsRnnStateArray(const toco::Model& model, const string& array_name); + } // namespace toco #endif // TENSORFLOW_CONTRIB_LITE_TOCO_TOOLING_UTIL_H_ diff --git a/tensorflow/contrib/lite/tools/BUILD b/tensorflow/contrib/lite/tools/BUILD index 1bffcfb987330c5d067d7f986a486fcf93e57ee7..6786b1618456637aecfd870b9984af65b59784f6 100644 --- a/tensorflow/contrib/lite/tools/BUILD +++ b/tensorflow/contrib/lite/tools/BUILD @@ -99,8 +99,11 @@ cc_library( srcs = ["verifier.cc"], hdrs = ["verifier.h"], deps = [ + "//tensorflow/contrib/lite:framework", "//tensorflow/contrib/lite:schema_fbs_version", + "//tensorflow/contrib/lite:string_util", "//tensorflow/contrib/lite/schema:schema_fbs", + "@com_google_absl//absl/base:core_headers", ], ) @@ -112,8 +115,10 @@ cc_test( ":verifier", "//tensorflow/contrib/lite:framework", "//tensorflow/contrib/lite:schema_fbs_version", + "//tensorflow/contrib/lite:string_util", "//tensorflow/contrib/lite/schema:schema_fbs", "//tensorflow/contrib/lite/testing:util", + "//tensorflow/core:framework_lite", "@com_google_googletest//:gtest", "@flatbuffers", ], diff --git a/tensorflow/contrib/lite/tools/verifier.cc b/tensorflow/contrib/lite/tools/verifier.cc index 95a0895379845d8887939e0217270b30ea5584ca..726e2aaa3162591593cd2abd6384eb55baf0aef4 100644 --- a/tensorflow/contrib/lite/tools/verifier.cc +++ b/tensorflow/contrib/lite/tools/verifier.cc @@ -14,13 +14,32 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/contrib/lite/tools/verifier.h" +#include #include "tensorflow/contrib/lite/schema/schema_generated.h" +#include "tensorflow/contrib/lite/string_util.h" #include "tensorflow/contrib/lite/version.h" namespace tflite { namespace { +// Reports error message when the reporter is set. +void ReportError(ErrorReporter* error_reporter, const char* format, ...) { + if (error_reporter) { + va_list args; + va_start(args, format); + error_reporter->Report(format, args); + va_end(args); + } +} + +// Returns the int32_t value pointed by ptr. +const uint32_t* GetIntPtr(const char* ptr) { + return reinterpret_cast(ptr); +} + +// Verifies flatbuffer format of the model contents and returns the in-memory +// model. const Model* VerifyFlatbufferAndGetModel(const void* buf, size_t len) { ::flatbuffers::Verifier verifier(static_cast(buf), len); if (VerifyModelBuffer(verifier)) { @@ -30,14 +49,159 @@ const Model* VerifyFlatbufferAndGetModel(const void* buf, size_t len) { } } +const uint32_t kMaxNumString = UINT_MAX / sizeof(int32_t) - 2; + +// Verifies string tensor has legit buffer contents that follow the schema +// defined in lite/string_util.h +bool VerifyStringTensorBuffer(const Buffer& buffer, + ErrorReporter* error_reporter) { + uint32_t buffer_size = buffer.data()->size(); + const char* buffer_ptr = reinterpret_cast(buffer.data()->data()); + + uint32_t num_strings = *GetIntPtr(buffer_ptr); + if (num_strings > kMaxNumString) { + ReportError(error_reporter, + "String tensor has invalid num of string set: %d", num_strings); + return false; + } + uint32_t header_offsets = + static_cast(num_strings + 2) * sizeof(int32_t); + + if (buffer_size < header_offsets) { + ReportError(error_reporter, + "String tensor buffer requires at least %d bytes, but is " + "allocated with %d bytes", + header_offsets, buffer_size); + return false; + } + + uint32_t prev_ptr = header_offsets; + uint32_t offset = sizeof(int32_t); + + if (*GetIntPtr(buffer_ptr + offset) != header_offsets) { + ReportError(error_reporter, + "String tensor buffer initial offset must be: %d", + header_offsets); + return false; + } + offset += sizeof(int32_t); + for (int i = 1; i <= num_strings; i++, offset += sizeof(int32_t)) { + int string_offset = *GetIntPtr(buffer_ptr + offset); + if (string_offset < prev_ptr || string_offset > buffer_size) { + ReportError(error_reporter, "String tensor buffer is invalid: index %d", + i); + return false; + } + } + if (*GetIntPtr(buffer_ptr + offset - sizeof(int32_t)) != buffer_size) { + ReportError(error_reporter, "String tensor buffer last offset must be %d", + buffer_size); + return false; + } + return true; +} + +// Verifies numeric tensor has legit buffer. +bool VerifyNumericTensorBuffer(const Tensor& tensor, const Buffer& buffer, + ErrorReporter* error_reporter) { + uint64_t bytes_required = 1; + for (int dim : *tensor.shape()) { + bytes_required *= dim; + if (bytes_required > UINT_MAX) { + ReportError(error_reporter, "Tensor dimension overflow"); + return false; + } + } + switch (tensor.type()) { + case TensorType_FLOAT32: + bytes_required *= sizeof(float); + break; + case TensorType_INT32: + bytes_required *= sizeof(int32_t); + break; + case TensorType_UINT8: + bytes_required *= sizeof(uint8_t); + break; + case TensorType_INT64: + bytes_required *= sizeof(int64_t); + break; + case TensorType_FLOAT16: + // FALLTHROUGH_INTENDED; + default: + ReportError(error_reporter, "Invalid tensor type: %d", tensor.type()); + return false; + } + if (bytes_required > UINT_MAX) { + ReportError(error_reporter, "Tensor dimension overflow"); + return false; + } + + if (bytes_required != buffer.data()->size()) { + ReportError( + error_reporter, + "Tensor requires %d bytes, but is allocated with %d bytes buffer", + bytes_required, buffer.data()->size()); + return false; + } + return true; + + // TODO(yichengfan): verify quantized tensors. +} + +// Verifies tensors have valid properties and legit buffer if set. +bool VerifyTensors(const Model& model, ErrorReporter* error_reporter) { + if (!model.subgraphs()) { + return true; + } + for (const auto& subgraph : *model.subgraphs()) { + if (!subgraph->tensors()) { + return true; + } + for (const auto& tensor : *subgraph->tensors()) { + if (!tensor->buffer()) { + return true; + } + if (tensor->buffer() >= model.buffers()->size()) { + ReportError(error_reporter, "Invalid tensor buffer index: %d", + tensor->buffer()); + return false; + } + auto* buffer = model.buffers()->Get(tensor->buffer()); + if (!buffer || !buffer->data()) { + ReportError(error_reporter, "Tensor buffer %d not set", + tensor->buffer()); + return false; + } + + if (tensor->type() == TensorType_STRING) { + if (!VerifyStringTensorBuffer(*buffer, error_reporter)) { + return false; + } + } else { + if (!VerifyNumericTensorBuffer(*tensor, *buffer, error_reporter)) { + return false; + } + } + } + } + return true; +} + } // namespace -bool Verify(const void* buf, size_t len) { +bool Verify(const void* buf, size_t len, ErrorReporter* error_reporter) { const Model* model = VerifyFlatbufferAndGetModel(buf, len); if (model == nullptr) { + ReportError(error_reporter, "Invalid flatbuffer format"); return false; } - - return model->version() == TFLITE_SCHEMA_VERSION; + if (model->version() != TFLITE_SCHEMA_VERSION) { + ReportError(error_reporter, "Invalid model version %d", model->version()); + return false; + } + if (!VerifyTensors(*model, error_reporter)) { + return false; + } + return true; } } // namespace tflite diff --git a/tensorflow/contrib/lite/tools/verifier.h b/tensorflow/contrib/lite/tools/verifier.h index 03e1f22b7e87baf6d1586dde5812fc854d9e2c4c..d2bf3c91d54225098c1f254c26971e8bb962f791 100644 --- a/tensorflow/contrib/lite/tools/verifier.h +++ b/tensorflow/contrib/lite/tools/verifier.h @@ -18,13 +18,15 @@ limitations under the License. #include +#include "tensorflow/contrib/lite/error_reporter.h" + namespace tflite { // Verifies the integrity of a Tensorflow Lite flatbuffer model file. // Currently, it verifies: // * The file is following a legit flatbuffer schema. // * The model is in supported version. -bool Verify(const void* buf, size_t len); +bool Verify(const void* buf, size_t len, ErrorReporter* error_reporter); } // namespace tflite diff --git a/tensorflow/contrib/lite/tools/verifier_test.cc b/tensorflow/contrib/lite/tools/verifier_test.cc index 0481a55a78e5e1cd0821df13ffaf84bbe28a1b8e..87f6854e9e67c0389949c8d72a476036051d1c0f 100644 --- a/tensorflow/contrib/lite/tools/verifier_test.cc +++ b/tensorflow/contrib/lite/tools/verifier_test.cc @@ -12,7 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/tools/verifier.h" +#include +#include + #include "flatbuffers/flatbuffers.h" #include "flatbuffers/util.h" #include @@ -20,7 +22,9 @@ limitations under the License. #include "tensorflow/contrib/lite/error_reporter.h" #include "tensorflow/contrib/lite/schema/schema_generated.h" #include "tensorflow/contrib/lite/testing/util.h" +#include "tensorflow/contrib/lite/tools/verifier.h" #include "tensorflow/contrib/lite/version.h" +#include "tensorflow/core/framework/numeric_types.h" namespace tflite { @@ -28,31 +32,62 @@ using flatbuffers::FlatBufferBuilder; using flatbuffers::Offset; using flatbuffers::Vector; -// Class that abstracts the list of buffers at the end of the TF Lite structure -class DeferredBufferWriter { +// Build single subgraph model. +class TfLiteFlatbufferModelBuilder { public: - DeferredBufferWriter() { - data_.push_back({}); // sentinel empty buffer. + TfLiteFlatbufferModelBuilder() { + buffers_.push_back( + CreateBuffer(builder_, builder_.CreateVector(std::vector{}))); } - Offset>> BuildBuffers(FlatBufferBuilder *builder) { - std::vector> buffer_vector; - for (const auto &vec : data_) { - auto data_buffer = builder->CreateVector(vec.data(), vec.size()); - buffer_vector.push_back(tflite::CreateBuffer(*builder, data_buffer)); + void AddTensor(const std::vector& shape, tflite::TensorType type, + const std::vector& buffer, const char* name) { + int buffer_index = 0; + if (!buffer.empty()) { + buffer_index = buffers_.size(); + buffers_.push_back(CreateBuffer(builder_, builder_.CreateVector(buffer))); } - return builder->CreateVector(buffer_vector); + tensors_.push_back(CreateTensorDirect(builder_, &shape, type, buffer_index, + name, /*quantization=*/0)); + } + + void AddOperator(const std::vector& inputs, + const std::vector& outputs, + tflite::BuiltinOperator builtin_op, const char* custom_op) { + operator_codes_.push_back( + CreateOperatorCodeDirect(builder_, builtin_op, custom_op)); + operators_.push_back(CreateOperator( + builder_, operator_codes_.size() - 1, builder_.CreateVector(inputs), + builder_.CreateVector(outputs), BuiltinOptions_NONE, + /*builtin_options=*/0, + /*custom_options=*/0, tflite::CustomOptionsFormat_FLEXBUFFERS)); } - // Registers a buffer index and takes ownership of the data to write to it. - int Record(std::vector data) { - int buffer_index = data_.size(); - data_.emplace_back(std::move(data)); - return buffer_index; + void FinishModel(const std::vector& inputs, + const std::vector& outputs) { + auto subgraph = std::vector>({CreateSubGraph( + builder_, builder_.CreateVector(tensors_), + builder_.CreateVector(inputs), builder_.CreateVector(outputs), + builder_.CreateVector(operators_), + builder_.CreateString("test_subgraph"))}); + auto result = CreateModel( + builder_, TFLITE_SCHEMA_VERSION, builder_.CreateVector(operator_codes_), + builder_.CreateVector(subgraph), builder_.CreateString("test_model"), + builder_.CreateVector(buffers_)); + tflite::FinishModelBuffer(builder_, result); + } + + bool Verify() { + return tflite::Verify(builder_.GetBufferPointer(), builder_.GetSize(), + DefaultErrorReporter()); } private: - std::vector> data_; + FlatBufferBuilder builder_; + std::vector> operators_; + std::vector> operator_codes_; + std::vector> tensors_; + std::vector> buffers_; }; TEST(VerifyModel, TestEmptyModel) { @@ -62,43 +97,26 @@ TEST(VerifyModel, TestEmptyModel) { /*description=*/0, /*buffers=*/0); ::tflite::FinishModelBuffer(builder, model); - ASSERT_TRUE(Verify(builder.GetBufferPointer(), builder.GetSize())); + ASSERT_TRUE(Verify(builder.GetBufferPointer(), builder.GetSize(), + DefaultErrorReporter())); } TEST(VerifyModel, TestSimpleModel) { - FlatBufferBuilder builder; - auto inputs = builder.CreateVector({0}); - auto outputs = builder.CreateVector({1}); - auto operator_codes = builder.CreateVector(std::vector>{ - CreateOperatorCodeDirect(builder, BuiltinOperator_CUSTOM, "test")}); - auto operators = - builder.CreateVector(std::vector>{CreateOperator( - builder, /*opcode_index=*/0, - /*inputs=*/builder.CreateVector({0}), - /*outputs=*/builder.CreateVector({1}), BuiltinOptions_NONE, - /*builtin_options=*/0, - /*custom_options=*/0, ::tflite::CustomOptionsFormat_FLEXBUFFERS)}); - std::vector shape; - auto tensors = builder.CreateVector(std::vector>{ - CreateTensorDirect(builder, &shape, TensorType_INT32, /*buffer=*/0, - "input", /*quantization=*/0), - CreateTensorDirect(builder, &shape, TensorType_INT32, /*buffer=*/0, - "output", /*quantization=*/0)}); - auto subgraph = std::vector>( - {CreateSubGraph(builder, tensors, inputs, outputs, operators, - builder.CreateString("Main"))}); - - auto model = CreateModel(builder, TFLITE_SCHEMA_VERSION, operator_codes, - builder.CreateVector(subgraph), - builder.CreateString("SmartReply"), /*buffers=*/0); - - ::tflite::FinishModelBuffer(builder, model); - ASSERT_TRUE(Verify(builder.GetBufferPointer(), builder.GetSize())); + TfLiteFlatbufferModelBuilder builder; + builder.AddOperator({0, 1}, {2}, BuiltinOperator_CUSTOM, "test"); + builder.AddTensor({2, 3}, TensorType_UINT8, {1, 2, 3, 4, 5, 6}, "input"); + builder.AddTensor( + {2}, TensorType_STRING, + {2, 0, 0, 0, 16, 0, 0, 0, 17, 0, 0, 0, 19, 0, 0, 0, 'A', 'B', 'C'}, + "data"); + builder.AddTensor({2, 3}, TensorType_INT32, {}, "output"); + builder.FinishModel({0, 1}, {2}); + ASSERT_TRUE(builder.Verify()); } TEST(VerifyModel, TestCorruptedData) { - string model = "123"; - ASSERT_FALSE(Verify(model.data(), model.size())); + std::string model = "123"; + ASSERT_FALSE(Verify(model.data(), model.size(), /*error_reporter=*/nullptr)); } TEST(VerifyModel, TestUnsupportedVersion) { @@ -106,7 +124,8 @@ TEST(VerifyModel, TestUnsupportedVersion) { auto model = CreateModel(builder, /*version=*/1, /*operator_codes=*/0, /*subgraphs=*/0, /*description=*/0, /*buffers=*/0); ::tflite::FinishModelBuffer(builder, model); - ASSERT_FALSE(Verify(builder.GetBufferPointer(), builder.GetSize())); + ASSERT_FALSE(Verify(builder.GetBufferPointer(), builder.GetSize(), + DefaultErrorReporter())); } TEST(VerifyModel, TestRandomModificationIsNotAllowed) { @@ -116,20 +135,105 @@ TEST(VerifyModel, TestRandomModificationIsNotAllowed) { /*subgraphs=*/0, /*description=*/0, /*buffers=*/0); ::tflite::FinishModelBuffer(builder, model); - string model_content(reinterpret_cast(builder.GetBufferPointer()), - builder.GetSize()); + std::string model_content(reinterpret_cast(builder.GetBufferPointer()), + builder.GetSize()); for (int i = 0; i < model_content.size(); i++) { model_content[i] = (model_content[i] + 137) % 255; - EXPECT_FALSE(Verify(model_content.data(), model_content.size())) + EXPECT_FALSE(Verify(model_content.data(), model_content.size(), + DefaultErrorReporter())) << "Fail at position: " << i; } } +TEST(VerifyModel, TestIntTensorShapeIsGreaterThanBuffer) { + TfLiteFlatbufferModelBuilder builder; + builder.AddTensor({2, 3}, TensorType_UINT8, {1, 2, 3, 4}, "input"); + builder.FinishModel({}, {}); + ASSERT_FALSE(builder.Verify()); +} + +TEST(VerifyModel, TestIntTensorShapeIsSmallerThanBuffer) { + TfLiteFlatbufferModelBuilder builder; + builder.AddTensor({2, 1}, TensorType_UINT8, {1, 2, 3, 4}, "input"); + builder.FinishModel({}, {}); + ASSERT_FALSE(builder.Verify()); +} + +TEST(VerifyModel, TestIntTensorShapeOverflow) { + TfLiteFlatbufferModelBuilder builder; + builder.AddTensor({1024, 2048, 4096}, TensorType_UINT8, {1, 2, 3, 4}, + "input"); + builder.FinishModel({}, {}); + ASSERT_FALSE(builder.Verify()); +} + +TEST(VerifyModel, TensorBufferIsNotValid) { + FlatBufferBuilder builder; + std::vector shape = {2, 3}; + auto tensors = builder.CreateVector(std::vector>{ + CreateTensorDirect(builder, &shape, TensorType_INT32, /*buffer=*/2, + "input", /*quantization=*/0)}); + auto subgraph = std::vector>( + {CreateSubGraph(builder, tensors, /*inputs=*/0, /*outputs=*/0, + /*operators=*/0, builder.CreateString("Main"))}); + + auto buffers = builder.CreateVector(std::vector>{ + CreateBuffer(builder, + builder.CreateVector(std::vector{1, 2, 3, 4, 5, 6})), + }); + + auto model = CreateModel(builder, TFLITE_SCHEMA_VERSION, /*operator_codes=*/0, + builder.CreateVector(subgraph), + builder.CreateString("SmartReply"), buffers); + + ::tflite::FinishModelBuffer(builder, model); + ASSERT_FALSE(Verify(builder.GetBufferPointer(), builder.GetSize(), + DefaultErrorReporter())); +} + +TEST(VerifyModel, StringTensorHasInvalidNumString) { + TfLiteFlatbufferModelBuilder builder; + builder.AddTensor( + {2}, TensorType_STRING, + {0x00, 0x00, 0x00, 0x20, 16, 0, 0, 0, 17, 0, 0, 0, 18, 0, 0, 0, 'A', 'B'}, + "input"); + builder.FinishModel({}, {}); + ASSERT_FALSE(builder.Verify()); +} + +TEST(VerifyModel, StringTensorOffsetTooSmall) { + TfLiteFlatbufferModelBuilder builder; + builder.AddTensor( + {2}, TensorType_STRING, + {2, 0, 0, 0, 12, 0, 0, 0, 17, 0, 0, 0, 18, 0, 0, 0, 'A', 'B'}, "input"); + builder.FinishModel({}, {}); + ASSERT_FALSE(builder.Verify()); +} + +TEST(VerifyModel, StringTensorOffsetOutOfRange) { + TfLiteFlatbufferModelBuilder builder; + builder.AddTensor( + {2}, TensorType_STRING, + {2, 0, 0, 0, 16, 0, 0, 0, 17, 0, 0, 0, 22, 0, 0, 0, 'A', 'B'}, "input"); + builder.FinishModel({}, {}); + ASSERT_FALSE(builder.Verify()); +} + +TEST(VerifyModel, StringTensorIsLargerThanRequired) { + TfLiteFlatbufferModelBuilder builder; + builder.AddTensor( + {2}, TensorType_STRING, + {2, 0, 0, 0, 16, 0, 0, 0, 17, 0, 0, 0, 18, 0, 0, 0, 'A', 'B', 'C'}, + "input"); + builder.FinishModel({}, {}); + ASSERT_FALSE(builder.Verify()); +} + // TODO(yichengfan): make up malicious files to test with. } // namespace tflite -int main(int argc, char **argv) { +int main(int argc, char** argv) { ::tflite::LogToStderr(); ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); diff --git a/tensorflow/contrib/lite/tools/visualize.py b/tensorflow/contrib/lite/tools/visualize.py index d0d78e3afab7d89f216bb8ceb42e4429ca4f1759..f571dd59da0a3f4aff264b48fba3e41f75b50404 100644 --- a/tensorflow/contrib/lite/tools/visualize.py +++ b/tensorflow/contrib/lite/tools/visualize.py @@ -198,10 +198,13 @@ class TensorMapper(object): def GenerateGraph(subgraph_idx, g, opcode_mapper): """Produces the HTML required to have a d3 visualization of the dag.""" + def TensorName(idx): - return "t%d"%idx + return "t%d" % idx + def OpName(idx): - return "o%d"%idx + return "o%d" % idx + edges = [] nodes = [] first = {} @@ -210,27 +213,35 @@ def GenerateGraph(subgraph_idx, g, opcode_mapper): for tensor_input_position, tensor_index in enumerate(op["inputs"]): if tensor_index not in first: first[tensor_index] = ( - op_index*pixel_mult, - tensor_input_position*pixel_mult - pixel_mult/2) - edges.append( - {"source": TensorName(tensor_index), "target": OpName(op_index)}) + op_index * pixel_mult, + tensor_input_position * pixel_mult - pixel_mult / 2) + edges.append({ + "source": TensorName(tensor_index), + "target": OpName(op_index) + }) for tensor_index in op["outputs"]: - edges.append( - {"target": TensorName(tensor_index), "source": OpName(op_index)}) - nodes.append({"id": OpName(op_index), - "name": opcode_mapper(op["opcode_index"]), - "group": 2, - "x": pixel_mult, - "y": op_index * pixel_mult}) + edges.append({ + "target": TensorName(tensor_index), + "source": OpName(op_index) + }) + nodes.append({ + "id": OpName(op_index), + "name": opcode_mapper(op["opcode_index"]), + "group": 2, + "x": pixel_mult, + "y": op_index * pixel_mult + }) for tensor_index, tensor in enumerate(g["tensors"]): - initial_y = (first[tensor_index] if tensor_index in first - else len(g["operators"])) - - nodes.append({"id": TensorName(tensor_index), - "name": "%s (%d)" % (tensor["name"], tensor_index), - "group": 1, - "x": 2, - "y": initial_y}) + initial_y = ( + first[tensor_index] if tensor_index in first else len(g["operators"])) + + nodes.append({ + "id": TensorName(tensor_index), + "name": "%s (%d)" % (tensor["name"], tensor_index), + "group": 1, + "x": 2, + "y": initial_y + }) graph_str = json.dumps({"nodes": nodes, "edges": edges}) html = _D3_HTML_TEMPLATE % (graph_str, subgraph_idx) @@ -267,7 +278,7 @@ def GenerateTableHtml(items, keys_to_print, display_index=True): for h, mapper in keys_to_print: val = tensor[h] if h in tensor else None val = val if mapper is None else mapper(val) - html += "%s\n"%val + html += "%s\n" % val html += "\n" html += "\n" @@ -279,18 +290,19 @@ def CreateHtmlFile(tflite_input, html_output): # Convert the model into a JSON flatbuffer using flatc (build if doesn't # exist. - if not os.path.exists(tflite_input): + if not os.path.exists(tflite_input): raise RuntimeError("Invalid filename %r" % tflite_input) if tflite_input.endswith(".tflite") or tflite_input.endswith(".bin"): # Run convert - cmd = (_BINARY + " -t " - "--strict-json --defaults-json -o /tmp {schema} -- {input}".format( - input=tflite_input, schema=_SCHEMA)) + cmd = ( + _BINARY + " -t " + "--strict-json --defaults-json -o /tmp {schema} -- {input}".format( + input=tflite_input, schema=_SCHEMA)) print(cmd) os.system(cmd) - real_output = ("/tmp/"+ os.path.splitext(os.path.split(tflite_input)[-1])[0] - + ".json") + real_output = ("/tmp/" + os.path.splitext( + os.path.split(tflite_input)[-1])[0] + ".json") data = json.load(open(real_output)) elif tflite_input.endswith(".json"): @@ -302,12 +314,13 @@ def CreateHtmlFile(tflite_input, html_output): html += "

TensorFlow Lite Model

" data["filename"] = tflite_input # Avoid special case - toplevel_stuff = [("filename", None), ("version", None), - ("description", None)] + toplevel_stuff = [("filename", None), ("version", None), ("description", + None)] html += "\n" for key, mapping in toplevel_stuff: - if not mapping: mapping = lambda x: x + if not mapping: + mapping = lambda x: x html += "\n" % (key, mapping(data[key])) html += "
%s%s
\n" @@ -320,22 +333,22 @@ def CreateHtmlFile(tflite_input, html_output): html += "
" tensor_mapper = TensorMapper(g) opcode_mapper = OpCodeMapper(data) - op_keys_to_display = [ - ("inputs", tensor_mapper), ("outputs", tensor_mapper), - ("builtin_options", None), ("opcode_index", opcode_mapper)] - tensor_keys_to_display = [ - ("name", None), ("type", None), ("shape", None), ("buffer", None), - ("quantization", None)] + op_keys_to_display = [("inputs", tensor_mapper), ("outputs", tensor_mapper), + ("builtin_options", None), ("opcode_index", + opcode_mapper)] + tensor_keys_to_display = [("name", None), ("type", None), ("shape", None), + ("buffer", None), ("quantization", None)] html += "

Subgraph %d

\n" % subgraph_idx # Inputs and outputs. html += "

Inputs/Outputs

\n" - html += GenerateTableHtml([{"inputs": g["inputs"], - "outputs": g["outputs"]}], - [("inputs", tensor_mapper), - ("outputs", tensor_mapper)], - display_index=False) + html += GenerateTableHtml( + [{ + "inputs": g["inputs"], + "outputs": g["outputs"] + }], [("inputs", tensor_mapper), ("outputs", tensor_mapper)], + display_index=False) # Print the tensors. html += "

Tensors

\n" @@ -357,8 +370,7 @@ def CreateHtmlFile(tflite_input, html_output): # Operator codes html += "

Operator Codes

\n" - html += GenerateTableHtml(data["operator_codes"], - operator_keys_to_display) + html += GenerateTableHtml(data["operator_codes"], operator_keys_to_display) html += "\n" @@ -370,10 +382,10 @@ def main(argv): tflite_input = argv[1] html_output = argv[2] except IndexError: - print ("Usage: %s " % (argv[0])) + print("Usage: %s " % (argv[0])) else: CreateHtmlFile(tflite_input, html_output) + if __name__ == "__main__": main(sys.argv) - diff --git a/tensorflow/contrib/losses/python/losses/loss_ops_test.py b/tensorflow/contrib/losses/python/losses/loss_ops_test.py index 9d0f95e6f3e7fa9666a99e31578b38d52e0b6b4a..1417772e0496cb571488e5b30bd4f3fb1b591730 100644 --- a/tensorflow/contrib/losses/python/losses/loss_ops_test.py +++ b/tensorflow/contrib/losses/python/losses/loss_ops_test.py @@ -27,6 +27,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors_impl from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops @@ -274,6 +275,7 @@ class SoftmaxCrossEntropyLossTest(test.TestCase): self.assertAlmostEqual(np.average(weights) * 10.0, loss, 3) +@test_util.with_c_api class SparseSoftmaxCrossEntropyLossTest(test.TestCase): def testNoneWeightRaisesValueError(self): @@ -471,7 +473,11 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase): labels = constant_op.constant([[0, 1], [2, 3]]) weights = constant_op.constant([1.2, 3.4, 5.6, 7.8]) - with self.assertRaises(errors_impl.InvalidArgumentError): + if ops._USE_C_API: + error_type = ValueError + else: + error_type = errors_impl.InvalidArgumentError + with self.assertRaises(error_type): loss_ops.sparse_softmax_cross_entropy( logits, labels, weights=weights).eval() diff --git a/tensorflow/contrib/makefile/BUILD b/tensorflow/contrib/makefile/BUILD index a8dd59f32a7f3b27993a7ee48ee7cc07ada59a4c..701eeb44fe3f814cb3fb1cedd8618753946cc3e5 100644 --- a/tensorflow/contrib/makefile/BUILD +++ b/tensorflow/contrib/makefile/BUILD @@ -12,20 +12,3 @@ filegroup( ), visibility = ["//tensorflow:__subpackages__"], ) - -sh_test( - name = "build_all_linux", - size = "enormous", - srcs = ["build_all_linux.sh"], - data = [ - "//tensorflow:all_opensource_files", - "//third_party/eigen3:all_files", - "//third_party/fft2d:all_files", - ], - tags = [ - "manual", - "no_gpu", - "no_oss", - "notap", - ], -) diff --git a/tensorflow/contrib/makefile/Makefile b/tensorflow/contrib/makefile/Makefile index c573cf15da6aa756bf6840206af1663769d0181d..81327407d44b4317b7aecb964a689a35aa35c163 100644 --- a/tensorflow/contrib/makefile/Makefile +++ b/tensorflow/contrib/makefile/Makefile @@ -407,7 +407,7 @@ $(MARCH_OPTION) \ -I$(JETPACK)/cuda/extras/CUPTI/include - LIBS += \ + CUDA_LIBS := \ -ltfcuda \ -lcudart_static \ -lcudnn \ @@ -420,10 +420,10 @@ $(MARCH_OPTION) \ -lculibos \ -lcurand_static - OBJDIR := $(OBJDIR)Tegra/ - LIBDIR := $(LIBDIR)Tegra/ - BINDIR := $(BINDIR)Tegra/ - DEPDIR := $(DEPDIR)Tegra/ + OBJDIR := $(OBJDIR)android_arm64-v8a/ + LIBDIR := $(LIBDIR)android_arm64-v8a/ + BINDIR := $(BINDIR)android_arm64-v8a/ + DEPDIR := $(DEPDIR)android_arm64-v8a/ TEGRA_LIBS := \ -L$(JETPACK)/cuda/targets/aarch64-linux-androideabi/lib \ @@ -729,7 +729,7 @@ $(BENCHMARK_NAME): $(BENCHMARK_OBJS) $(LIB_PATH) $(CUDA_LIB_DEPS) @mkdir -p $(dir $@) $(CXX) $(CXXFLAGS) $(INCLUDES) \ -o $(BENCHMARK_NAME) $(BENCHMARK_OBJS) \ - $(LIBFLAGS) $(TEGRA_LIBS) $(LIB_PATH) $(LDFLAGS) $(LIBS) + $(LIBFLAGS) $(TEGRA_LIBS) $(LIB_PATH) $(LDFLAGS) $(LIBS) $(CUDA_LIBS) # NVCC compilation rules for Tegra ifeq ($(BUILD_FOR_TEGRA),1) diff --git a/tensorflow/contrib/makefile/build_all_android.sh b/tensorflow/contrib/makefile/build_all_android.sh index 281c4653c627661ae39592e2ea982d04104c30dd..f67c5161861d20ba15ba165491365cdfd0239047 100755 --- a/tensorflow/contrib/makefile/build_all_android.sh +++ b/tensorflow/contrib/makefile/build_all_android.sh @@ -37,7 +37,7 @@ fi ARCH=armeabi-v7a -while getopts "Es:t:Tx:a" opt_name; do +while getopts "Es:t:Tx:a:" opt_name; do case "$opt_name" in E) ENABLE_EXPERIMENTAL_HEXNN_OPS="true";; s) SUB_MAKEFILES="${OPTARG}";; diff --git a/tensorflow/contrib/makefile/build_all_ios.sh b/tensorflow/contrib/makefile/build_all_ios.sh index a18df256f976c3c0ac4cefe1c884d951e63ef823..2d9979183975e6a17527b40ef5ee1795ced44a7b 100755 --- a/tensorflow/contrib/makefile/build_all_ios.sh +++ b/tensorflow/contrib/makefile/build_all_ios.sh @@ -96,7 +96,7 @@ if [[ "${ONLY_MAKE_TENSORFLOW}" != "true" ]]; then if [[ -z "${BUILD_ARCH}" ]]; then # Compile protobuf for the target iOS device architectures. - tensorflow/contrib/makefile/compile_ios_protobuf.sh -a ${DEFAULT_ARCH} + tensorflow/contrib/makefile/compile_ios_protobuf.sh else # Compile protobuf for the target iOS device architectures. tensorflow/contrib/makefile/compile_ios_protobuf.sh -a ${BUILD_ARCH} diff --git a/tensorflow/contrib/makefile/sub_makefiles/android/Makefile.in b/tensorflow/contrib/makefile/sub_makefiles/android/Makefile.in index d9277ed60cb456208572ca1ad8df530648faef82..3081084ee76e41de801f49a67c1fec07f4ff03b9 100644 --- a/tensorflow/contrib/makefile/sub_makefiles/android/Makefile.in +++ b/tensorflow/contrib/makefile/sub_makefiles/android/Makefile.in @@ -54,7 +54,7 @@ $(INFERENCE_SO_PATH): $(LIB_OBJS) $(INFERENCE_OBJS) $(CUDA_LIB_DEPS) -o $@ $(INFERENCE_OBJS) $(LIB_OBJS) $(TEGRA_LIBS) \ $(LIBFLAGS) $(LDFLAGS) \ -shared -Wl,-soname,$(INFERENCE_SO_NAME) \ - $(LIBS) + $(LIBS) $(CUDA_LIBS) $(INFERENCE_SO_NAME): $(INFERENCE_SO_PATH) diff --git a/tensorflow/contrib/makefile/tf_op_files.txt b/tensorflow/contrib/makefile/tf_op_files.txt index 9a1ab503178f8c5fb4ba77917bd2cdad0ac72cdf..5a812af4e95fe7a05b9c2634b0cc1d860fb7f619 100644 --- a/tensorflow/contrib/makefile/tf_op_files.txt +++ b/tensorflow/contrib/makefile/tf_op_files.txt @@ -293,3 +293,4 @@ tensorflow/core/kernels/batchtospace_op.cc tensorflow/core/kernels/warn_about_ints.cc tensorflow/core/kernels/segment_reduction_ops.cc tensorflow/core/kernels/batch_util.cc +tensorflow/core/ops/audio_ops.cc diff --git a/tensorflow/contrib/memory_stats/kernels/memory_stats_ops.cc b/tensorflow/contrib/memory_stats/kernels/memory_stats_ops.cc index 39c0d5af45b4a81fa4dde0b5deac14a3af372cbb..974fb537499c5ea4591a0a128f53d2dea67b9e57 100644 --- a/tensorflow/contrib/memory_stats/kernels/memory_stats_ops.cc +++ b/tensorflow/contrib/memory_stats/kernels/memory_stats_ops.cc @@ -80,9 +80,9 @@ REGISTER_KERNEL_BUILDER(Name("BytesLimit").Device(DEVICE_GPU).HostMemory("out"), BytesLimitOp); #ifdef TENSORFLOW_USE_SYCL -REGISTER_KERNEL_BUILDER(Name("BytesLimit").Device(DEVICE_SYCL).HostMemory("out"), - BytesLimitOp); -#endif // TENSORFLOW_USE_SYCL +REGISTER_KERNEL_BUILDER( + Name("BytesLimit").Device(DEVICE_SYCL).HostMemory("out"), BytesLimitOp); +#endif // TENSORFLOW_USE_SYCL // Op that measures the peak memory in bytes. class MaxBytesInUseOp : public MemoryStatsOp { @@ -107,6 +107,6 @@ REGISTER_KERNEL_BUILDER( REGISTER_KERNEL_BUILDER( Name("MaxBytesInUse").Device(DEVICE_SYCL).HostMemory("out"), MaxBytesInUseOp); -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL } // namespace tensorflow diff --git a/tensorflow/contrib/meta_graph_transform/meta_graph_transform.py b/tensorflow/contrib/meta_graph_transform/meta_graph_transform.py index 2932ae1c8df32cd936cff932b061571c513fda79..ff88b4fa841673fc52b9f6fdc5ca43d30c44bbfd 100644 --- a/tensorflow/contrib/meta_graph_transform/meta_graph_transform.py +++ b/tensorflow/contrib/meta_graph_transform/meta_graph_transform.py @@ -171,7 +171,14 @@ def _clean_save_and_restore(graph_def, op, removed_op_names): shape_op_value_tensor.tensor_shape.dim[0].size = len(shapes) op.attr['dtypes'].list.type[:] = dtypes + if not name_op.attr['_output_shapes'].list.shape: + name_op.attr['_output_shapes'].list.shape.add() + name_op.attr['_output_shapes'].list.shape[0].dim.add() name_op.attr['_output_shapes'].list.shape[0].dim[0].size = len(names) + + if not shape_op.attr['_output_shapes'].list.shape: + shape_op.attr['_output_shapes'].list.shape.add() + shape_op.attr['_output_shapes'].list.shape[0].dim.add() shape_op.attr['_output_shapes'].list.shape[0].dim[0].size = len(shapes) diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops.py b/tensorflow/contrib/metrics/python/ops/metric_ops.py index 55946c128b1a46b8368aedd9f857c1902c4c4586..c2340d037761aa5ebd0dbea128cadb1b65f1a788 100644 --- a/tensorflow/contrib/metrics/python/ops/metric_ops.py +++ b/tensorflow/contrib/metrics/python/ops/metric_ops.py @@ -739,7 +739,7 @@ def _streaming_confusion_matrix_at_thresholds(predictions, else: for include in includes: if include not in all_includes: - raise ValueError('Invaild key: %s.' % include) + raise ValueError('Invalid key: %s.' % include) predictions, labels, weights = metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access predictions, labels, weights) diff --git a/tensorflow/contrib/model_pruning/examples/cifar10/cifar10_input.py b/tensorflow/contrib/model_pruning/examples/cifar10/cifar10_input.py index d07fece4bc668612d517e8dcaab1a35451a0238e..6a3b535eb447dd80f8e39d1d005f8f1d4f503549 100644 --- a/tensorflow/contrib/model_pruning/examples/cifar10/cifar10_input.py +++ b/tensorflow/contrib/model_pruning/examples/cifar10/cifar10_input.py @@ -58,6 +58,7 @@ def read_cifar10(filename_queue): class CIFAR10Record(object): pass + result = CIFAR10Record() # Dimensions of the images in the CIFAR-10 dataset. @@ -147,8 +148,9 @@ def distorted_inputs(data_dir, batch_size): images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size. labels: Labels. 1D tensor of [batch_size] size. """ - filenames = [os.path.join(data_dir, 'data_batch_%d.bin' % i) - for i in xrange(1, 6)] + filenames = [ + os.path.join(data_dir, 'data_batch_%d.bin' % i) for i in xrange(1, 6) + ] for f in filenames: if not tf.gfile.Exists(f): raise ValueError('Failed to find file: ' + f) @@ -174,10 +176,9 @@ def distorted_inputs(data_dir, batch_size): # Because these operations are not commutative, consider randomizing # the order their operation. - distorted_image = tf.image.random_brightness(distorted_image, - max_delta=63) - distorted_image = tf.image.random_contrast(distorted_image, - lower=0.2, upper=1.8) + distorted_image = tf.image.random_brightness(distorted_image, max_delta=63) + distorted_image = tf.image.random_contrast( + distorted_image, lower=0.2, upper=1.8) # Subtract off the mean and divide by the variance of the pixels. float_image = tf.image.per_image_standardization(distorted_image) @@ -188,15 +189,18 @@ def distorted_inputs(data_dir, batch_size): # Ensure that the random shuffling has good mixing properties. min_fraction_of_examples_in_queue = 0.4 - min_queue_examples = int(NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN * - min_fraction_of_examples_in_queue) - print ('Filling queue with %d CIFAR images before starting to train. ' - 'This will take a few minutes.' % min_queue_examples) + min_queue_examples = int( + NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN * min_fraction_of_examples_in_queue) + print('Filling queue with %d CIFAR images before starting to train. ' + 'This will take a few minutes.' % min_queue_examples) # Generate a batch of images and labels by building up a queue of examples. - return _generate_image_and_label_batch(float_image, read_input.label, - min_queue_examples, batch_size, - shuffle=True) + return _generate_image_and_label_batch( + float_image, + read_input.label, + min_queue_examples, + batch_size, + shuffle=True) def inputs(eval_data, data_dir, batch_size): @@ -212,8 +216,9 @@ def inputs(eval_data, data_dir, batch_size): labels: Labels. 1D tensor of [batch_size] size. """ if not eval_data: - filenames = [os.path.join(data_dir, 'data_batch_%d.bin' % i) - for i in xrange(1, 6)] + filenames = [ + os.path.join(data_dir, 'data_batch_%d.bin' % i) for i in xrange(1, 6) + ] num_examples_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN else: filenames = [os.path.join(data_dir, 'test_batch.bin')] @@ -235,8 +240,8 @@ def inputs(eval_data, data_dir, batch_size): # Image processing for evaluation. # Crop the central [height, width] of the image. - resized_image = tf.image.resize_image_with_crop_or_pad(reshaped_image, - width, height) + resized_image = tf.image.resize_image_with_crop_or_pad( + reshaped_image, width, height) # Subtract off the mean and divide by the variance of the pixels. float_image = tf.image.per_image_standardization(resized_image) @@ -247,10 +252,13 @@ def inputs(eval_data, data_dir, batch_size): # Ensure that the random shuffling has good mixing properties. min_fraction_of_examples_in_queue = 0.4 - min_queue_examples = int(num_examples_per_epoch * - min_fraction_of_examples_in_queue) + min_queue_examples = int( + num_examples_per_epoch * min_fraction_of_examples_in_queue) # Generate a batch of images and labels by building up a queue of examples. - return _generate_image_and_label_batch(float_image, read_input.label, - min_queue_examples, batch_size, - shuffle=False) + return _generate_image_and_label_batch( + float_image, + read_input.label, + min_queue_examples, + batch_size, + shuffle=False) diff --git a/tensorflow/contrib/mpi/mpi_rendezvous_mgr.cc b/tensorflow/contrib/mpi/mpi_rendezvous_mgr.cc index 0252bc79922fc33d5a90590f3f1ebef4d47a27df..6a7f5efecdb4062874a09df227d139ad20d59f3f 100644 --- a/tensorflow/contrib/mpi/mpi_rendezvous_mgr.cc +++ b/tensorflow/contrib/mpi/mpi_rendezvous_mgr.cc @@ -24,11 +24,11 @@ limitations under the License. #include #include -#include "tensorflow/core/distributed_runtime/tensor_coding.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/gpu/gpu_util.h" #include "tensorflow/core/distributed_runtime/session_mgr.h" +#include "tensorflow/core/distributed_runtime/tensor_coding.h" namespace tensorflow { @@ -62,7 +62,6 @@ BaseRemoteRendezvous* MPIRendezvousMgr::Create(int64 step_id, void MPIRemoteRendezvous::RecvFromRemoteAsync( const Rendezvous::ParsedKey& parsed, const Rendezvous::Args& recv_args, DoneCallback done) { - Status s = Status::OK(); MPIRequestTensorCall* rendezvous_call = new MPIRequestTensorCall(); @@ -103,37 +102,37 @@ void MPIRemoteRendezvous::RecvFromRemoteAsync( // Create the function which is called when the Tensor is send by remote const int64 temp1 = step_id_; rendezvous_call->recv_call_ = - [this, parsed, recv_args, done, dst, temp1, rendezvous_call]( - MPIRecvTensorResponse mpi_response) { - Status s; - Device* dst_device; - if (s.ok()) { - s = env_->device_mgr->LookupDevice(parsed.dst_device, &dst_device); - CHECK(s.ok()) << "Device lookup failed"; - } - - VLOG(3) << "MPI Received tensor " << parsed.FullKey() - << " @ step: " << temp1 - << " single-send: " << mpi_response.singlesend(); - - Tensor val; - if (mpi_response.singlesend()) { - dst_device->MakeTensorFromProto(mpi_response.response().tensor(), - recv_args.alloc_attrs, &val); - } else { - TensorResponse tr; - tr.InitAlloc(dst_device, recv_args.alloc_attrs); - tr.InitPartial(mpi_response.response()); - const size_t nBytes = tr.tensor().TotalBytes(); - void* data = const_cast(DMAHelper::base(&tr.tensor())); - MPI_Status status; - MPI_CHECK(MPI_Recv(data, static_cast(nBytes), MPI_BYTE, dst, - TAG_SENDTENSOR2, MPI_COMM_WORLD, &status)); - val = std::move(tr.tensor()); - } - - done(s, Args(), recv_args, val, mpi_response.response().is_dead()); - }; + [this, parsed, recv_args, done, dst, temp1, + rendezvous_call](MPIRecvTensorResponse mpi_response) { + Status s; + Device* dst_device; + if (s.ok()) { + s = env_->device_mgr->LookupDevice(parsed.dst_device, &dst_device); + CHECK(s.ok()) << "Device lookup failed"; + } + + VLOG(3) << "MPI Received tensor " << parsed.FullKey() + << " @ step: " << temp1 + << " single-send: " << mpi_response.singlesend(); + + Tensor val; + if (mpi_response.singlesend()) { + dst_device->MakeTensorFromProto(mpi_response.response().tensor(), + recv_args.alloc_attrs, &val); + } else { + TensorResponse tr; + tr.InitAlloc(dst_device, recv_args.alloc_attrs); + tr.InitPartial(mpi_response.response()); + const size_t nBytes = tr.tensor().TotalBytes(); + void* data = const_cast(DMAHelper::base(&tr.tensor())); + MPI_Status status; + MPI_CHECK(MPI_Recv(data, static_cast(nBytes), MPI_BYTE, dst, + TAG_SENDTENSOR2, MPI_COMM_WORLD, &status)); + val = std::move(tr.tensor()); + } + + done(s, Args(), recv_args, val, mpi_response.response().is_dead()); + }; MPIRendezvousMgr* mgr = reinterpret_cast(this->rendezvous_mgr_); @@ -159,9 +158,11 @@ void MPIRendezvousMgr::AddRequest(RecvTensorRequest request, TF_CHECK_OK(Rendezvous::ParseKey(key, &parsed)); MPIRecvTensorCallBack send_cb = [this, mpi_dst, parsed]( - const Status& status, const Rendezvous::Args& send_args, - const Rendezvous::Args& recv_args, const Tensor& val, bool is_dead, - MPISendTensorCall* mpi_send_call) { + const Status& status, + const Rendezvous::Args& send_args, + const Rendezvous::Args& recv_args, + const Tensor& val, bool is_dead, + MPISendTensorCall* mpi_send_call) { // TODO(jbedorf) this should be a loop over max size CHECK(mpi_send_call->mRes_.ByteSize() < INT_MAX) << "Buffer too large for single transfer"; @@ -194,74 +195,78 @@ void MPIRendezvousMgr::AddRequest(RecvTensorRequest request, }; // Wrapper around the read callback to place the callback on our queue - Rendezvous::DoneCallback done_cb = [this, parsed, step_id, send_cb]( - const Status& status, const Rendezvous::Args& send_args, - const Rendezvous::Args& recv_args, const Tensor& val, bool is_dead) { - if (!status.ok()) { - CHECK(status.ok()) << "RecvLocalAsync was not ok, key: " - << parsed.FullKey() << " step: " << step_id - << " error message: " << status.error_message(); - return; - } - - VLOG(3) << "MPI Sending tensor " << parsed.FullKey() - << " @ step: " << step_id << std::endl; - - auto mpi_send_call = new MPISendTensorCall(); - mpi_send_call->Init(parsed, step_id, is_dead); - - Device* src_dev = nullptr; - Status s = this->worker_env_2->device_mgr->LookupDevice(parsed.src_device, - &src_dev); - CHECK(s.ok()) << "src device not found"; - - // Control if shape and data should be send together or if we can optimize - // it in two different transfers, thereby reducing memory copies - bool doOptimalTransfer = true; - if (!DataTypeCanUseMemcpy(val.dtype())) doOptimalTransfer = false; - if (val.TotalBytes() < 1024) doOptimalTransfer = false; - - doOptimalTransfer = doOptimalTransfer && use_optimal_transfer_; - - if (doOptimalTransfer) { - // First send the Tensor description and in a follow up transfer the data - mpi_send_call->mRes_.mutable_response()->mutable_tensor()->set_dtype( - val.dtype()); - val.shape().AsProto(mpi_send_call->mRes_.mutable_response() - ->mutable_tensor() - ->mutable_tensor_shape()); - mpi_send_call->mRes_.set_singlesend(false); - } else { - // Send the Tensor description and data in a single transfer - if (src_dev->tensorflow_gpu_device_info() && - (!send_args.alloc_attrs.on_host())) { - Notification n; - GPUUtil::SetProtoFromGPU( - val, src_dev, send_args.device_context, - mpi_send_call->mRes_.mutable_response()->mutable_tensor(), is_dead, - [&n, &s](const Status& s_) { - s = s_; - n.Notify(); - }); - n.WaitForNotification(); - } else { - val.AsProtoTensorContent( - mpi_send_call->mRes_.mutable_response()->mutable_tensor()); - } - } - - std::function res = std::bind( - send_cb, status, send_args, recv_args, val, is_dead, mpi_send_call); - - SendQueueEntry req(parsed.FullKey().ToString().c_str(), std::move(res)); - - this->QueueSendRequest(req); - - // Wait for the notification that indicates the tensor has been - // successfully transmitted to the remote process. Only needed if we - // have not parsed the tensor to proto - if (doOptimalTransfer) mpi_send_call->n_.WaitForNotification(); - }; // done_cb + Rendezvous::DoneCallback done_cb = + [this, parsed, step_id, send_cb]( + const Status& status, const Rendezvous::Args& send_args, + const Rendezvous::Args& recv_args, const Tensor& val, bool is_dead) { + if (!status.ok()) { + CHECK(status.ok()) + << "RecvLocalAsync was not ok, key: " << parsed.FullKey() + << " step: " << step_id + << " error message: " << status.error_message(); + return; + } + + VLOG(3) << "MPI Sending tensor " << parsed.FullKey() + << " @ step: " << step_id << std::endl; + + auto mpi_send_call = new MPISendTensorCall(); + mpi_send_call->Init(parsed, step_id, is_dead); + + Device* src_dev = nullptr; + Status s = this->worker_env_2->device_mgr->LookupDevice( + parsed.src_device, &src_dev); + CHECK(s.ok()) << "src device not found"; + + // Control if shape and data should be send together or if we can + // optimize it in two different transfers, thereby reducing memory + // copies + bool doOptimalTransfer = true; + if (!DataTypeCanUseMemcpy(val.dtype())) doOptimalTransfer = false; + if (val.TotalBytes() < 1024) doOptimalTransfer = false; + + doOptimalTransfer = doOptimalTransfer && use_optimal_transfer_; + + if (doOptimalTransfer) { + // First send the Tensor description and in a follow up transfer the + // data + mpi_send_call->mRes_.mutable_response()->mutable_tensor()->set_dtype( + val.dtype()); + val.shape().AsProto(mpi_send_call->mRes_.mutable_response() + ->mutable_tensor() + ->mutable_tensor_shape()); + mpi_send_call->mRes_.set_singlesend(false); + } else { + // Send the Tensor description and data in a single transfer + if (src_dev->tensorflow_gpu_device_info() && + (!send_args.alloc_attrs.on_host())) { + Notification n; + GPUUtil::SetProtoFromGPU( + val, src_dev, send_args.device_context, + mpi_send_call->mRes_.mutable_response()->mutable_tensor(), + is_dead, [&n, &s](const Status& s_) { + s = s_; + n.Notify(); + }); + n.WaitForNotification(); + } else { + val.AsProtoTensorContent( + mpi_send_call->mRes_.mutable_response()->mutable_tensor()); + } + } + + std::function res = std::bind( + send_cb, status, send_args, recv_args, val, is_dead, mpi_send_call); + + SendQueueEntry req(parsed.FullKey().ToString().c_str(), std::move(res)); + + this->QueueSendRequest(req); + + // Wait for the notification that indicates the tensor has been + // successfully transmitted to the remote process. Only needed if we + // have not parsed the tensor to proto + if (doOptimalTransfer) mpi_send_call->n_.WaitForNotification(); + }; // done_cb worker_env_2->compute_pool->Schedule([this, step_id, parsed, done_cb]() { this->RecvLocalAsync(step_id, parsed, done_cb); @@ -293,9 +298,8 @@ void MPIRendezvousMgr::MPIBackgroundThread() { } // Remove sends that have been completed - active_sends.remove_if([](std::unique_ptr& i) { - return i->IsFinished(); - }); + active_sends.remove_if( + [](std::unique_ptr& i) { return i->IsFinished(); }); // send a Tensor request RequestQueueEntry req; diff --git a/tensorflow/contrib/mpi/mpi_rendezvous_mgr.h b/tensorflow/contrib/mpi/mpi_rendezvous_mgr.h index d35e65363f5f031cd3f784e793f3a3d98f61abc7..5596601ddb9846c0e4f5be4bf33114fc19c0a59d 100644 --- a/tensorflow/contrib/mpi/mpi_rendezvous_mgr.h +++ b/tensorflow/contrib/mpi/mpi_rendezvous_mgr.h @@ -18,12 +18,12 @@ limitations under the License. #ifdef TENSORFLOW_USE_MPI -#include -#include #include -#include -#include #include +#include +#include +#include +#include #include #include #include @@ -161,7 +161,8 @@ class MPIRendezvousMgr : public BaseRendezvousMgr { private: typedef std::function MPIRecvTensorCallBack; + const Tensor&, const bool, MPISendTensorCall*)> + MPIRecvTensorCallBack; typedef std::pair> RequestQueueEntry; typedef std::pair> diff --git a/tensorflow/contrib/mpi/mpi_server_lib.cc b/tensorflow/contrib/mpi/mpi_server_lib.cc index d585c0565eb234655e7a1bbc92df5741e18c8f33..a31fa9ce0b3110d875689d74a41ca9f9cc85f532 100644 --- a/tensorflow/contrib/mpi/mpi_server_lib.cc +++ b/tensorflow/contrib/mpi/mpi_server_lib.cc @@ -22,8 +22,8 @@ limitations under the License. #include "grpc/support/alloc.h" -#include "tensorflow/core/distributed_runtime/server_lib.h" #include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h" +#include "tensorflow/core/distributed_runtime/server_lib.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/env.h" diff --git a/tensorflow/contrib/mpi/mpi_utils.h b/tensorflow/contrib/mpi/mpi_utils.h index 45e21f2b25ab4897641ffec776eb1b3c32ab9a2e..fa297c28cb47d43ba927ab941854bd472d90b465 100644 --- a/tensorflow/contrib/mpi/mpi_utils.h +++ b/tensorflow/contrib/mpi/mpi_utils.h @@ -18,8 +18,8 @@ limitations under the License. #ifdef TENSORFLOW_USE_MPI -#include #include +#include #include #include "tensorflow/core/lib/strings/str_util.h" diff --git a/tensorflow/contrib/mpi_collectives/kernels/mpi_ops.cc b/tensorflow/contrib/mpi_collectives/kernels/mpi_ops.cc index 2d5b98022c3aafb627e986a2764ee60184014945..8dca90a1e34d6a234c2b1479ca5594e88afcc194 100644 --- a/tensorflow/contrib/mpi_collectives/kernels/mpi_ops.cc +++ b/tensorflow/contrib/mpi_collectives/kernels/mpi_ops.cc @@ -35,8 +35,8 @@ limitations under the License. #define OMPI_SKIP_MPICXX #include "third_party/mpi/mpi.h" -#include "tensorflow/contrib/mpi_collectives/mpi_message.pb.h" #include "tensorflow/contrib/mpi_collectives/kernels/ring.h" +#include "tensorflow/contrib/mpi_collectives/mpi_message.pb.h" /* * MPI Allreduce and Allgather Ops for TensorFlow. diff --git a/tensorflow/contrib/ndlstm/python/lstm1d.py b/tensorflow/contrib/ndlstm/python/lstm1d.py index b24e332e4aea7f0ef981909558dcd6d730ca08a7..2e2e9086c00b3e7766678b5eb6dca47dc9a5ddcc 100644 --- a/tensorflow/contrib/ndlstm/python/lstm1d.py +++ b/tensorflow/contrib/ndlstm/python/lstm1d.py @@ -88,7 +88,7 @@ def ndlstm_base_dynamic(inputs, noutput, scope=None, reverse=False): if reverse: inputs = array_ops.reverse_v2(inputs, [0]) outputs, _ = rnn.dynamic_rnn( - lstm_cell, inputs, time_major=True, dtype=inputs.dtype) + lstm_cell, inputs, time_major=True, dtype=inputs.dtype) if reverse: outputs = array_ops.reverse_v2(outputs, [0]) return outputs diff --git a/tensorflow/contrib/nearest_neighbor/kernels/heap.h b/tensorflow/contrib/nearest_neighbor/kernels/heap.h index 32925569a82c43be75a0b6e93d7d781cda3d53f4..a2dbb8052bfa1634d27c8b38a9bb6ca27fae42a2 100644 --- a/tensorflow/contrib/nearest_neighbor/kernels/heap.h +++ b/tensorflow/contrib/nearest_neighbor/kernels/heap.h @@ -56,7 +56,7 @@ class HeapBase { // This method adds an element at the end of the internal array without // "heapifying" the array afterwards. This is useful for setting up a heap - // where a single call to heapify at the end of the inital insertion + // where a single call to heapify at the end of the initial insertion // operations suffices. void InsertUnsorted(const KeyType& key, const DataType& data) { if (v_.size() == static_cast(num_elements_)) { diff --git a/tensorflow/contrib/nearest_neighbor/kernels/hyperplane_lsh_probes.cc b/tensorflow/contrib/nearest_neighbor/kernels/hyperplane_lsh_probes.cc index 2b412fac9a621f01bd21c6b4391da3c462dd78b3..13db6f62f525b6318687e3bf4b6499eee2c61ea8 100644 --- a/tensorflow/contrib/nearest_neighbor/kernels/hyperplane_lsh_probes.cc +++ b/tensorflow/contrib/nearest_neighbor/kernels/hyperplane_lsh_probes.cc @@ -75,7 +75,8 @@ class HyperplaneLSHProbesOp : public OpKernel { num_hyperplanes_per_table, ".")); OP_REQUIRES(context, num_hyperplanes_per_table <= 30, InvalidArgument("Need num_hyperplanes_per_table <= 30, got ", - num_hyperplanes_per_table, ". " + num_hyperplanes_per_table, + ". " "If you need more hyperplanes, change this Op" " to work for larger integer types (int64).")); @@ -88,12 +89,13 @@ class HyperplaneLSHProbesOp : public OpKernel { InvalidArgument("num_probes must be at least 1.")); int expected_num_hyperplanes = num_tables * num_hyperplanes_per_table; - OP_REQUIRES( - context, products_tensor.dim_size(1) == expected_num_hyperplanes, - InvalidArgument("Expected number of hyperplanes is ", - expected_num_hyperplanes, " but received ", - products_tensor.dim_size(1), " inner products per " - "point.")); + OP_REQUIRES(context, + products_tensor.dim_size(1) == expected_num_hyperplanes, + InvalidArgument("Expected number of hyperplanes is ", + expected_num_hyperplanes, " but received ", + products_tensor.dim_size(1), + " inner products per " + "point.")); auto products_eigen_tensor = products_tensor.matrix(); ConstMatrixMap products_matrix(products_eigen_tensor.data(), @@ -116,13 +118,11 @@ class HyperplaneLSHProbesOp : public OpKernel { // lschmidt's workstation. int64 cost_per_unit = 21 * num_hyperplanes_per_table * num_tables; if (num_probes > num_tables) { - cost_per_unit += 110 * num_hyperplanes_per_table - * (num_probes - num_tables); + cost_per_unit += + 110 * num_hyperplanes_per_table * (num_probes - num_tables); } context->device()->tensorflow_cpu_worker_threads()->workers->ParallelFor( - batch_size, - cost_per_unit, - [&](int64 start, int64 end) { + batch_size, cost_per_unit, [&](int64 start, int64 end) { HyperplaneMultiprobe multiprobe( num_hyperplanes_per_table, num_tables); diff --git a/tensorflow/contrib/opt/python/training/elastic_average_optimizer.py b/tensorflow/contrib/opt/python/training/elastic_average_optimizer.py index 716ee9cdf704a14a6e433c7f92ccb91739f70655..5763593b81497f5d6945ff1e5d000042d295c093 100644 --- a/tensorflow/contrib/opt/python/training/elastic_average_optimizer.py +++ b/tensorflow/contrib/opt/python/training/elastic_average_optimizer.py @@ -150,7 +150,7 @@ class ElasticAverageOptimizer(optimizer.Optimizer): self._global_map = ea_custom_getter._global_map if moving_rate is None: - self._moving_rate = BETA / communication_period / num_worker + self._moving_rate = self.BETA / communication_period / num_worker else: self._moving_rate = moving_rate if rho is None: diff --git a/tensorflow/contrib/periodic_resample/kernels/periodic_resample_op.cc b/tensorflow/contrib/periodic_resample/kernels/periodic_resample_op.cc index 9cee405cef25f54fd064f8002265c42016c4fa50..e18923c8aae74c66ce78f98eb5e615e99463af74 100644 --- a/tensorflow/contrib/periodic_resample/kernels/periodic_resample_op.cc +++ b/tensorflow/contrib/periodic_resample/kernels/periodic_resample_op.cc @@ -14,13 +14,12 @@ // limitations under the License. // ============================================================================= -#include "tensorflow/core/framework/register_types.h" #include "tensorflow/contrib/periodic_resample/kernels/periodic_resample_op.h" +#include "tensorflow/core/framework/register_types.h" namespace tensorflow { -REGISTER_KERNEL_BUILDER(Name("PeriodicResample") - .Device(DEVICE_CPU), +REGISTER_KERNEL_BUILDER(Name("PeriodicResample").Device(DEVICE_CPU), PeriodicResampleOp); } // namespace tensorflow diff --git a/tensorflow/contrib/periodic_resample/kernels/periodic_resample_op.h b/tensorflow/contrib/periodic_resample/kernels/periodic_resample_op.h index ba410f025d497178cfc1666ae231e75bad55b05e..3ab588c45881c8f93b4c1bcdf7ccde39086a1ed7 100644 --- a/tensorflow/contrib/periodic_resample/kernels/periodic_resample_op.h +++ b/tensorflow/contrib/periodic_resample/kernels/periodic_resample_op.h @@ -118,9 +118,9 @@ template #include -#include #include #include #include #include #include #include -#include -#include #include +#include +#include +#include #include #include "tensorflow/core/framework/graph.pb.h" @@ -46,10 +46,10 @@ limitations under the License. // These are all common classes it's handy to reference with no namespace. using tensorflow::Flag; -using tensorflow::Tensor; +using tensorflow::int32; using tensorflow::Status; using tensorflow::string; -using tensorflow::int32; +using tensorflow::Tensor; // Used to store the memory-mapped buffers we use for capture. struct CameraBuffer { diff --git a/tensorflow/contrib/pi_examples/label_image/label_image.cc b/tensorflow/contrib/pi_examples/label_image/label_image.cc index 0b18045789f3a87ceb228033407d6b696bdb33f6..c6935a093f728353caeeb79a9ed85c957d87f066 100644 --- a/tensorflow/contrib/pi_examples/label_image/label_image.cc +++ b/tensorflow/contrib/pi_examples/label_image/label_image.cc @@ -23,9 +23,9 @@ limitations under the License. // // Full build instructions are at tensorflow/contrib/pi_examples/README.md. -#include #include #include +#include #include #include @@ -46,10 +46,10 @@ limitations under the License. // These are all common classes it's handy to reference with no namespace. using tensorflow::Flag; -using tensorflow::Tensor; +using tensorflow::int32; using tensorflow::Status; using tensorflow::string; -using tensorflow::int32; +using tensorflow::Tensor; // Takes a file name, and loads a list of labels from it, one per line, and // returns a vector of the strings. It pads with empty strings so the length @@ -77,23 +77,22 @@ Status ReadLabelsFile(string file_name, std::vector* result, // Error handling for JPEG decoding. void CatchError(j_common_ptr cinfo) { (*cinfo->err->output_message)(cinfo); - jmp_buf *jpeg_jmpbuf = reinterpret_cast(cinfo->client_data); + jmp_buf* jpeg_jmpbuf = reinterpret_cast(cinfo->client_data); jpeg_destroy(cinfo); longjmp(*jpeg_jmpbuf, 1); } // Decompresses a JPEG file from disk. Status LoadJpegFile(string file_name, std::vector* data, - int* width, int* height, int* channels) { + int* width, int* height, int* channels) { struct jpeg_decompress_struct cinfo; - FILE * infile; + FILE* infile; JSAMPARRAY buffer; int row_stride; if ((infile = fopen(file_name.c_str(), "rb")) == NULL) { LOG(ERROR) << "Can't open " << file_name; - return tensorflow::errors::NotFound("JPEG file ", file_name, - " not found"); + return tensorflow::errors::NotFound("JPEG file ", file_name, " not found"); } struct jpeg_error_mgr jerr; @@ -116,10 +115,11 @@ Status LoadJpegFile(string file_name, std::vector* data, data->resize((*height) * (*width) * (*channels)); row_stride = cinfo.output_width * cinfo.output_components; - buffer = (*cinfo.mem->alloc_sarray) - ((j_common_ptr) &cinfo, JPOOL_IMAGE, row_stride, 1); + buffer = (*cinfo.mem->alloc_sarray)((j_common_ptr)&cinfo, JPOOL_IMAGE, + row_stride, 1); while (cinfo.output_scanline < cinfo.output_height) { - tensorflow::uint8* row_address = &((*data)[cinfo.output_scanline * row_stride]); + tensorflow::uint8* row_address = + &((*data)[cinfo.output_scanline * row_stride]); jpeg_read_scanlines(&cinfo, buffer, 1); memcpy(row_address, buffer[0], row_stride); } @@ -141,24 +141,25 @@ Status ReadTensorFromImageFile(string file_name, const int wanted_height, int image_height; int image_channels; TF_RETURN_IF_ERROR(LoadJpegFile(file_name, &image_data, &image_width, - &image_height, &image_channels)); - LOG(INFO) << "Loaded JPEG: " << image_width << "x" << image_height - << "x" << image_channels; + &image_height, &image_channels)); + LOG(INFO) << "Loaded JPEG: " << image_width << "x" << image_height << "x" + << image_channels; const int wanted_channels = 3; if (image_channels < wanted_channels) { - return tensorflow::errors::FailedPrecondition("Image needs to have at least ", - wanted_channels, " but only has ", - image_channels); + return tensorflow::errors::FailedPrecondition( + "Image needs to have at least ", wanted_channels, " but only has ", + image_channels); } - // In these loops, we convert the eight-bit data in the image into float, resize - // it using bilinear filtering, and scale it numerically to the float range that - // the model expects (given by input_mean and input_std). + // In these loops, we convert the eight-bit data in the image into float, + // resize it using bilinear filtering, and scale it numerically to the float + // range that the model expects (given by input_mean and input_std). tensorflow::Tensor image_tensor( - tensorflow::DT_FLOAT, tensorflow::TensorShape( - {1, wanted_height, wanted_width, wanted_channels})); + tensorflow::DT_FLOAT, + tensorflow::TensorShape( + {1, wanted_height, wanted_width, wanted_channels})); auto image_tensor_mapped = image_tensor.tensor(); tensorflow::uint8* in = image_data.data(); - float *out = image_tensor_mapped.data(); + float* out = image_tensor_mapped.data(); const size_t image_rowlen = image_width * image_channels; const float width_scale = static_cast(image_width) / wanted_width; const float height_scale = static_cast(image_height) / wanted_height; @@ -166,35 +167,37 @@ Status ReadTensorFromImageFile(string file_name, const int wanted_height, const float in_y = y * height_scale; const int top_y_index = static_cast(floorf(in_y)); const int bottom_y_index = - std::min(static_cast(ceilf(in_y)), (image_height - 1)); + std::min(static_cast(ceilf(in_y)), (image_height - 1)); const float y_lerp = in_y - top_y_index; tensorflow::uint8* in_top_row = in + (top_y_index * image_rowlen); tensorflow::uint8* in_bottom_row = in + (bottom_y_index * image_rowlen); - float *out_row = out + (y * wanted_width * wanted_channels); + float* out_row = out + (y * wanted_width * wanted_channels); for (int x = 0; x < wanted_width; ++x) { const float in_x = x * width_scale; const int left_x_index = static_cast(floorf(in_x)); const int right_x_index = - std::min(static_cast(ceilf(in_x)), (image_width - 1)); + std::min(static_cast(ceilf(in_x)), (image_width - 1)); tensorflow::uint8* in_top_left_pixel = - in_top_row + (left_x_index * wanted_channels); + in_top_row + (left_x_index * wanted_channels); tensorflow::uint8* in_top_right_pixel = - in_top_row + (right_x_index * wanted_channels); + in_top_row + (right_x_index * wanted_channels); tensorflow::uint8* in_bottom_left_pixel = - in_bottom_row + (left_x_index * wanted_channels); + in_bottom_row + (left_x_index * wanted_channels); tensorflow::uint8* in_bottom_right_pixel = - in_bottom_row + (right_x_index * wanted_channels); + in_bottom_row + (right_x_index * wanted_channels); const float x_lerp = in_x - left_x_index; - float *out_pixel = out_row + (x * wanted_channels); + float* out_pixel = out_row + (x * wanted_channels); for (int c = 0; c < wanted_channels; ++c) { - const float top_left((in_top_left_pixel[c] - input_mean) / input_std); - const float top_right((in_top_right_pixel[c] - input_mean) / input_std); - const float bottom_left((in_bottom_left_pixel[c] - input_mean) / input_std); - const float bottom_right((in_bottom_right_pixel[c] - input_mean) / input_std); - const float top = top_left + (top_right - top_left) * x_lerp; - const float bottom = - bottom_left + (bottom_right - bottom_left) * x_lerp; - out_pixel[c] = top + (bottom - top) * y_lerp; + const float top_left((in_top_left_pixel[c] - input_mean) / input_std); + const float top_right((in_top_right_pixel[c] - input_mean) / input_std); + const float bottom_left((in_bottom_left_pixel[c] - input_mean) / + input_std); + const float bottom_right((in_bottom_right_pixel[c] - input_mean) / + input_std); + const float top = top_left + (top_right - top_left) * x_lerp; + const float bottom = + bottom_left + (bottom_right - bottom_left) * x_lerp; + out_pixel[c] = top + (bottom - top) * y_lerp; } } } @@ -233,10 +236,10 @@ Status GetTopLabels(const std::vector& outputs, int how_many_labels, scores.push_back(std::pair({i, unsorted_scores_flat(i)})); } std::sort(scores.begin(), scores.end(), - [](const std::pair &left, - const std::pair &right) { - return left.second > right.second; - }); + [](const std::pair& left, + const std::pair& right) { + return left.second > right.second; + }); scores.resize(how_many_labels); Tensor sorted_indices(tensorflow::DT_INT32, {scores.size()}); Tensor sorted_scores(tensorflow::DT_FLOAT, {scores.size()}); diff --git a/tensorflow/contrib/py2tf/BUILD b/tensorflow/contrib/py2tf/BUILD index 3e846aefeb30e29de8b00f76c0b8d7c6053e8099..d91220f6ddb859ff52d4e5853948cb667981009b 100644 --- a/tensorflow/contrib/py2tf/BUILD +++ b/tensorflow/contrib/py2tf/BUILD @@ -18,69 +18,14 @@ py_library( name = "py2tf", srcs = [ "__init__.py", - "api.py", - "config.py", - "conversion.py", - "naming.py", ], srcs_version = "PY2AND3", visibility = ["//visibility:public"], deps = [ - "//tensorflow/contrib/py2tf/converters", + "//tensorflow/contrib/py2tf/impl", "//tensorflow/contrib/py2tf/pyct", - "//tensorflow/contrib/py2tf/pyct/static_analysis", + "//tensorflow/contrib/py2tf/utils", "@gast_archive//:gast", "@six_archive//:six", ], ) - -# Separate target that allows access to internal symbols for testing. -py_library( - name = "py2tf_internal", - srcs = [ - "api.py", - "config.py", - "conversion.py", - "naming.py", - ], - srcs_version = "PY2AND3", - visibility = ["//tensorflow:__subpackages__"], - deps = [ - "//tensorflow/contrib/py2tf/converters", - "//tensorflow/contrib/py2tf/pyct", - "//tensorflow/contrib/py2tf/pyct/static_analysis", - "@gast_archive//:gast", - "@six_archive//:six", - ], -) - -py_test( - name = "api_test", - srcs = ["api_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":py2tf_internal", - "//tensorflow/python:client_testlib", - ], -) - -py_test( - name = "conversion_test", - srcs = ["conversion_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":py2tf_internal", - "//tensorflow/python:client_testlib", - "@gast_archive//:gast", - ], -) - -py_test( - name = "naming_test", - srcs = ["naming_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":py2tf_internal", - "//tensorflow/python:client_testlib", - ], -) diff --git a/tensorflow/contrib/py2tf/__init__.py b/tensorflow/contrib/py2tf/__init__.py index d187da99e065cb2d31ae4e45a9570378f9d1bf27..379fa7fd5c2a22b5b16a21cca8c2ea8afdcaeefa 100644 --- a/tensorflow/contrib/py2tf/__init__.py +++ b/tensorflow/contrib/py2tf/__init__.py @@ -21,11 +21,16 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.py2tf.api import to_code -from tensorflow.contrib.py2tf.api import to_graph +from tensorflow.contrib.py2tf import utils +from tensorflow.contrib.py2tf.impl.api import convert +from tensorflow.contrib.py2tf.impl.api import graph_ready +from tensorflow.contrib.py2tf.impl.api import to_code +from tensorflow.contrib.py2tf.impl.api import to_graph +from tensorflow.contrib.py2tf.pyct.transformer import PyFlowParseError from tensorflow.python.util.all_util import remove_undocumented - -_allowed_symbols = ['to_graph', 'to_code'] +_allowed_symbols = [ + 'to_graph', 'to_code', 'convert', 'graph_ready', 'utils', 'PyFlowParseError' +] remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/py2tf/converters/BUILD b/tensorflow/contrib/py2tf/converters/BUILD index 4f90f94e0960b4afaec1b27d25a5abd53322f229..cb9dec74ffc30c9d8c87a223083abb2650dfd2fa 100644 --- a/tensorflow/contrib/py2tf/converters/BUILD +++ b/tensorflow/contrib/py2tf/converters/BUILD @@ -17,6 +17,7 @@ filegroup( py_library( name = "converters", srcs = [ + "asserts.py", "break_canonicalization.py", "builtin_functions.py", "call_trees.py", @@ -25,7 +26,6 @@ py_library( "decorators.py", "for_canonicalization.py", "logical_expressions.py", - "print_functions.py", "side_effect_guards.py", ], srcs_version = "PY2AND3", @@ -45,10 +45,22 @@ py_library( deps = [ ":converters", "//tensorflow/contrib/py2tf/pyct/static_analysis", + "//tensorflow/contrib/py2tf/utils", "@gast_archive//:gast", ], ) +py_test( + name = "asserts_test", + srcs = ["asserts_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":test_lib", + "//tensorflow/contrib/py2tf/pyct", + "//tensorflow/python:client_testlib", + ], +) + py_test( name = "break_canonicalization_test", srcs = ["break_canonicalization_test.py"], @@ -71,6 +83,17 @@ py_test( ], ) +py_test( + name = "decorators_test", + srcs = ["decorators_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":test_lib", + "//tensorflow/contrib/py2tf/pyct", + "//tensorflow/python:client_testlib", + ], +) + py_test( name = "continue_canonicalization_test", srcs = ["continue_canonicalization_test.py"], @@ -107,6 +130,7 @@ py_test( py_test( name = "for_canonicalization_test", srcs = ["for_canonicalization_test.py"], + srcs_version = "PY2AND3", deps = [ ":test_lib", "//tensorflow/contrib/py2tf/pyct", @@ -125,18 +149,6 @@ py_test( ], ) -py_test( - name = "print_functions_test", - srcs = ["print_functions_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":test_lib", - "//tensorflow/contrib/py2tf/pyct", - "//tensorflow/python:client_testlib", - "@gast_archive//:gast", - ], -) - py_test( name = "side_effect_guards_test", srcs = ["side_effect_guards_test.py"], diff --git a/tensorflow/contrib/py2tf/converters/print_functions.py b/tensorflow/contrib/py2tf/converters/asserts.py similarity index 54% rename from tensorflow/contrib/py2tf/converters/print_functions.py rename to tensorflow/contrib/py2tf/converters/asserts.py index 5da738c4954fb628212562b73641e1fc27032168..2d6ee1d09829b538815dbb9794868c13f51578fc 100644 --- a/tensorflow/contrib/py2tf/converters/print_functions.py +++ b/tensorflow/contrib/py2tf/converters/asserts.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Compatibility support. Converts Print nodes to function calls.""" +"""Converts Assert statements to their corresponding TF calls.""" from __future__ import absolute_import from __future__ import division @@ -20,32 +20,34 @@ from __future__ import print_function import gast -from tensorflow.contrib.py2tf.pyct import anno +from tensorflow.contrib.py2tf.pyct import templates +from tensorflow.contrib.py2tf.pyct import transformer -class PrintFunctionTransformer(gast.NodeTransformer): +class AssertsTransformer(transformer.Base): """Transforms Print nodes to Call so they can be handled as functions.""" # pylint:disable=invalid-name - def visit_Print(self, node): + def visit_Assert(self, node): self.generic_visit(node) - for n in node.values: - n.ctx = gast.Param() - call_node = gast.Call( - func=gast.Name('print', gast.Load(), None), - args=node.values, - keywords=[]) - anno.setanno(call_node.func, 'live_val', print) - anno.setanno(call_node.func, 'fqn', 'print') - anno.setanno(call_node, 'args_scope', anno.getanno(node, 'args_scope')) - node = gast.Expr(call_node) - return node + + # Note: The lone tf.Assert call will be wrapped with control_dependencies + # by side_effect_guards. + template = """ + tf.Assert(test, [tf.constant(msg)]) + """ + + if node.msg is None: + return templates.replace( + template, test=node.test, msg=gast.Str('Assertion error')) + elif isinstance(node.msg, gast.Str): + return templates.replace(template, test=node.test, msg=node.msg) + else: + raise NotImplementedError('Can only convert string messages for now.') # pylint:enable=invalid-name -def transform(node): - transformer = PrintFunctionTransformer() - node = transformer.visit(node) - return node +def transform(node, context): + return AssertsTransformer(context).visit(node) diff --git a/tensorflow/contrib/py2tf/converters/print_functions_test.py b/tensorflow/contrib/py2tf/converters/asserts_test.py similarity index 73% rename from tensorflow/contrib/py2tf/converters/print_functions_test.py rename to tensorflow/contrib/py2tf/converters/asserts_test.py index 475196ce102955b350acf9bf94255997f875f62c..6611f2777a93a7e819c8becfa06a09b27f4e6aaf 100644 --- a/tensorflow/contrib/py2tf/converters/print_functions_test.py +++ b/tensorflow/contrib/py2tf/converters/asserts_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for print_functions module.""" +"""Tests for asserts module.""" from __future__ import absolute_import from __future__ import division @@ -20,24 +20,21 @@ from __future__ import print_function import gast +from tensorflow.contrib.py2tf.converters import asserts from tensorflow.contrib.py2tf.converters import converter_test_base -from tensorflow.contrib.py2tf.converters import print_functions -from tensorflow.contrib.py2tf.pyct import compiler from tensorflow.python.platform import test -class PrintFunctionsTest(converter_test_base.TestCase): +class AssertsTest(converter_test_base.TestCase): def test_transform(self): def test_fn(a): - print(a) + assert a > 0 - node = self.parse_and_analyze(test_fn, {'print': print}) - node = print_functions.transform(node) - result = compiler.ast_to_object(node) + node = self.parse_and_analyze(test_fn, {}) + node = asserts.transform(node, self.ctx) - result.test_fn('a') self.assertTrue(isinstance(node.body[0].body[0].value, gast.Call)) diff --git a/tensorflow/contrib/py2tf/converters/break_canonicalization.py b/tensorflow/contrib/py2tf/converters/break_canonicalization.py index 2ae65e3007466409433e9b4ea0081898907e19ac..bfb709c5e32c6f19dc0fd109df61ece925d701a3 100644 --- a/tensorflow/contrib/py2tf/converters/break_canonicalization.py +++ b/tensorflow/contrib/py2tf/converters/break_canonicalization.py @@ -22,13 +22,15 @@ import gast from tensorflow.contrib.py2tf.pyct import anno from tensorflow.contrib.py2tf.pyct import templates +from tensorflow.contrib.py2tf.pyct import transformer +from tensorflow.contrib.py2tf.pyct.static_analysis.annos import NodeAnno -class BreakCanonicalizationTransformer(gast.NodeTransformer): +class BreakCanonicalizationTransformer(transformer.Base): """Canonicalizes continue statements into additional conditionals.""" - def __init__(self, namer): - self.namer = namer + def __init__(self, context): + super(BreakCanonicalizationTransformer, self).__init__(context) # This is a stack structure, to correctly process nested loops. self.break_uses = [] @@ -67,9 +69,10 @@ class BreakCanonicalizationTransformer(gast.NodeTransformer): def visit_While(self, node): self.generic_visit(node.test) - scope = anno.getanno(node, 'body_scope') + scope = anno.getanno(node, NodeAnno.BODY_SCOPE) - break_var = self.namer.new_symbol('break_requested', scope.referenced) + break_var = self.context.namer.new_symbol('break_requested', + scope.referenced) self.break_uses.append([False, break_var]) node.body = self._manual_visit_list(node.body) if self.break_uses[-1][0]: @@ -89,9 +92,10 @@ class BreakCanonicalizationTransformer(gast.NodeTransformer): def visit_For(self, node): self.generic_visit(node.target) self.generic_visit(node.iter) - scope = anno.getanno(node, 'body_scope') + scope = anno.getanno(node, NodeAnno.BODY_SCOPE) - break_var = self.namer.new_symbol('break_requested', scope.referenced) + break_var = self.context.namer.new_symbol('break_requested', + scope.referenced) self.break_uses.append([False, break_var]) node.body = self._manual_visit_list(node.body) if self.break_uses[-1][0]: @@ -112,7 +116,5 @@ class BreakCanonicalizationTransformer(gast.NodeTransformer): return self._create_break_trigger() -def transform(node, namer): - transformer = BreakCanonicalizationTransformer(namer) - node = transformer.visit(node) - return node +def transform(node, context): + return BreakCanonicalizationTransformer(context).visit(node) diff --git a/tensorflow/contrib/py2tf/converters/break_canonicalization_test.py b/tensorflow/contrib/py2tf/converters/break_canonicalization_test.py index b5ba2ad923dfeb73b38169494f6c7ea16ee815f1..54c4d99361f00ba2b5b79323f5feddcbbdfc99e8 100644 --- a/tensorflow/contrib/py2tf/converters/break_canonicalization_test.py +++ b/tensorflow/contrib/py2tf/converters/break_canonicalization_test.py @@ -44,8 +44,8 @@ class BreakCanonicalizationTest(converter_test_base.TestCase): v.append(x) return v - node = self.parse_and_analyze(test_fn, {}, include_type_analysis=False) - node = break_canonicalization.transform(node, TestNamer()) + node = self.parse_and_analyze(test_fn, {}, namer=TestNamer()) + node = break_canonicalization.transform(node, self.ctx) result = compiler.ast_to_object(node) self.assertEqual(test_fn(0), result.test_fn(0)) @@ -76,8 +76,8 @@ class BreakCanonicalizationTest(converter_test_base.TestCase): v.append(x) return v - node = self.parse_and_analyze(test_fn, {}, include_type_analysis=False) - node = break_canonicalization.transform(node, TestNamer()) + node = self.parse_and_analyze(test_fn, {}, namer=TestNamer()) + node = break_canonicalization.transform(node, self.ctx) result = compiler.ast_to_object(node) # The break is incompletely canonicalized. Everything is in place, but @@ -104,8 +104,8 @@ class BreakCanonicalizationTest(converter_test_base.TestCase): v.append(x) return v, u, w - node = self.parse_and_analyze(test_fn, {}, include_type_analysis=False) - node = break_canonicalization.transform(node, TestNamer()) + node = self.parse_and_analyze(test_fn, {}, namer=TestNamer()) + node = break_canonicalization.transform(node, self.ctx) result = compiler.ast_to_object(node) self.assertEqual(test_fn(0), result.test_fn(0)) diff --git a/tensorflow/contrib/py2tf/converters/builtin_functions.py b/tensorflow/contrib/py2tf/converters/builtin_functions.py index 7f6b64a34c1b95f0dd6b92dbc587da672e6c9c28..3e56634106c2c9c1e4c334d4e61cedee395853a9 100644 --- a/tensorflow/contrib/py2tf/converters/builtin_functions.py +++ b/tensorflow/contrib/py2tf/converters/builtin_functions.py @@ -21,12 +21,14 @@ from __future__ import print_function import gast from tensorflow.contrib.py2tf.pyct import templates +from tensorflow.contrib.py2tf.pyct import transformer -class BuiltinFunctionTransformer(gast.NodeTransformer): +class BuiltinFunctionTransformer(transformer.Base): """Transforms Print nodes to Call so they can be handled as functions.""" - # TODO(mdan): Bring print_functions in here. + def __init__(self, context): + super(BuiltinFunctionTransformer, self).__init__(context) def _convert_len(self, node): template = """ @@ -44,10 +46,15 @@ class BuiltinFunctionTransformer(gast.NodeTransformer): return self._convert_len(node) return node + def visit_Print(self, node): + self.generic_visit(node) + template = """ + fname(args) + """ + return templates.replace(template, fname='print', args=node.values) + # pylint:enable=invalid-name -def transform(node): - transformer = BuiltinFunctionTransformer() - node = transformer.visit(node) - return node +def transform(node, context): + return BuiltinFunctionTransformer(context).visit(node) diff --git a/tensorflow/contrib/py2tf/converters/builtin_functions_test.py b/tensorflow/contrib/py2tf/converters/builtin_functions_test.py index b5358da6bc0be06ec1f59d0ef58d926289b5b78f..be76066242856c85784221166a70299187b11b14 100644 --- a/tensorflow/contrib/py2tf/converters/builtin_functions_test.py +++ b/tensorflow/contrib/py2tf/converters/builtin_functions_test.py @@ -18,6 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import gast + from tensorflow.contrib.py2tf.converters import builtin_functions from tensorflow.contrib.py2tf.converters import converter_test_base from tensorflow.contrib.py2tf.pyct import compiler @@ -34,7 +36,7 @@ class BuiltinFunctionsTest(converter_test_base.TestCase): return len(a) node = self.parse_and_analyze(test_fn, {'len': len}) - node = builtin_functions.transform(node) + node = builtin_functions.transform(node, self.ctx) result = compiler.ast_to_object(node) setattr(result, 'tf', array_ops) @@ -43,6 +45,18 @@ class BuiltinFunctionsTest(converter_test_base.TestCase): sess.run( result.test_fn(constant_op.constant([0, 0, 0])))) + def test_print(self): + + def test_fn(a): + print(a) + + node = self.parse_and_analyze(test_fn, {'print': print}) + node = builtin_functions.transform(node, self.ctx) + result = compiler.ast_to_object(node) + + result.test_fn('a') + self.assertTrue(isinstance(node.body[0].body[0].value, gast.Call)) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/py2tf/converters/call_trees.py b/tensorflow/contrib/py2tf/converters/call_trees.py index 0aae030450ae2b981328f604bfddec2f25e13ec4..834baf258d3b5a0eee02e79ccdf0ebfc61b2c9df 100644 --- a/tensorflow/contrib/py2tf/converters/call_trees.py +++ b/tensorflow/contrib/py2tf/converters/call_trees.py @@ -29,46 +29,47 @@ import gast from tensorflow.contrib.py2tf.pyct import anno from tensorflow.contrib.py2tf.pyct import parser from tensorflow.contrib.py2tf.pyct import templates +from tensorflow.contrib.py2tf.pyct import transformer +from tensorflow.contrib.py2tf.pyct.static_analysis.annos import NodeAnno +from tensorflow.python.util import tf_inspect class FunctionNamer(object): """Describes the interface for CallTreeTransformer's namer.""" def compiled_function_name(self, - original_name, - live_object=None, + original_fqn, + live_entity=None, owner_type=None): """Generate the name corresponding to the compiled version of a function. Args: - original_name: String - live_object: Callable, the actual target function, if known. + original_fqn: string or tuple(string) + live_entity: Callable, the actual target function, if known. owner_type: Optional object. If present, it indicates that the function is a member of the given type. Returns: - String. + string, bool """ raise NotImplementedError() - def compiled_class_name(self, original_name, live_object=None): + def compiled_class_name(self, original_fqn, live_entity=None): """Generate the name corresponding to the compiled version of a class. Args: - original_name: String - live_object: The actual target class, if known. + original_fqn: string or tuple(string) + live_entity: The actual target class, if known. Returns: - String. + string """ raise NotImplementedError() -class CallTreeTransformer(gast.NodeTransformer): +class CallTreeTransformer(transformer.Base): """Transforms the call tree by renaming transformed symbols.""" - def __init__(self, namer, namespace, uncompiled_modules, - nocompile_decorators): - self.namer = namer - self.namespace = namespace + def __init__(self, context, uncompiled_modules, nocompile_decorators): + super(CallTreeTransformer, self).__init__(context) self.uncompiled_modules = uncompiled_modules self.nocompile_decorators = nocompile_decorators @@ -78,7 +79,7 @@ class CallTreeTransformer(gast.NodeTransformer): if isinstance(node, gast.Call): return self._resolve_name(node.func) if isinstance(node, gast.Name): - return self.namespace.get(node.id) + return self.context.namespace.get(node.id) if isinstance(node, gast.Attribute): parent = self._resolve_name(node.value) if parent is not None: @@ -91,8 +92,12 @@ class CallTreeTransformer(gast.NodeTransformer): if anno.hasanno(node, 'live_val'): return anno.getanno(node, 'live_val') if isinstance(node, gast.Attribute) and anno.hasanno(node, 'type'): - member = getattr(anno.getanno(node, 'type'), node.attr) - return member + owner_type = anno.getanno(node, 'type') + if hasattr(owner_type, node.attr): + return getattr(owner_type, node.attr) + else: + raise ValueError('Type "%s" has not attribute "%s". Is it dynamic?' % + (owner_type, node.attr)) return None def _should_compile(self, node, fqn): @@ -106,14 +111,14 @@ class CallTreeTransformer(gast.NodeTransformer): # The decorators themselves are not to be converted. # If present, the decorators should appear as static functions. - target_obj = self._try_resolve_target(node.func) - if target_obj is not None: + target_entity = self._try_resolve_target(node.func) + if target_entity is not None: # This attribute is set by the decorator itself. # TODO(mdan): This may not play nicely with other wrapping decorators. - if hasattr(target_obj, '__pyct_is_compile_decorator'): + if hasattr(target_entity, '__pyct_is_compile_decorator'): return False - if target_obj in self.nocompile_decorators: + if target_entity in self.nocompile_decorators: return False # Inspect the target function decorators. If any include a @convert @@ -122,7 +127,8 @@ class CallTreeTransformer(gast.NodeTransformer): # To parse and re-analize each function for every call site could be quite # wasteful. Maybe we could cache the parsed AST? try: - target_node = parser.parse_object(target_obj).body[0] + target_node, _ = parser.parse_entity(target_entity) + target_node = target_node.body[0] except TypeError: # Functions whose source we cannot access are compilable (e.g. wrapped # to py_func). @@ -136,53 +142,62 @@ class CallTreeTransformer(gast.NodeTransformer): return True + def _determine_function_owner(self, m): + # TODO(mdan): The parent type should be known at analysis. Use that instead. + if hasattr(m, 'im_class'): # Python 2 + return m.im_class + if hasattr(m, '__qualname__'): # Python 3 + # Object attributes: should be bound to "self". + if hasattr(m, '__self__'): + return type(m.__self__) + + # Class attributes: should have the owner name in their namespace. + qn = m.__qualname__.split('.') + if len(qn) < 2: + return None + owner_name, func_name = qn[-2:] + if func_name != m.__name__: + raise ValueError('Inconsistent names detected ' + '(__qualname__[1] = "%s", __name__ = "%s") for %s.' % + (func_name, m.__name__, m)) + if owner_name == '': + return None + if owner_name not in self.context.namespace: + raise ValueError( + 'Could not resolve name "%s" while analyzing %s. Namespace:\n%s' % + (owner_name, m, self.context.namespace)) + return self.context.namespace[owner_name] + return None + def _rename_compilable_function(self, node): assert anno.hasanno(node.func, 'live_val') assert anno.hasanno(node.func, 'fqn') - target_obj = anno.getanno(node.func, 'live_val') + target_entity = anno.getanno(node.func, 'live_val') target_fqn = anno.getanno(node.func, 'fqn') if not self._should_compile(node, target_fqn): return node if anno.hasanno(node, 'is_constructor'): - new_name = self.namer.compiled_class_name( - '__'.join(target_fqn), live_object=target_obj) + new_name = self.context.namer.compiled_class_name( + target_fqn, live_entity=target_entity) + do_rename = True else: - new_name = self.namer.compiled_function_name( - '__'.join(target_fqn), live_object=target_obj) - node.func = gast.Name(new_name, gast.Load(), None) - return node - - def _rename_member_function_of_known_type(self, node): - assert isinstance(node.func, gast.Attribute) - - type_fqn = anno.getanno(node.func, 'type_fqn') - assert anno.hasanno(node.func, 'type') - target_type = anno.getanno(node.func, 'type') - - if not self._should_compile(node, type_fqn): - return node - - # TODO(mdan): We should not assume that the namer only needs the - # member function name. - method_name = node.func.attr - method_object = getattr(target_type, method_name) - new_name = self.namer.compiled_function_name( - method_name, live_object=method_object, owner_type=target_type) - if new_name != node.func.attr: - # If a member function call is renamed, then the new function is no - # longer bound to the target object. We then refactor the call from: - # foo.bar(...) - # to: - # renamed_foo(bar, ...) - # TODO(mdan): This risks causing duplication, if target_type is renamed. - node.args = [node.func.value] + node.args - node.func = gast.Name(new_name, gast.Load(), None) + owner_type = self._determine_function_owner(target_entity) + new_name, do_rename = self.context.namer.compiled_function_name( + target_fqn, live_entity=target_entity, owner_type=owner_type) + + if do_rename: + if target_entity is not None: + if tf_inspect.ismethod(target_entity): + # The renaming process will transform it into a regular function. + # TODO(mdan): Is this complete? How does it work with nested members? + node.args = [node.func.value] + node.args + node.func = templates.replace('func_name', func_name=new_name)[0] return node def _wrap_to_py_func_no_return(self, node): - args_scope = anno.getanno(node, 'args_scope') + args_scope = anno.getanno(node, NodeAnno.ARGS_SCOPE) # TODO(mdan): Properly handle varargs, kwargs, etc. template = """ def wrapper(args): @@ -193,23 +208,23 @@ class CallTreeTransformer(gast.NodeTransformer): wrapper_def, call_expr = templates.replace( template, call=node.func, - wrapper=self.namer.compiled_function_name(node.func.id), - args=tuple(gast.Name(n, gast.Load(), None) for n in args_scope.used)) - anno.setanno(call_expr.value, 'args_scope', args_scope) + wrapper=self.context.namer.compiled_function_name(node.func.id)[0], + args=tuple(args_scope.used)) + anno.setanno(call_expr.value, NodeAnno.ARGS_SCOPE, args_scope) # TODO(mdan): Rename this annotation to 'graph_ready' - anno.setanno(wrapper_def, 'skip_processing', True) + anno.setanno(wrapper_def, anno.Basic.SKIP_PROCESSING, True) return (wrapper_def, call_expr) - def _function_is_compilable(self, target_obj): + def _function_is_compilable(self, target_entity): # TODO(mdan): This is just a placeholder. Implement. - return not isinstance(target_obj, types.BuiltinFunctionType) + return not isinstance(target_entity, types.BuiltinFunctionType) def visit_Expr(self, node): if isinstance(node.value, gast.Call): if anno.hasanno(node.value.func, 'live_val'): - target_obj = anno.getanno(node.value.func, 'live_val') - if not self._function_is_compilable(target_obj): + target_entity = anno.getanno(node.value.func, 'live_val') + if not self._function_is_compilable(target_entity): if anno.hasanno(node.value.func, 'fqn'): target_fqn = anno.getanno(node.value.func, 'fqn') if not self._should_compile(node.value, target_fqn): @@ -227,8 +242,8 @@ class CallTreeTransformer(gast.NodeTransformer): # If the function is wrapped by one of the marker decorators, # consider it graph ready. if anno.hasanno(node.func, 'live_val'): - target_obj = anno.getanno(node.func, 'live_val') - if target_obj in self.nocompile_decorators: + target_entity = anno.getanno(node.func, 'live_val') + if target_entity in self.nocompile_decorators: if len(node.args) < 1: raise ValueError( 'Found call to decorator function "%s", but it had no arguments. ' @@ -237,28 +252,28 @@ class CallTreeTransformer(gast.NodeTransformer): self.generic_visit(node) if anno.hasanno(node.func, 'live_val'): - target_obj = anno.getanno(node.func, 'live_val') - if self._function_is_compilable(target_obj): + target_entity = anno.getanno(node.func, 'live_val') + if self._function_is_compilable(target_entity): node = self._rename_compilable_function(node) else: raise NotImplementedError('py_func with return values') - elif anno.hasanno(node.func, 'type_fqn'): - node = self._rename_member_function_of_known_type(node) else: - raise NotImplementedError( - 'Member function call (of unknown type): %s.' % node.func.id) + if self.context.recursive: + raise NotImplementedError('Could not resolve target function.') + else: + # TODO(mdan): Double check. Is this reachable code? + pass return node # pylint:enable=invalid-name -def transform(node, namer, namespace, uncompiled_modules, nocompile_decorators): +def transform(node, context, uncompiled_modules, nocompile_decorators): """Transform function call to the compiled counterparts. Args: node: AST to transform. - namer: FunctionNamer-like. - namespace: Dict mapping symbol names to their corresponding live objects. + context: An EntityContext object. uncompiled_modules: set of string tuples, each tuple represents the fully qualified name of a package containing functions that will not be compiled. @@ -269,7 +284,6 @@ def transform(node, namer, namespace, uncompiled_modules, nocompile_decorators): node: The transformed AST new_names: set(string), containing any newly-generated names """ - transformer = CallTreeTransformer(namer, namespace, uncompiled_modules, - nocompile_decorators) - node = transformer.visit(node) + t = CallTreeTransformer(context, uncompiled_modules, nocompile_decorators) + node = t.visit(node) return node diff --git a/tensorflow/contrib/py2tf/converters/call_trees_test.py b/tensorflow/contrib/py2tf/converters/call_trees_test.py index 8cb8d7be0f122ed124b0fda69c745a349543a16d..e63c10de0fed72333a6d571f9b9a4f1cb50b5f1d 100644 --- a/tensorflow/contrib/py2tf/converters/call_trees_test.py +++ b/tensorflow/contrib/py2tf/converters/call_trees_test.py @@ -28,8 +28,13 @@ from tensorflow.python.platform import test class TestNamer(call_trees.FunctionNamer): - def compiled_function_name(self, original_name, live_object=None): - return 'renamed_%s' % original_name + def compiled_function_name(self, + original_fqn, + live_entity=None, + owner_type=None): + if owner_type is not None: + return None, False + return ('renamed_%s' % '_'.join(original_fqn)), True class CallTreesTest(converter_test_base.TestCase): @@ -45,14 +50,35 @@ class CallTreesTest(converter_test_base.TestCase): def test_fn_2(a): return test_fn_1(a) + 1 - node = self.parse_and_analyze(test_fn_2, {'test_fn_1': test_fn_1}) - node = call_trees.transform(node, TestNamer(), {}, (), ()) + node = self.parse_and_analyze( + test_fn_2, {'test_fn_1': test_fn_1}, namer=TestNamer()) + node = call_trees.transform(node, self.ctx, (), ()) result = compiler.ast_to_object(node) # Only test_fn_2 is transformed, so we'll insert renamed_test_fn_1 manually. setattr(result, 'renamed_test_fn_1', renamed_test_fn_1) self.assertEquals(3, result.test_fn_2(1)) + def test_simple_methods(self): + + class TestClass(object): + + def test_fn_1(self, a): + return a + 1 + + def test_fn_2(self, a): + return self.test_fn_1(a) + 1 + + node = self.parse_and_analyze( + TestClass.test_fn_2, {'TestClass': TestClass}, + namer=TestNamer(), + arg_types={'self': (TestClass.__name__, TestClass)}) + node = call_trees.transform(node, self.ctx, (), ()) + result = compiler.ast_to_object(node) + + tc = TestClass() + self.assertEquals(3, result.test_fn_2(tc, 1)) + def test_uncompiled_modules(self): def test_fn(a): @@ -60,11 +86,13 @@ class CallTreesTest(converter_test_base.TestCase): a = math_ops.add(a, constant_op.constant(1)) return a - node = self.parse_and_analyze(test_fn, { - 'math_ops': math_ops, - 'constant_op': constant_op - }) - node = call_trees.transform(node, TestNamer(), {}, + node = self.parse_and_analyze( + test_fn, { + 'math_ops': math_ops, + 'constant_op': constant_op + }, + namer=TestNamer()) + node = call_trees.transform(node, self.ctx, set(((math_ops.__name__,), (constant_op.__name__,))), ()) result = compiler.ast_to_object(node) diff --git a/tensorflow/contrib/py2tf/converters/continue_canonicalization.py b/tensorflow/contrib/py2tf/converters/continue_canonicalization.py index 486f0f6509d67d9d981e43ea6e5c77d14e6b23fc..4069a678b118b56b59d2e5491bb80cf52efd8143 100644 --- a/tensorflow/contrib/py2tf/converters/continue_canonicalization.py +++ b/tensorflow/contrib/py2tf/converters/continue_canonicalization.py @@ -18,17 +18,17 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import gast - from tensorflow.contrib.py2tf.pyct import anno from tensorflow.contrib.py2tf.pyct import templates +from tensorflow.contrib.py2tf.pyct import transformer +from tensorflow.contrib.py2tf.pyct.static_analysis.annos import NodeAnno -class ContinueCanonicalizationTransformer(gast.NodeTransformer): +class ContinueCanonicalizationTransformer(transformer.Base): """Canonicalizes continue statements into additional conditionals.""" - def __init__(self, namer): - self.namer = namer + def __init__(self, context): + super(ContinueCanonicalizationTransformer, self).__init__(context) # This is a stack structure, to correctly process nested loops. self.continuation_uses = [] @@ -76,7 +76,7 @@ class ContinueCanonicalizationTransformer(gast.NodeTransformer): return reorganized_nodes def _process_loop_block(self, block, scope): - cont_var = self.namer.new_symbol('cont_requested', scope.referenced) + cont_var = self.context.namer.new_symbol('cont_requested', scope.referenced) self.continuation_uses.append([False, cont_var]) block = self._visit_and_reindent_if_necessary(block) if self.continuation_uses[-1][0]: @@ -87,7 +87,8 @@ class ContinueCanonicalizationTransformer(gast.NodeTransformer): def visit_While(self, node): self.generic_visit(node.test) node.body = self._process_loop_block(node.body, - anno.getanno(node, 'body_scope')) + anno.getanno(node, + NodeAnno.BODY_SCOPE)) for n in node.orelse: self.generic_visit(n) return node @@ -96,7 +97,8 @@ class ContinueCanonicalizationTransformer(gast.NodeTransformer): self.generic_visit(node.target) self.generic_visit(node.iter) node.body = self._process_loop_block(node.body, - anno.getanno(node, 'body_scope')) + anno.getanno(node, + NodeAnno.BODY_SCOPE)) for n in node.orelse: self.generic_visit(n) return node @@ -122,6 +124,4 @@ class ContinueCanonicalizationTransformer(gast.NodeTransformer): def transform(node, namer): - transformer = ContinueCanonicalizationTransformer(namer) - node = transformer.visit(node) - return node + return ContinueCanonicalizationTransformer(namer).visit(node) diff --git a/tensorflow/contrib/py2tf/converters/continue_canonicalization_test.py b/tensorflow/contrib/py2tf/converters/continue_canonicalization_test.py index c1fe903a2dd332626c8e64826652723c30ac412a..4b188195595e4ea4be4359eb87830d587c1de1de 100644 --- a/tensorflow/contrib/py2tf/converters/continue_canonicalization_test.py +++ b/tensorflow/contrib/py2tf/converters/continue_canonicalization_test.py @@ -44,8 +44,8 @@ class ContinueCanonicalizationTest(converter_test_base.TestCase): v.append(x) return v - node = self.parse_and_analyze(test_fn, {}, include_type_analysis=False) - node = continue_canonicalization.transform(node, TestNamer()) + node = self.parse_and_analyze(test_fn, {}, namer=TestNamer()) + node = continue_canonicalization.transform(node, self.ctx) result = compiler.ast_to_object(node) self.assertEqual(test_fn(0), result.test_fn(0)) @@ -65,8 +65,8 @@ class ContinueCanonicalizationTest(converter_test_base.TestCase): v.append(x) return v - node = self.parse_and_analyze(test_fn, {}, include_type_analysis=False) - node = continue_canonicalization.transform(node, TestNamer()) + node = self.parse_and_analyze(test_fn, {}, namer=TestNamer()) + node = continue_canonicalization.transform(node, self.ctx) result = compiler.ast_to_object(node) self.assertEqual(test_fn([]), result.test_fn([])) @@ -91,8 +91,8 @@ class ContinueCanonicalizationTest(converter_test_base.TestCase): v.append(x) return v, u, w - node = self.parse_and_analyze(test_fn, {}, include_type_analysis=False) - node = continue_canonicalization.transform(node, TestNamer()) + node = self.parse_and_analyze(test_fn, {}, namer=TestNamer()) + node = continue_canonicalization.transform(node, self.ctx) result = compiler.ast_to_object(node) self.assertEqual(test_fn(0), result.test_fn(0)) diff --git a/tensorflow/contrib/py2tf/converters/control_flow.py b/tensorflow/contrib/py2tf/converters/control_flow.py index a40c7b28f7bc3b8483b0b18cf11dbf99456df645..a256c074a807b6d8ff3a2573f6d57e82624c0229 100644 --- a/tensorflow/contrib/py2tf/converters/control_flow.py +++ b/tensorflow/contrib/py2tf/converters/control_flow.py @@ -22,6 +22,8 @@ import gast from tensorflow.contrib.py2tf.pyct import anno from tensorflow.contrib.py2tf.pyct import templates +from tensorflow.contrib.py2tf.pyct import transformer +from tensorflow.contrib.py2tf.pyct.static_analysis.annos import NodeAnno class SymbolNamer(object): @@ -41,21 +43,29 @@ class SymbolNamer(object): class SymbolRenamer(gast.NodeTransformer): + """Transformer that can rename symbols to a simple names.""" def __init__(self, name_map): self.name_map = name_map - def visit_Name(self, node): - if node.id in self.name_map: - node.id = self.name_map[node.id] + def _process(self, node): + qn = anno.getanno(node, anno.Basic.QN) + if qn in self.name_map: + return gast.Name(self.name_map[qn], node.ctx, None) return node + def visit_Name(self, node): + return self._process(node) + + def visit_Attribute(self, node): + return self._process(node) + -class ControlFlowTransformer(gast.NodeTransformer): +class ControlFlowTransformer(transformer.Base): """Transforms control flow structures like loops an conditionals.""" - def __init__(self, namer): - self.namer = namer + def __init__(self, context): + super(ControlFlowTransformer, self).__init__(context) # pylint:disable=invalid-name @@ -65,8 +75,8 @@ class ControlFlowTransformer(gast.NodeTransformer): def visit_If(self, node): self.generic_visit(node) - body_scope = anno.getanno(node, 'body_scope') - orelse_scope = anno.getanno(node, 'orelse_scope') + body_scope = anno.getanno(node, NodeAnno.BODY_SCOPE) + orelse_scope = anno.getanno(node, NodeAnno.ORELSE_SCOPE) if body_scope.created - orelse_scope.created: raise ValueError( @@ -86,7 +96,8 @@ class ControlFlowTransformer(gast.NodeTransformer): (body_scope.created | orelse_scope.created)) aliased_orig_names = tuple(need_alias) aliased_new_names = tuple( - self.namer.new_symbol(s, all_referenced) for s in aliased_orig_names) + self.context.namer.new_symbol(s.ssf(), all_referenced) + for s in aliased_orig_names) alias_map = dict(zip(aliased_orig_names, aliased_new_names)) node_body = node.body node_body = [SymbolRenamer(alias_map).visit(n) for n in node_body] @@ -94,72 +105,112 @@ class ControlFlowTransformer(gast.NodeTransformer): node_orelse = [SymbolRenamer(alias_map).visit(n) for n in node_orelse] if len(all_modified) == 1: - results = gast.Name(all_modified[0], None, None) + results = all_modified[0] else: - results = gast.Tuple( - tuple(gast.Name(s, None, None) for s in all_modified), None) - - template = """ - def body_name(): - aliased_new_names, = aliased_orig_names, - body - return (all_results,) - def orelse_name(): - aliased_new_names, = aliased_orig_names, - orelse - return (all_results,) - results = tf.cond(test, body_name, orelse_name) - """ - body_name = self.namer.new_symbol('if_true', all_referenced) - return templates.replace( - template, - test=node.test, - body_name=body_name, - body=node_body, - orelse_name=self.namer.new_symbol('if_false', all_referenced), - orelse=node_orelse, - aliased_orig_names=tuple(aliased_orig_names), - aliased_new_names=tuple(aliased_new_names), - all_results=tuple(alias_map[s] if s in aliased_orig_names else s - for s in all_modified), - results=results) + results = gast.Tuple([s.ast() for s in all_modified], None) + + if aliased_orig_names: + template = """ + def body_name(): + aliased_new_names, = aliased_orig_names, + body + return (all_results,) + def orelse_name(): + aliased_new_names, = aliased_orig_names, + orelse + return (all_results,) + results = tf.cond(test, body_name, orelse_name) + """ + body_name = self.context.namer.new_symbol('if_true', all_referenced) + return templates.replace( + template, + test=node.test, + body_name=body_name, + body=node_body, + orelse_name=self.context.namer.new_symbol('if_false', all_referenced), + orelse=node_orelse, + aliased_orig_names=tuple(aliased_orig_names), + aliased_new_names=tuple(aliased_new_names), + all_results=tuple(alias_map[s] if s in aliased_orig_names else s + for s in all_modified), + results=results) + else: + template = """ + def body_name(): + body + return (all_results,) + def orelse_name(): + orelse + return (all_results,) + results = tf.cond(test, body_name, orelse_name) + """ + body_name = self.context.namer.new_symbol('if_true', all_referenced) + return templates.replace( + template, + test=node.test, + body_name=body_name, + body=node_body, + orelse_name=self.context.namer.new_symbol('if_false', all_referenced), + orelse=node_orelse, + all_results=tuple(s for s in all_modified), + results=results) def visit_While(self, node): self.generic_visit(node) - body_scope = anno.getanno(node, 'body_scope') - body_closure = tuple(body_scope.modified - body_scope.created) - - if len(body_closure) == 1: - state = body_closure[0] + body_scope = anno.getanno(node, NodeAnno.BODY_SCOPE) + body_closure = body_scope.modified - body_scope.created + all_referenced = body_scope.referenced + + state = list(body_closure) + state_ssf = [ + self.context.namer.new_symbol(s.ssf(), all_referenced) for s in state + ] + ssf_map = { + name: ssf + for name, ssf in zip(state, state_ssf) + if str(name) != ssf + } + + if len(state) == 1: + state = state[0] + state_ssf = state_ssf[0] state_ast_tuple = state else: - state = tuple(body_closure) - state_ast_tuple = gast.Tuple( - tuple(gast.Name(n, None, None) for n in state), None) + state_ast_tuple = gast.Tuple([n.ast() for n in state], None) + + node_body = node.body + node_body = [SymbolRenamer(ssf_map).visit(n) for n in node_body] + + test = node.test + test = SymbolRenamer(ssf_map).visit(test) + template = """ - def test_name(state): + def test_name(state_ssf): return test - def body_name(state): + def body_name(state_ssf): body - return state, + return state_ssf, state_ast_tuple = tf.while_loop(test_name, body_name, [state]) """ node = templates.replace( template, state=state, + state_ssf=state_ssf, state_ast_tuple=state_ast_tuple, - test_name=self.namer.new_symbol('loop_test', body_scope.referenced), - test=node.test, - body_name=self.namer.new_symbol('loop_body', body_scope.referenced), - body=node.body) + test_name=self.context.namer.new_symbol('loop_test', + body_scope.referenced), + test=test, + body_name=self.context.namer.new_symbol('loop_body', + body_scope.referenced), + body=node_body) return node # pylint:enable=invalid-name -def transform(node, namer): - transformer = ControlFlowTransformer(namer) - node = transformer.visit(node) +def transform(node, context): + t = ControlFlowTransformer(context) + node = t.visit(node) return node diff --git a/tensorflow/contrib/py2tf/converters/control_flow_test.py b/tensorflow/contrib/py2tf/converters/control_flow_test.py index 054e33750dbae86559a9575dfecde64132b9a2cd..f192bf1b465f6d88578107268d7974cd97a01623 100644 --- a/tensorflow/contrib/py2tf/converters/control_flow_test.py +++ b/tensorflow/contrib/py2tf/converters/control_flow_test.py @@ -49,8 +49,8 @@ class ControlFlowTest(converter_test_base.TestCase): i += 1 return s, i, n - node = self.parse_and_analyze(test_fn, {}) - node = control_flow.transform(node, TestNamer()) + node = self.parse_and_analyze(test_fn, {}, namer=TestNamer()) + node = control_flow.transform(node, self.ctx) result = compiler.ast_to_object(node) setattr(result, 'tf', control_flow_ops) @@ -65,8 +65,8 @@ class ControlFlowTest(converter_test_base.TestCase): n -= 1 return n - node = self.parse_and_analyze(test_fn, {}) - node = control_flow.transform(node, TestNamer()) + node = self.parse_and_analyze(test_fn, {}, namer=TestNamer()) + node = control_flow.transform(node, self.ctx) result = compiler.ast_to_object(node) setattr(result, 'tf', control_flow_ops) @@ -84,8 +84,8 @@ class ControlFlowTest(converter_test_base.TestCase): b = 2 * n return a, b - node = self.parse_and_analyze(test_fn, {}) - node = control_flow.transform(node, TestNamer()) + node = self.parse_and_analyze(test_fn, {}, namer=TestNamer()) + node = control_flow.transform(node, self.ctx) result = compiler.ast_to_object(node) setattr(result, 'tf', control_flow_ops) @@ -102,8 +102,8 @@ class ControlFlowTest(converter_test_base.TestCase): n = -n return n - node = self.parse_and_analyze(test_fn, {}) - node = control_flow.transform(node, TestNamer()) + node = self.parse_and_analyze(test_fn, {}, namer=TestNamer()) + node = control_flow.transform(node, self.ctx) result = compiler.ast_to_object(node) setattr(result, 'tf', control_flow_ops) diff --git a/tensorflow/contrib/py2tf/converters/converter_test_base.py b/tensorflow/contrib/py2tf/converters/converter_test_base.py index ed006bad6d833b3682f819e87aa8b9c279372e51..bcb96c81ae762f90159797c929a261a5b7d4fa83 100644 --- a/tensorflow/contrib/py2tf/converters/converter_test_base.py +++ b/tensorflow/contrib/py2tf/converters/converter_test_base.py @@ -20,7 +20,8 @@ from __future__ import print_function from tensorflow.contrib.py2tf.pyct import context from tensorflow.contrib.py2tf.pyct import parser -from tensorflow.contrib.py2tf.pyct.static_analysis import access +from tensorflow.contrib.py2tf.pyct import qual_names +from tensorflow.contrib.py2tf.pyct.static_analysis import activity from tensorflow.contrib.py2tf.pyct.static_analysis import live_values from tensorflow.contrib.py2tf.pyct.static_analysis import type_info from tensorflow.python.platform import test @@ -31,18 +32,24 @@ class TestCase(test.TestCase): def parse_and_analyze(self, test_fn, namespace, + namer=None, arg_types=None, - include_type_analysis=True): + include_type_analysis=True, + recursive=True): + node, source = parser.parse_entity(test_fn) ctx = context.EntityContext( - namer=None, - source_code=None, + namer=namer, + source_code=source, source_file=None, namespace=namespace, arg_values=None, - arg_types=arg_types) - node = parser.parse_object(test_fn) - node = access.resolve(node) - node = live_values.resolve(node, namespace, {}) + arg_types=arg_types, + recursive=recursive) + node = qual_names.resolve(node) + node = activity.resolve(node, ctx) + node = live_values.resolve(node, ctx, {}) if include_type_analysis: node = type_info.resolve(node, ctx) + node = live_values.resolve(node, ctx, {}) + self.ctx = ctx return node diff --git a/tensorflow/contrib/py2tf/converters/decorators.py b/tensorflow/contrib/py2tf/converters/decorators.py index a4313bfa510a81463a218cd21b41d9a7f43d1892..3f620c1cd2d9b75f82410754a7e812e13eabe3ae 100644 --- a/tensorflow/contrib/py2tf/converters/decorators.py +++ b/tensorflow/contrib/py2tf/converters/decorators.py @@ -12,7 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Handles decorators.""" +"""Handles decorators. + +Note: this module only deals with functions whose decorators are still recorded +in the AST. This does not always happen. See the unit test for an example. +""" from __future__ import absolute_import from __future__ import division @@ -34,17 +38,19 @@ class DecoratorsTransformer(gast.NodeTransformer): def visit_FunctionDef(self, node): self.generic_visit(node) + kept_decorators = [] for dec in node.decorator_list: if isinstance(dec, gast.Call): - dec = dec.func - if not anno.hasanno(dec, 'live_val'): + dec_func = dec.func + else: + dec_func = dec + if not anno.hasanno(dec_func, 'live_val'): raise ValueError( - 'Could not resolve decorator: %s' % pretty_printer.fmt(dec)) - dec_value = anno.getanno(dec, 'live_val') - if dec_value in self.remove_decorators: - continue - raise ValueError('Dont know how to convert decorators for now.') - node.decorator_list = [] + 'Could not resolve decorator: %s' % pretty_printer.fmt(dec_func)) + dec_value = anno.getanno(dec_func, 'live_val') + if dec_value not in self.remove_decorators: + kept_decorators.append(dec) + node.decorator_list = kept_decorators return node # pylint:enable=invalid-name diff --git a/tensorflow/contrib/py2tf/converters/decorators_test.py b/tensorflow/contrib/py2tf/converters/decorators_test.py new file mode 100644 index 0000000000000000000000000000000000000000..f50d593043aeb76d63beb3cb6c301122c9ed8948 --- /dev/null +++ b/tensorflow/contrib/py2tf/converters/decorators_test.py @@ -0,0 +1,96 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for decorators module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import textwrap + +from tensorflow.contrib.py2tf.converters import converter_test_base +from tensorflow.contrib.py2tf.converters import decorators +from tensorflow.contrib.py2tf.pyct import compiler +from tensorflow.python.platform import test +from tensorflow.python.util import tf_inspect + + +class DecoratorsTest(converter_test_base.TestCase): + + def test_function_decorator(self): + + def function_decorator(): + + def decorator(f): + return lambda a: f(a) + 1 + + return decorator + + # The Python parser does capture decorators into the AST. + # However, the interpreter desugars them on load, and refering to the + # decorated function at runtime usually loses any trace of the decorator. + # Below is an example when that doesn't happen. + def static_wrapper(): + + @function_decorator() + def test_fn(a): # pylint:disable=unused-variable + return a + + node = self.parse_and_analyze(static_wrapper, + {'function_decorator': function_decorator}) + node = node.body[0].body[0] + + node = decorators.transform(node, remove_decorators=()) + result = compiler.ast_to_object( + node, + source_prefix=textwrap.dedent(tf_inspect.getsource(function_decorator))) + self.assertEqual(2, result.test_fn(1)) + + node = decorators.transform(node, remove_decorators=(function_decorator,)) + result = compiler.ast_to_object(node) + self.assertEqual(1, result.test_fn(1)) + + def test_simple_decorator(self): + + def simple_decorator(f): + return lambda a: f(a) + 1 + + # The Python parser does capture decorators into the AST. + # However, the interpreter desugars them upon load, and refering to the + # decorated function at runtime usually loses any trace of the decorator. + # Below is an example when that doesn't happen. + def static_wrapper(): + + @simple_decorator + def test_fn(a): # pylint:disable=unused-variable + return a + + node = self.parse_and_analyze(static_wrapper, + {'simple_decorator': simple_decorator}) + node = node.body[0].body[0] + + node = decorators.transform(node, remove_decorators=()) + result = compiler.ast_to_object( + node, + source_prefix=textwrap.dedent(tf_inspect.getsource(simple_decorator))) + self.assertEqual(2, result.test_fn(1)) + + node = decorators.transform(node, remove_decorators=(simple_decorator,)) + result = compiler.ast_to_object(node) + self.assertEqual(1, result.test_fn(1)) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/py2tf/converters/for_canonicalization.py b/tensorflow/contrib/py2tf/converters/for_canonicalization.py index c284689b904c6f372f30e83c259416a51babe4a6..935dade0ed30975dd29c8ffe5be875993936d241 100644 --- a/tensorflow/contrib/py2tf/converters/for_canonicalization.py +++ b/tensorflow/contrib/py2tf/converters/for_canonicalization.py @@ -22,24 +22,21 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import gast - from tensorflow.contrib.py2tf.pyct import anno from tensorflow.contrib.py2tf.pyct import templates +from tensorflow.contrib.py2tf.pyct import transformer +from tensorflow.contrib.py2tf.pyct.static_analysis.annos import NodeAnno -class ForLoopCanonicalizationTransformer(gast.NodeTransformer): +class ForLoopCanonicalizationTransformer(transformer.Base): """Canonicalizes for loops (e.g. into while loops).""" - def __init__(self, namer): - self.namer = namer + def __init__(self, context): + super(ForLoopCanonicalizationTransformer, self).__init__(context) def visit_For(self, node): self.generic_visit(node) - body_scope = anno.getanno(node, 'body_scope') - - # TODO(mdan): Distinguish between `for i in n` and `for i in range(n)` - # Or maybe we should replace range with tf.range? + body_scope = anno.getanno(node, NodeAnno.BODY_SCOPE) if anno.hasanno(node, 'extra_cond'): template = """ @@ -56,8 +53,8 @@ class ForLoopCanonicalizationTransformer(gast.NodeTransformer): loop_iter=node.iter, target=node.target, body=node.body, - i=self.namer.new_symbol('i', body_scope.referenced), - n=self.namer.new_symbol('n', body_scope.referenced), + i=self.context.namer.new_symbol('i', body_scope.referenced), + n=self.context.namer.new_symbol('n', body_scope.referenced), extra_cond=anno.getanno(node, 'extra_cond')) else: template = """ @@ -69,13 +66,14 @@ class ForLoopCanonicalizationTransformer(gast.NodeTransformer): body # pylint:disable=pointless-statement i += 1 """ - return templates.replace( + repl = templates.replace( template, loop_iter=node.iter, target=node.target, body=node.body, - i=self.namer.new_symbol('i', body_scope.referenced), - n=self.namer.new_symbol('n', body_scope.referenced)) + i=self.context.namer.new_symbol('i', body_scope.referenced), + n=self.context.namer.new_symbol('n', body_scope.referenced)) + return repl def visit_Continue(self, node): assert False, 'continue statement should be desugared at this point' @@ -84,7 +82,5 @@ class ForLoopCanonicalizationTransformer(gast.NodeTransformer): assert False, 'break statement should be desugared at this point' -def transform(node, namer): - transformer = ForLoopCanonicalizationTransformer(namer) - node = transformer.visit(node) - return node +def transform(node, context): + return ForLoopCanonicalizationTransformer(context).visit(node) diff --git a/tensorflow/contrib/py2tf/converters/for_canonicalization_test.py b/tensorflow/contrib/py2tf/converters/for_canonicalization_test.py index a6e6350fd45e9c9575af9c12d3d0c4e9b89bee41..142bd4aea12fa26468372472cb8a08e1f6b0e8ac 100644 --- a/tensorflow/contrib/py2tf/converters/for_canonicalization_test.py +++ b/tensorflow/contrib/py2tf/converters/for_canonicalization_test.py @@ -41,8 +41,8 @@ class ControlFlowTest(converter_test_base.TestCase): s += e return s - node = self.parse_and_analyze(test_fn, {}) - node = for_canonicalization.transform(node, TestNamer()) + node = self.parse_and_analyze(test_fn, {}, namer=TestNamer()) + node = for_canonicalization.transform(node, self.ctx) result = compiler.ast_to_object(node) l = [1, 2, 3] diff --git a/tensorflow/contrib/py2tf/converters/side_effect_guards.py b/tensorflow/contrib/py2tf/converters/side_effect_guards.py index 4df723989d4710c5bf1aa5568321b17ed98bbd42..7ece8135d674f1af60ced16d6c9976681bb62376 100644 --- a/tensorflow/contrib/py2tf/converters/side_effect_guards.py +++ b/tensorflow/contrib/py2tf/converters/side_effect_guards.py @@ -40,6 +40,8 @@ import gast from tensorflow.contrib.py2tf.pyct import anno from tensorflow.contrib.py2tf.pyct import templates +from tensorflow.contrib.py2tf.pyct import transformer +from tensorflow.contrib.py2tf.pyct.static_analysis.annos import NodeAnno class SymbolNamer(object): @@ -57,11 +59,11 @@ class SymbolNamer(object): raise NotImplementedError() -class SideEffectGuardTransformer(gast.NodeTransformer): +class SideEffectGuardTransformer(transformer.Base): """Adds control dependencies to functions with side effects.""" - def __init__(self, namer): - self.namer = namer + def __init__(self, context): + super(SideEffectGuardTransformer, self).__init__(context) self.indent_next = False self.next_indent_owner = None @@ -90,12 +92,11 @@ class SideEffectGuardTransformer(gast.NodeTransformer): return new_nodes def visit_FunctionDef(self, node): - if anno.hasanno(node, 'skip_processing'): - return node node.body = self._visit_and_reindent(node.body) return node def _gate_symbols(self, guard_statement, guarded_args): + # TODO(mdan): This won't work for variables. template = """ (args,) = (tf.identity(a) for a in (args,)) """ @@ -110,33 +111,22 @@ class SideEffectGuardTransformer(gast.NodeTransformer): # opt.minimize(loss) # or: # tf.py_func(...) - args_scope = anno.getanno(node.value, 'args_scope') - temp_name = self.namer.new_symbol('temp', args_scope.parent.referenced) - # TODO(mdan): Unsafe reference modification! - args_scope.mark_write(temp_name) template = """ - temp_result = call - if temp_result is not None: - if not isinstance(temp_result, (list, tuple)): - temp_result = (temp_result,) - ctx = tf.control_dependencies(temp_result) - else: - ctx = contextmanager(lambda: (yield))() - with ctx: - # TODO(mdan): Also insert ops to re-fetch if variables are involved. + with py2tf_utils.control_dependency_on_returns(tf, call): + # TODO(mdan): Also insert ops to re-fetch if variables are involved? pass # Will be removed below. """ # TODO(mdan): This is brittle. Reorganize the mechanism. - statements = templates.replace( - template, call=node.value, temp_result=temp_name) + statements = templates.replace(template, call=node.value) control_deps_guard = statements[-1] control_deps_guard.body = [] # First, attempt to gate future evaluation of args. If that's not # possible, gate all remaining statements (and that may fail too, see # _visit_and_reindent. - guarded_args = tuple( - n for n in args_scope.used if n in args_scope.parent.modified) + args_scope = anno.getanno(node.value, NodeAnno.ARGS_SCOPE) + guarded_args = tuple(args_scope.used & (args_scope.parent.modified + | args_scope.parent.returned)) if guarded_args: node = tuple(statements[:-1]) + ( self._gate_symbols(control_deps_guard, guarded_args),) @@ -150,6 +140,5 @@ class SideEffectGuardTransformer(gast.NodeTransformer): # pylint:enable=invalid-name -def transform(node, namer): - transformer = SideEffectGuardTransformer(namer) - return transformer.visit(node) +def transform(node, context): + return SideEffectGuardTransformer(context).visit(node) diff --git a/tensorflow/contrib/py2tf/converters/side_effect_guards_test.py b/tensorflow/contrib/py2tf/converters/side_effect_guards_test.py index 5c56973dc2ae5d1976a68f040772e856cdaeabf5..dea09ecc3ff8c566e7bd6c440bc1a146aaf15121 100644 --- a/tensorflow/contrib/py2tf/converters/side_effect_guards_test.py +++ b/tensorflow/contrib/py2tf/converters/side_effect_guards_test.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.contrib.py2tf import utils from tensorflow.contrib.py2tf.converters import converter_test_base from tensorflow.contrib.py2tf.converters import side_effect_guards from tensorflow.contrib.py2tf.pyct import compiler @@ -42,10 +43,12 @@ class SideEffectGuardsTest(converter_test_base.TestCase): state_ops.assign(a, a + 1) return a - node = self.parse_and_analyze(test_fn, {'state_ops': state_ops}) - node = side_effect_guards.transform(node, TestNamer()) + node = self.parse_and_analyze( + test_fn, {'state_ops': state_ops}, namer=TestNamer()) + node = side_effect_guards.transform(node, self.ctx) result = compiler.ast_to_object(node) setattr(result, 'state_ops', state_ops) + setattr(result, 'py2tf_utils', utils) # TODO(mdan): Configure the namespaces instead of doing these hacks. ops.identity = array_ops.identity diff --git a/tensorflow/contrib/py2tf/impl/BUILD b/tensorflow/contrib/py2tf/impl/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..22f0c25cabcd44261c0b42091b50969500db2193 --- /dev/null +++ b/tensorflow/contrib/py2tf/impl/BUILD @@ -0,0 +1,65 @@ +licenses(["notice"]) # Apache 2.0 + +load("//tensorflow:tensorflow.bzl", "py_test") + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) + +py_library( + name = "impl", + srcs = [ + "api.py", + "config.py", + "conversion.py", + "naming.py", + ], + srcs_version = "PY2AND3", + visibility = ["//tensorflow:__subpackages__"], + deps = [ + "//tensorflow/contrib/py2tf/converters", + "//tensorflow/contrib/py2tf/pyct", + "//tensorflow/contrib/py2tf/pyct/static_analysis", + "@gast_archive//:gast", + "@six_archive//:six", + ], +) + +py_test( + name = "api_test", + srcs = ["api_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":impl", + "//tensorflow/python:client_testlib", + ], +) + +py_test( + name = "conversion_test", + srcs = ["conversion_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":impl", + "//tensorflow/python:client_testlib", + "@gast_archive//:gast", + ], +) + +py_test( + name = "naming_test", + srcs = ["naming_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":impl", + "//tensorflow/python:client_testlib", + ], +) diff --git a/tensorflow/contrib/py2tf/api.py b/tensorflow/contrib/py2tf/impl/api.py similarity index 97% rename from tensorflow/contrib/py2tf/api.py rename to tensorflow/contrib/py2tf/impl/api.py index ca1f4e2645ee20fd78c0d837885823d2e199537a..85d40f31580d156bf719e059bb3580a068595cb5 100644 --- a/tensorflow/contrib/py2tf/api.py +++ b/tensorflow/contrib/py2tf/impl/api.py @@ -23,8 +23,8 @@ from functools import wraps import gast import six -from tensorflow.contrib.py2tf import config -from tensorflow.contrib.py2tf import conversion +from tensorflow.contrib.py2tf.impl import config +from tensorflow.contrib.py2tf.impl import conversion from tensorflow.contrib.py2tf.pyct import compiler from tensorflow.contrib.py2tf.pyct import parser from tensorflow.python.util import tf_inspect @@ -86,8 +86,8 @@ def convert_inline(f, *args, **kwargs): def convert(recursive=False, arg_types=None): """Decorator that compiles a function to graph mode. - The decorator is dynamic - invoking compilation whenever the decorated function - is called. This means the parameter values are known at compilation. + The decorator is dynamic - invoking compilation whenever the decorated + function is called. This means the parameter values are known at compilation. Args: recursive: Whether to recusrively convert any functions that the decorator diff --git a/tensorflow/contrib/py2tf/api_test.py b/tensorflow/contrib/py2tf/impl/api_test.py similarity index 98% rename from tensorflow/contrib/py2tf/api_test.py rename to tensorflow/contrib/py2tf/impl/api_test.py index 2384447708d7e0ab5dbfbeb592a47353f1909f50..dbd079a3ca6d09824f24c6f0bd7647758d3a5552 100644 --- a/tensorflow/contrib/py2tf/api_test.py +++ b/tensorflow/contrib/py2tf/impl/api_test.py @@ -18,8 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.py2tf import api -from tensorflow.contrib.py2tf import config +from tensorflow.contrib.py2tf.impl import api +from tensorflow.contrib.py2tf.impl import config from tensorflow.contrib.py2tf.pyct import parser from tensorflow.python.framework import constant_op from tensorflow.python.ops import math_ops diff --git a/tensorflow/contrib/py2tf/config.py b/tensorflow/contrib/py2tf/impl/config.py similarity index 85% rename from tensorflow/contrib/py2tf/config.py rename to tensorflow/contrib/py2tf/impl/config.py index 8c502a7a9e546dd9b9b40d7cf6d3c9821038afb3..6525806a0933dd9f0a237e278bb70b88346bea27 100644 --- a/tensorflow/contrib/py2tf/config.py +++ b/tensorflow/contrib/py2tf/impl/config.py @@ -32,7 +32,9 @@ DEFAULT_UNCOMPILED_MODULES = set(( NO_SIDE_EFFECT_CONSTRUCTORS = set(('tensorflow',)) # TODO(mdan): Also allow controlling the generated names (for testability). +# TODO(mdan): Verify that these names are not hidden by generated code. +# TODO(mdan): Make sure copybara renames the reference below. COMPILED_IMPORT_STATEMENTS = ( - 'from contextlib import contextmanager', 'import tensorflow as tf', -) + 'from tensorflow.contrib.py2tf import utils as ' + 'py2tf_utils') diff --git a/tensorflow/contrib/py2tf/conversion.py b/tensorflow/contrib/py2tf/impl/conversion.py similarity index 83% rename from tensorflow/contrib/py2tf/conversion.py rename to tensorflow/contrib/py2tf/impl/conversion.py index b484eebbd58b955d1e783359269d16101d83cfd2..ff4f159975578dada45542df39f7ebbb61dd2e36 100644 --- a/tensorflow/contrib/py2tf/conversion.py +++ b/tensorflow/contrib/py2tf/impl/conversion.py @@ -21,8 +21,7 @@ from __future__ import print_function import gast import six -from tensorflow.contrib.py2tf import config -from tensorflow.contrib.py2tf import naming +from tensorflow.contrib.py2tf.converters import asserts from tensorflow.contrib.py2tf.converters import break_canonicalization from tensorflow.contrib.py2tf.converters import builtin_functions from tensorflow.contrib.py2tf.converters import call_trees @@ -31,11 +30,13 @@ from tensorflow.contrib.py2tf.converters import control_flow from tensorflow.contrib.py2tf.converters import decorators from tensorflow.contrib.py2tf.converters import for_canonicalization from tensorflow.contrib.py2tf.converters import logical_expressions -from tensorflow.contrib.py2tf.converters import print_functions from tensorflow.contrib.py2tf.converters import side_effect_guards +from tensorflow.contrib.py2tf.impl import config +from tensorflow.contrib.py2tf.impl import naming from tensorflow.contrib.py2tf.pyct import context from tensorflow.contrib.py2tf.pyct import parser -from tensorflow.contrib.py2tf.pyct.static_analysis import access +from tensorflow.contrib.py2tf.pyct import qual_names +from tensorflow.contrib.py2tf.pyct.static_analysis import activity from tensorflow.contrib.py2tf.pyct.static_analysis import live_values from tensorflow.contrib.py2tf.pyct.static_analysis import type_info from tensorflow.python.util import tf_inspect @@ -171,7 +172,8 @@ def class_to_graph(c, conversion_map): def function_to_graph(f, conversion_map, arg_values, arg_types, owner_type=None): """Specialization of `entity_to_graph` for callable functions.""" - node = parser.parse_object(f).body[0] + node, source = parser.parse_entity(f) + node = node.body[0] namespace = six.get_function_globals(f) # This is needed for non-global functions. @@ -185,28 +187,30 @@ def function_to_graph(f, conversion_map, arg_values, arg_types, namer = conversion_map.new_namer(namespace) ctx = context.EntityContext( namer=namer, - source_code=tf_inspect.getsource(f), - source_file=tf_inspect.getfile(f), + source_code=source, + source_file='', namespace=namespace, arg_values=arg_values, - arg_types=arg_types) + arg_types=arg_types, + recursive=conversion_map.recursive) node = node_to_graph(node, ctx, conversion_map.nocompile_decorators) - # Simulate a rename to ensure the top level is in the name map. This is needed - # for top level functions, and it also helps the consistency verification made - # by update_name_map. - if owner_type is not None: - new_name = namer.compiled_function_name(f.__name__, f, owner_type) - else: - new_name = namer.compiled_function_name(f.__name__, f) + # TODO(mdan): This somewhat duplicates the call rename logic in call_treest.py + new_name, did_rename = namer.compiled_function_name(f.__name__, f, owner_type) + if not did_rename: + new_name = f.__name__ + if node.name != f.__name__: + raise NotImplementedError('Strange corner case. Send us offending code!') + node.name = new_name conversion_map.update_name_map(namer) - return node, conversion_map.name_map[f] + return node, new_name def _static_analysis_pass(node, ctx): - node = access.resolve(node) - node = live_values.resolve(node, ctx.namespace, config.PYTHON_LITERALS) + node = qual_names.resolve(node) + node = activity.resolve(node, ctx, None) + node = live_values.resolve(node, ctx, config.PYTHON_LITERALS) node = type_info.resolve(node, ctx) return node @@ -230,10 +234,7 @@ def node_to_graph(node, ctx, nocompile_decorators): # TODO(mdan): Factor out common elements. # These include: - # * keeping track of symbols that have been created - # * marking nodes (e.g. py_func wrappers) to suppress further processing # * code move between blocks - # * insertion of new global references # * visiting blocks in transformers # Certain steps, especially canonicalization, insert new symbols into the @@ -241,29 +242,35 @@ def node_to_graph(node, ctx, nocompile_decorators): # to re-run the analysis. node = _static_analysis_pass(node, ctx) + # Past this point, line numbers are no longer accurate so we ignore the + # source. + # TODO(mdan): Is it feasible to reconstruct intermediate source code? + ctx.source_code = None node = decorators.transform(node, nocompile_decorators) - node = break_canonicalization.transform(node, ctx.namer) + node = break_canonicalization.transform(node, ctx) + node = asserts.transform(node, ctx) # Note: sequencing continue canonicalization before for loop one avoids # dealing with the extra loop increment operation that the for # canonicalization creates. - node = continue_canonicalization.transform(node, ctx.namer) + node = continue_canonicalization.transform(node, ctx) ctx.namespace['len'] = len node = _static_analysis_pass(node, ctx) - node = for_canonicalization.transform(node, ctx.namer) + node = for_canonicalization.transform(node, ctx) # for_canonicalization may insert new global references. - node = builtin_functions.transform(node) + node = builtin_functions.transform(node, ctx) # builtin_functions may insert new global references. ctx.namespace['print'] = print node = _static_analysis_pass(node, ctx) - node = print_functions.transform(node) - node = call_trees.transform(node, ctx.namer, ctx.namespace, - config.DEFAULT_UNCOMPILED_MODULES, + node = call_trees.transform(node, ctx, config.DEFAULT_UNCOMPILED_MODULES, nocompile_decorators) - node = control_flow.transform(node, ctx.namer) + node = control_flow.transform(node, ctx) + + # control_flow may create new symbols and change scopes. + node = _static_analysis_pass(node, ctx) node = logical_expressions.transform(node) - node = side_effect_guards.transform(node, ctx.namer) + node = side_effect_guards.transform(node, ctx) return node diff --git a/tensorflow/contrib/py2tf/conversion_test.py b/tensorflow/contrib/py2tf/impl/conversion_test.py similarity index 97% rename from tensorflow/contrib/py2tf/conversion_test.py rename to tensorflow/contrib/py2tf/impl/conversion_test.py index 26f915f4f46e54c9648ae6b35415c4e2639af774..3888958f19b9fa13b759924c5188722e500e30a1 100644 --- a/tensorflow/contrib/py2tf/conversion_test.py +++ b/tensorflow/contrib/py2tf/impl/conversion_test.py @@ -20,7 +20,7 @@ from __future__ import print_function import gast -from tensorflow.contrib.py2tf import conversion +from tensorflow.contrib.py2tf.impl import conversion from tensorflow.python.platform import test diff --git a/tensorflow/contrib/py2tf/naming.py b/tensorflow/contrib/py2tf/impl/naming.py similarity index 57% rename from tensorflow/contrib/py2tf/naming.py rename to tensorflow/contrib/py2tf/impl/naming.py index a90758962b83e1616f7d727440eb7481c49343ad..d31462cba060bf6c04eefbad3ce7f166db994ab3 100644 --- a/tensorflow/contrib/py2tf/naming.py +++ b/tensorflow/contrib/py2tf/impl/naming.py @@ -18,7 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.util import tf_inspect +from tensorflow.contrib.py2tf.pyct import qual_names class Namer(object): @@ -45,10 +45,15 @@ class Namer(object): self.generated_names = set() - def compiled_class_name(self, original_name, live_object=None): + def compiled_class_name(self, original_fqn, live_entity=None): """See call_trees.FunctionNamer.compiled_class_name.""" - if live_object is not None and live_object in self.renamed_calls: - return self.renamed_calls[live_object] + if live_entity is not None and live_entity in self.renamed_calls: + return self.renamed_calls[live_entity] + + if isinstance(original_fqn, tuple): + original_name = '__'.join(original_fqn) + else: + original_name = original_fqn new_name_root = 'Tf%s' % original_name new_name = new_name_root @@ -57,49 +62,63 @@ class Namer(object): n += 1 new_name = '%s_%d' % (new_name_root, n) - if live_object is not None: - self.renamed_calls[live_object] = new_name + if live_entity is not None: + self.renamed_calls[live_entity] = new_name self.generated_names.add(new_name) + if live_entity is not None: + self.renamed_calls[live_entity] = new_name return new_name def compiled_function_name(self, - original_name, - live_object=None, + original_fqn, + live_entity=None, owner_type=None): """See call_trees.FunctionNamer.compiled_function_name.""" - if live_object is not None and live_object in self.renamed_calls: - return self.renamed_calls[live_object] if not self.recursive: - new_name = original_name - elif owner_type is None or owner_type in self.partial_types: - # Top level functions: rename - new_name_root = 'tf__%s' % original_name - new_name = new_name_root - n = 0 - while new_name in self.global_namespace: - n += 1 - new_name = '%s_%d' % (new_name_root, n) + return None, False + + if owner_type is not None and owner_type not in self.partial_types: + # Members are not renamed when part of an entire converted class. + return None, False + + if isinstance(original_fqn, tuple): + original_name = '__'.join(original_fqn) else: - if tf_inspect.isclass(owner_type): - # Class members: do not rename (the entire class will be renamed) - new_name = original_name - else: - raise NotImplementedError('Member function "%s" of non-class type: %s' % - (original_name, owner_type)) + original_name = original_fqn + + if live_entity is not None and live_entity in self.renamed_calls: + return self.renamed_calls[live_entity], True + + new_name_root = 'tf__%s' % original_name + new_name = new_name_root + n = 0 + while new_name in self.global_namespace: + n += 1 + new_name = '%s_%d' % (new_name_root, n) - if live_object is not None: - self.renamed_calls[live_object] = new_name + if live_entity is not None: + self.renamed_calls[live_entity] = new_name self.generated_names.add(new_name) - return new_name + + return new_name, True def new_symbol(self, name_root, reserved_locals): """See control_flow.SymbolNamer.new_symbol.""" + # reserved_locals may contain QNs. + all_reserved_locals = set() + for s in reserved_locals: + if isinstance(s, qual_names.QN): + all_reserved_locals.update(s.qn) + elif isinstance(s, str): + all_reserved_locals.add(s) + else: + raise ValueError('Unexpected symbol type "%s"' % type(s)) + new_name = name_root n = 0 - while (new_name in self.global_namespace - or new_name in reserved_locals - or new_name in self.generated_names): + while (new_name in self.global_namespace or + new_name in all_reserved_locals or new_name in self.generated_names): n += 1 new_name = '%s_%d' % (name_root, n) diff --git a/tensorflow/contrib/py2tf/naming_test.py b/tensorflow/contrib/py2tf/impl/naming_test.py similarity index 82% rename from tensorflow/contrib/py2tf/naming_test.py rename to tensorflow/contrib/py2tf/impl/naming_test.py index 7bfc9b8733b6efc3ab440ae5a0614258ae395ad4..beb4e54937bbb91b19157c9b9e3c528353206c62 100644 --- a/tensorflow/contrib/py2tf/naming_test.py +++ b/tensorflow/contrib/py2tf/impl/naming_test.py @@ -18,7 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.py2tf import naming +from tensorflow.contrib.py2tf.impl import naming from tensorflow.python.platform import test @@ -29,8 +29,9 @@ class NamerTest(test.TestCase): pass namer = naming.Namer({}, True, None, ()) - self.assertEqual('tf__foo', namer.compiled_function_name('foo')) - self.assertEqual('tf__bar', namer.compiled_function_name('bar', bar)) + self.assertEqual(('tf__foo', True), namer.compiled_function_name('foo')) + self.assertEqual(('tf__bar', True), namer.compiled_function_name( + 'bar', bar)) self.assertEqual({bar: 'tf__bar'}, namer.renamed_calls) self.assertItemsEqual(('tf__bar', 'tf__foo'), namer.generated_names) @@ -39,15 +40,18 @@ class NamerTest(test.TestCase): pass namer = naming.Namer({}, True, None, ()) - self.assertEqual('tf__foo', namer.compiled_function_name('foo', foo)) - self.assertEqual('tf__foo', namer.compiled_function_name('foo', foo)) + self.assertEqual(('tf__foo', True), namer.compiled_function_name( + 'foo', foo)) + self.assertEqual(('tf__foo', True), namer.compiled_function_name( + 'foo', foo)) def test_compiled_function_name_avoids_global_conflicts(self): def foo(): pass namer = naming.Namer({'tf__foo': 1}, True, None, ()) - self.assertEqual('tf__foo_1', namer.compiled_function_name('foo', foo)) + self.assertEqual(('tf__foo_1', True), + namer.compiled_function_name('foo', foo)) def test_new_symbol_tracks_names(self): namer = naming.Namer({}, True, None, ()) diff --git a/tensorflow/contrib/py2tf/pyct/BUILD b/tensorflow/contrib/py2tf/pyct/BUILD index 88902dea84a9da62d8dd9093c181dc17e59672a7..054eb17fb6a0eba38c58c46a657c3ad16b4773dc 100644 --- a/tensorflow/contrib/py2tf/pyct/BUILD +++ b/tensorflow/contrib/py2tf/pyct/BUILD @@ -21,8 +21,10 @@ py_library( "anno.py", "compiler.py", "context.py", + "copier.py", "parser.py", "pretty_printer.py", + "qual_names.py", "templates.py", "transformer.py", ], @@ -31,6 +33,7 @@ py_library( deps = [ "@astor_archive//:astor", "@gast_archive//:gast", + "@six_archive//:six", "@termcolor_archive//:termcolor", ], ) @@ -56,6 +59,17 @@ py_test( ], ) +py_test( + name = "copier_test", + srcs = ["copier_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":pyct", + "//tensorflow/python:client_testlib", + "@gast_archive//:gast", + ], +) + py_test( name = "parser_test", srcs = ["parser_test.py"], diff --git a/tensorflow/contrib/py2tf/pyct/anno.py b/tensorflow/contrib/py2tf/pyct/anno.py index 889e4ba4ffaed887faffb8736e4a59502da99e81..c6d41f9e128a31b4c3d513615da8c6d0fe51c29d 100644 --- a/tensorflow/contrib/py2tf/pyct/anno.py +++ b/tensorflow/contrib/py2tf/pyct/anno.py @@ -21,6 +21,25 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from enum import Enum + + +class NoValue(Enum): + + def __repr__(self): + return self.name + + +class Basic(NoValue): + """Container for annotation keys. + + The enum values are used strictly for documentation purposes. + """ + + QN = 'Qualified name, as it appeared in the code.' + SKIP_PROCESSING = ( + 'This node should be preserved as is and not processed any further.') + def getanno(node, key, field_name='___pyct_anno'): return getattr(node, field_name)[key] diff --git a/tensorflow/contrib/py2tf/pyct/compiler.py b/tensorflow/contrib/py2tf/pyct/compiler.py index b09353cc72bd5f9d02a8973ebe880b92d39ac304..fc71469d1eaeb92352e3b50cb743621d7e5eb1d5 100644 --- a/tensorflow/contrib/py2tf/pyct/compiler.py +++ b/tensorflow/contrib/py2tf/pyct/compiler.py @@ -41,7 +41,7 @@ def ast_to_source(node, indentation): return astor.source_repr.pretty_source(generator.result).lstrip() -def ast_to_object(node, indentation=' '): +def ast_to_object(node, indentation=' ', source_prefix=None): """Return the Python objects represented by given AST. Compiling the AST code this way ensures that the source code is readable by @@ -50,6 +50,7 @@ def ast_to_object(node, indentation=' '): Args: node: The code to compile, as an AST object. indentation: The string to use for indentation. + source_prefix: Optional string to print as-is into the source file. Returns: A module object containing the compiled source code. @@ -58,5 +59,8 @@ def ast_to_object(node, indentation=' '): with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f: module_name = os.path.basename(f.name[:-3]) + if source_prefix: + f.write(source_prefix) + f.write('\n') f.write(source) return imp.load_source(module_name, f.name) diff --git a/tensorflow/contrib/py2tf/pyct/context.py b/tensorflow/contrib/py2tf/pyct/context.py index 73f3613d09d01e9e643cfb8ee3a8e67e5c126455..fef74ebefa290369c7310af6d7e4faeef44d9aee 100644 --- a/tensorflow/contrib/py2tf/pyct/context.py +++ b/tensorflow/contrib/py2tf/pyct/context.py @@ -33,10 +33,11 @@ class EntityContext(object): """ def __init__(self, namer, source_code, source_file, namespace, arg_values, - arg_types): + arg_types, recursive): self.namer = namer self.source_code = source_code self.source_file = source_file self.namespace = namespace self.arg_values = {} if arg_values is None else arg_values self.arg_types = {} if arg_types is None else arg_types + self.recursive = recursive diff --git a/tensorflow/contrib/py2tf/pyct/copier.py b/tensorflow/contrib/py2tf/pyct/copier.py new file mode 100644 index 0000000000000000000000000000000000000000..41598fdc995a187fe5ff8ed7f86dad0a96d62fe5 --- /dev/null +++ b/tensorflow/contrib/py2tf/pyct/copier.py @@ -0,0 +1,68 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Copy an AST tree, discarding annotations.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import ast + +import gast + +from tensorflow.contrib.py2tf.pyct import anno + + +class CleanCopier(gast.NodeVisitor): + """Copy AST nodes. + + The copied nodes will ignore almost all fields that prefixed by '__'. + Exceptions make some annotations. + """ + + # TODO(mdan): Parametrize which annotations get carried over. + + def generic_visit(self, node): + new_fields = {} + for f in node._fields: + if f.startswith('__'): + continue + if not hasattr(node, f): + continue + v = getattr(node, f) + if isinstance(v, list): + v = [self.generic_visit(n) for n in v] + elif isinstance(v, tuple): + v = tuple(self.generic_visit(n) for n in v) + elif isinstance(v, (gast.AST, ast.AST)): + v = self.generic_visit(v) + else: + # Assume everything else is a value type. + pass + new_fields[f] = v + new_node = type(node)(**new_fields) + if anno.hasanno(node, anno.Basic.SKIP_PROCESSING): + anno.setanno(new_node, anno.Basic.SKIP_PROCESSING, True) + return new_node + + +def copy_clean(node): + copier = CleanCopier() + if isinstance(node, list): + return [copier.visit(n) for n in node] + elif isinstance(node, tuple): + return tuple(copier.visit(n) for n in node) + else: + return copier.visit(node) diff --git a/tensorflow/contrib/py2tf/pyct/copier_test.py b/tensorflow/contrib/py2tf/pyct/copier_test.py new file mode 100644 index 0000000000000000000000000000000000000000..a6b35eda1492b7981557a5a515abdd36b05c8c87 --- /dev/null +++ b/tensorflow/contrib/py2tf/pyct/copier_test.py @@ -0,0 +1,53 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for copier module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import ast + +from tensorflow.contrib.py2tf.pyct import copier +from tensorflow.python.platform import test + + +class CopierTest(test.TestCase): + + def test_copy_clean(self): + ret = ast.Return( + ast.BinOp( + op=ast.Add(), + left=ast.Name(id='a', ctx=ast.Load()), + right=ast.Num(1))) + setattr(ret, '__foo', 'bar') + node = ast.FunctionDef( + name='f', + args=ast.arguments( + args=[ast.Name(id='a', ctx=ast.Param())], + vararg=None, + kwarg=None, + defaults=[]), + body=[ret], + decorator_list=[], + returns=None) + new_node = copier.copy_clean(node) + self.assertFalse(node is new_node) + self.assertFalse(ret is new_node.body[0]) + self.assertFalse(hasattr(new_node.body[0], '__foo')) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/py2tf/pyct/parser.py b/tensorflow/contrib/py2tf/pyct/parser.py index 3daa69b9ceff714c94c61134f6fb81f9927ea258..dc7df883b349becd860bb0dbceab22cb39c750b5 100644 --- a/tensorflow/contrib/py2tf/pyct/parser.py +++ b/tensorflow/contrib/py2tf/pyct/parser.py @@ -28,11 +28,13 @@ import gast from tensorflow.python.util import tf_inspect -def parse_object(obj): - """Return the AST of given object.""" - return parse_str(tf_inspect.getsource(obj)) +def parse_entity(entity): + """Return the AST of given entity.""" + source = tf_inspect.getsource(entity) + source = textwrap.dedent(source) + return parse_str(source), source def parse_str(src): """Return the AST of given piece of code.""" - return gast.parse(textwrap.dedent(src)) + return gast.parse(src) diff --git a/tensorflow/contrib/py2tf/pyct/parser_test.py b/tensorflow/contrib/py2tf/pyct/parser_test.py index 46f9aa82071efa98518810851b76761ff42751e5..f35dfa04c70dc191078248c32f9a04d28133129a 100644 --- a/tensorflow/contrib/py2tf/pyct/parser_test.py +++ b/tensorflow/contrib/py2tf/pyct/parser_test.py @@ -18,6 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import textwrap + from tensorflow.contrib.py2tf.pyct import parser from tensorflow.python.platform import test @@ -28,15 +30,16 @@ def f(x): class ParserTest(test.TestCase): - def test_parse_object(self): - mod = parser.parse_object(f) + def test_parse_entity(self): + mod, _ = parser.parse_entity(f) self.assertEqual('f', mod.body[0].name) def test_parse_str(self): - mod = parser.parse_str(""" + mod = parser.parse_str( + textwrap.dedent(""" def f(x): return x + 1 - """) + """)) self.assertEqual('f', mod.body[0].name) diff --git a/tensorflow/contrib/py2tf/pyct/pretty_printer.py b/tensorflow/contrib/py2tf/pyct/pretty_printer.py index 5e70c0ed833c10012e6a5b4cb26e9e4198162693..bacc1e4a7774ec5b84495255042392fe089150d5 100644 --- a/tensorflow/contrib/py2tf/pyct/pretty_printer.py +++ b/tensorflow/contrib/py2tf/pyct/pretty_printer.py @@ -25,24 +25,30 @@ import termcolor class PrettyPrinter(gast.NodeVisitor): """Print AST nodes.""" - def __init__(self): + def __init__(self, color): self.indent_lvl = 0 self.result = '' + self.color = color + + def _color(self, string, color, attrs=None): + if self.color: + return termcolor.colored(string, color, attrs=attrs) + return string def _type(self, node): - return termcolor.colored(node.__class__.__name__, None, attrs=['bold']) + return self._color(node.__class__.__name__, None, ['bold']) def _field(self, name): - return termcolor.colored(name, 'blue') + return self._color(name, 'blue') def _value(self, name): - return termcolor.colored(name, 'magenta') + return self._color(name, 'magenta') def _warning(self, name): - return termcolor.colored(name, 'red') + return self._color(name, 'red') def _indent(self): - return termcolor.colored('| ' * self.indent_lvl, None, attrs=['dark']) + return self._color('| ' * self.indent_lvl, None, ['dark']) def _print(self, s): self.result += s @@ -76,6 +82,16 @@ class PrettyPrinter(gast.NodeVisitor): self._print('%s]' % (self._indent())) else: self._print('%s%s=[]' % (self._indent(), self._field(f))) + elif isinstance(v, tuple): + if v: + self._print('%s%s=(' % (self._indent(), self._field(f))) + self.indent_lvl += 1 + for n in v: + self.generic_visit(n) + self.indent_lvl -= 1 + self._print('%s)' % (self._indent())) + else: + self._print('%s%s=()' % (self._indent(), self._field(f))) elif isinstance(v, gast.AST): self.generic_visit(v, f) elif isinstance(v, str): @@ -87,8 +103,8 @@ class PrettyPrinter(gast.NodeVisitor): self.indent_lvl -= 1 -def fmt(node): - printer = PrettyPrinter() +def fmt(node, color=True): + printer = PrettyPrinter(color) if isinstance(node, (list, tuple)): for n in node: printer.visit(n) diff --git a/tensorflow/contrib/py2tf/pyct/pretty_printer_test.py b/tensorflow/contrib/py2tf/pyct/pretty_printer_test.py index 65e5b1d9191749a0caeeda48df37690564a8fc1e..81e3f47b80b6cb3bb7ba9f4a1787d03df4151a99 100644 --- a/tensorflow/contrib/py2tf/pyct/pretty_printer_test.py +++ b/tensorflow/contrib/py2tf/pyct/pretty_printer_test.py @@ -24,10 +24,6 @@ from tensorflow.contrib.py2tf.pyct import pretty_printer from tensorflow.python.platform import test -def f(x): - return x + 1 - - class PrettyPrinterTest(test.TestCase): def test_format(self): diff --git a/tensorflow/contrib/py2tf/pyct/qual_names.py b/tensorflow/contrib/py2tf/pyct/qual_names.py new file mode 100644 index 0000000000000000000000000000000000000000..11e3838467783807530576413bec20c9904f873f --- /dev/null +++ b/tensorflow/contrib/py2tf/pyct/qual_names.py @@ -0,0 +1,99 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities for manipulating qualified names. + +A qualified name is a uniform way to refer to simple (e.g. 'foo') and composite +(e.g. 'foo.bar') syntactic symbols. + +This is *not* related to the __qualname__ attribute used by inspect, which +refers to scopes. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import gast + +from tensorflow.contrib.py2tf.pyct import anno + + +class QN(object): + """Represents a qualified name. + + """ + + def __init__(self, base, attr=None): + if attr: + if not isinstance(base, QN): + raise ValueError('For attribute QNs, base must be a QN.') + self._parent = base + self.qn = base.qn + (attr,) + else: + self._parent = None + self.qn = tuple(base.split('.')) + + def is_composite(self): + return len(self.qn) > 1 + + @property + def parent(self): + if self._parent is None: + raise ValueError('Cannot get parent of simple name "%s".' % self.qn[0]) + return self._parent + + def __hash__(self): + return hash(self.qn) + + def __eq__(self, other): + return self.qn == other.qn + + def __str__(self): + return '.'.join(self.qn) + + def __repr__(self): + return str(self) + + def ssf(self): + """Simple symbol form.""" + return '_'.join(self.qn) + + def ast(self): + # The caller must adjust the context appropriately. + if self.is_composite(): + return gast.Attribute(self.parent.ast(), self.qn[-1], None) + return gast.Name(self.qn[0], None, None) + + +class QnResolver(gast.NodeTransformer): + """Annotates nodes with QN information. + + Note: Not using NodeAnnos to avoid circular dependencies. + """ + + def visit_Name(self, node): + self.generic_visit(node) + anno.setanno(node, anno.Basic.QN, QN(node.id)) + return node + + def visit_Attribute(self, node): + self.generic_visit(node) + anno.setanno(node, anno.Basic.QN, + QN(anno.getanno(node.value, anno.Basic.QN), node.attr)) + return node + + +def resolve(node): + return QnResolver().visit(node) diff --git a/tensorflow/contrib/py2tf/pyct/static_analysis/BUILD b/tensorflow/contrib/py2tf/pyct/static_analysis/BUILD index 32e2954fffca3b9f512116648117904b85a60e25..fbfce18c60cca4b105e7de3c3ea7b9c3438f6b2a 100644 --- a/tensorflow/contrib/py2tf/pyct/static_analysis/BUILD +++ b/tensorflow/contrib/py2tf/pyct/static_analysis/BUILD @@ -17,7 +17,8 @@ filegroup( py_library( name = "static_analysis", srcs = [ - "access.py", + "activity.py", + "annos.py", "live_values.py", "type_info.py", ], @@ -30,8 +31,8 @@ py_library( ) py_test( - name = "access_test", - srcs = ["access_test.py"], + name = "activity_test", + srcs = ["activity_test.py"], srcs_version = "PY2AND3", deps = [ ":static_analysis", diff --git a/tensorflow/contrib/py2tf/pyct/static_analysis/access_test.py b/tensorflow/contrib/py2tf/pyct/static_analysis/access_test.py deleted file mode 100644 index 0912ebb4c355c2ae2563e13e36926a4b8e3599a1..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/py2tf/pyct/static_analysis/access_test.py +++ /dev/null @@ -1,234 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for access module.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import gast - -from tensorflow.contrib.py2tf.pyct import anno -from tensorflow.contrib.py2tf.pyct import parser -from tensorflow.contrib.py2tf.pyct.static_analysis import access -from tensorflow.python.platform import test - - -class ScopeTest(test.TestCase): - - def test_basic(self): - scope = access.Scope(None) - self.assertFalse(scope.has('foo')) - - scope.mark_read('foo') - self.assertFalse(scope.has('foo')) - - scope.mark_write('foo') - self.assertTrue(scope.has('foo')) - - scope.mark_read('bar') - self.assertFalse(scope.has('bar')) - - def test_copy(self): - scope = access.Scope(None) - scope.mark_write('foo') - - other = access.Scope(None) - other.copy_from(scope) - - self.assertTrue('foo' in other.created) - - scope.mark_write('bar') - scope.copy_from(other) - - self.assertFalse('bar' in scope.created) - - scope.mark_write('bar') - scope.merge_from(other) - - self.assertTrue('bar' in scope.created) - self.assertFalse('bar' in other.created) - - def test_nesting(self): - scope = access.Scope(None) - scope.mark_write('foo') - scope.mark_read('bar') - - child = access.Scope(scope) - self.assertTrue(child.has('foo')) - self.assertTrue(scope.has('foo')) - - child.mark_write('bar') - self.assertTrue(child.has('bar')) - self.assertFalse(scope.has('bar')) - - def test_referenced(self): - scope = access.Scope(None) - scope.mark_read('a') - - child = access.Scope(scope) - child.mark_read('b') - - child2 = access.Scope(child, isolated=False) - child2.mark_read('c') - - self.assertTrue('c' in child2.referenced) - self.assertTrue('b' in child2.referenced) - self.assertFalse('a' in child2.referenced) - - self.assertTrue('c' in child.referenced) - self.assertTrue('b' in child.referenced) - self.assertFalse('a' in child.referenced) - - -class AccessResolverTest(test.TestCase): - - def test_local_markers(self): - - def test_fn(a): # pylint:disable=unused-argument - b = c # pylint:disable=undefined-variable - while b > 0: - b -= 1 - return b - - node = parser.parse_object(test_fn) - node = access.resolve(node) - - self.assertFalse(anno.getanno(node.body[0].body[0].value, - 'is_local')) # c in b = c - self.assertTrue(anno.getanno(node.body[0].body[1].test.left, - 'is_local')) # b in b > 0 - self.assertTrue(anno.getanno(node.body[0].body[2].value, - 'is_local')) # b in return b - - def assertScopeIs(self, scope, used, modified, created): - self.assertItemsEqual(used, scope.used) - self.assertItemsEqual(modified, scope.modified) - self.assertItemsEqual(created, scope.created) - - def test_print_statement(self): - - def test_fn(a): - b = 0 - c = 1 - print(a, b) - return c - - node = parser.parse_object(test_fn) - node = access.resolve(node) - - print_node = node.body[0].body[2] - if isinstance(print_node, gast.Print): - # Python 2 - print_args_scope = anno.getanno(print_node, 'args_scope') - else: - # Python 3 - assert isinstance(print_node, gast.Expr) - # The call node should be the one being annotated. - print_node = print_node.value - print_args_scope = anno.getanno(print_node, 'args_scope') - # We basically need to detect which variables are captured by the call - # arguments. - self.assertScopeIs(print_args_scope, ('a', 'b'), (), ()) - - def test_call(self): - - def test_fn(a): - b = 0 - c = 1 - foo(a, b) # pylint:disable=undefined-variable - return c - - node = parser.parse_object(test_fn) - node = access.resolve(node) - - call_node = node.body[0].body[2].value - # We basically need to detect which variables are captured by the call - # arguments. - self.assertScopeIs( - anno.getanno(call_node, 'args_scope'), ('a', 'b'), (), ()) - - def test_while(self): - - def test_fn(a): - b = a - while b > 0: - c = b - b -= 1 - return b, c - - node = parser.parse_object(test_fn) - node = access.resolve(node) - - while_node = node.body[0].body[1] - self.assertScopeIs( - anno.getanno(while_node, 'body_scope'), ('b',), ('b', 'c'), ('c',)) - self.assertScopeIs( - anno.getanno(while_node, 'body_parent_scope'), ('a', 'b', 'c'), - ('a', 'b', 'c'), ('a', 'b', 'c')) - - def test_for(self): - - def test_fn(a): - b = a - for _ in a: - c = b - b -= 1 - return b, c - - node = parser.parse_object(test_fn) - node = access.resolve(node) - - for_node = node.body[0].body[1] - self.assertScopeIs( - anno.getanno(for_node, 'body_scope'), ('b',), ('b', 'c'), ('c',)) - self.assertScopeIs( - anno.getanno(for_node, 'body_parent_scope'), ('a', 'b', 'c'), - ('a', 'b', 'c', '_'), ('a', 'b', 'c', '_')) - - def test_if(self): - - def test_fn(x): - if x > 0: - x = -x - y = 2 * x - z = -y - else: - x = 2 * x - y = -x - u = -y - return z, u - - node = parser.parse_object(test_fn) - node = access.resolve(node) - - if_node = node.body[0].body[0] - self.assertScopeIs( - anno.getanno(if_node, 'body_scope'), ('x', 'y'), ('x', 'y', 'z'), - ('y', 'z')) - # TODO(mdan): Double check: is it ok to not mark a local symbol as not read? - self.assertScopeIs( - anno.getanno(if_node, 'body_parent_scope'), ('x', 'z', 'u'), - ('x', 'y', 'z', 'u'), ('x', 'y', 'z', 'u')) - self.assertScopeIs( - anno.getanno(if_node, 'orelse_scope'), ('x', 'y'), ('x', 'y', 'u'), - ('y', 'u')) - self.assertScopeIs( - anno.getanno(if_node, 'body_parent_scope'), ('x', 'z', 'u'), - ('x', 'y', 'z', 'u'), ('x', 'y', 'z', 'u')) - - -if __name__ == '__main__': - test.main() diff --git a/tensorflow/contrib/py2tf/pyct/static_analysis/access.py b/tensorflow/contrib/py2tf/pyct/static_analysis/activity.py similarity index 60% rename from tensorflow/contrib/py2tf/pyct/static_analysis/access.py rename to tensorflow/contrib/py2tf/pyct/static_analysis/activity.py index 8f3ac48b68c05256fbac4c4d8d86381755c8027c..1c93e1603113d48176af7a97a0f37321e6f67586 100644 --- a/tensorflow/contrib/py2tf/pyct/static_analysis/access.py +++ b/tensorflow/contrib/py2tf/pyct/static_analysis/activity.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Access information (reads, writes) resolution.""" +"""Activity analysis.""" from __future__ import absolute_import from __future__ import division @@ -23,6 +23,8 @@ import copy import gast from tensorflow.contrib.py2tf.pyct import anno +from tensorflow.contrib.py2tf.pyct import transformer +from tensorflow.contrib.py2tf.pyct.static_analysis.annos import NodeAnno # TODO(mdan): Add support for PY3 (e.g. Param vs arg). @@ -53,6 +55,8 @@ class Scope(object): self.modified = set() self.created = set() self.used = set() + self.params = set() + self.returned = set() # TODO(mdan): Rename to `locals` @property @@ -69,61 +73,116 @@ class Scope(object): self.modified = copy.copy(other.modified) self.created = copy.copy(other.created) self.used = copy.copy(other.used) + self.params = copy.copy(other.params) + self.returned = copy.copy(other.returned) def merge_from(self, other): self.modified |= other.modified self.created |= other.created self.used |= other.used + self.params |= other.params + self.returned |= other.returned def has(self, name): - if name in self.modified: + if name in self.modified or name in self.params: return True elif self.parent is not None: return self.parent.has(name) return False + def is_modified_since_entry(self, name): + if name in self.modified: + return True + elif self.parent is not None and not self.isolated: + return self.parent.is_modified_since_entry(name) + return False + + def is_param(self, name): + if name in self.params: + return True + elif self.parent is not None and not self.isolated: + return self.parent.is_param(name) + return False + def mark_read(self, name): self.used.add(name) if self.parent is not None and name not in self.created: self.parent.mark_read(name) + def mark_param(self, name): + self.params.add(name) + + def mark_creation(self, name): + if name.is_composite(): + parent = name.parent + if self.has(parent): + # This is considered mutation of the parent, not creation. + # TODO(mdan): Is that really so? + return + else: + raise ValueError('Unknown symbol "%s".' % parent) + self.created.add(name) + def mark_write(self, name): self.modified.add(name) if self.isolated: - self.created.add(name) + self.mark_creation(name) else: if self.parent is None: - self.created.add(name) + self.mark_creation(name) else: if not self.parent.has(name): - self.created.add(name) + self.mark_creation(name) self.parent.mark_write(name) + def mark_returned(self, name): + self.returned.add(name) + if not self.isolated and self.parent is not None: + self.parent.mark_returned(name) + -class AccessResolver(gast.NodeTransformer): +class ActivityAnalizer(transformer.Base): """Annotates nodes with local scope information. See Scope.""" - def __init__(self): - self.scope = Scope(None) + def __init__(self, context, parent_scope): + super(ActivityAnalizer, self).__init__(context) + self.scope = Scope(parent_scope) + self._in_return_statement = False + + def _track_symbol(self, node): + qn = anno.getanno(node, anno.Basic.QN) - def visit_Name(self, node): - # TODO(mdan): This is insufficient for object fields, e.g. hp.learning_rate. - self.generic_visit(node) if isinstance(node.ctx, gast.Store): - self.scope.mark_write(node.id) + self.scope.mark_write(qn) elif isinstance(node.ctx, gast.Load): - anno.setanno(node, 'is_local', self.scope.has(node.id)) - self.scope.mark_read(node.id) + self.scope.mark_read(qn) elif isinstance(node.ctx, gast.Param): # Param contexts appear in function defs, so they have the meaning of # defining a variable. # TODO(mdan): This bay be incorrect with nested functions. # For nested functions, we'll have to add the notion of hiding args from # the parent scope, not writing to them. - self.scope.mark_write(node.id) + self.scope.mark_creation(qn) + self.scope.mark_param(qn) else: - raise ValueError('Unknown context %s for node %s.' % (type(node.ctx), - node.id)) + raise ValueError('Unknown context %s for node %s.' % (type(node.ctx), qn)) + + anno.setanno(node, NodeAnno.IS_LOCAL, self.scope.has(qn)) + anno.setanno(node, NodeAnno.IS_MODIFIED_SINCE_ENTRY, + self.scope.is_modified_since_entry(qn)) + anno.setanno(node, NodeAnno.IS_PARAM, self.scope.is_param(qn)) + + if self._in_return_statement: + self.scope.mark_returned(qn) + + def visit_Name(self, node): + self.generic_visit(node) + self._track_symbol(node) + return node + + def visit_Attribute(self, node): + self.generic_visit(node) + self._track_symbol(node) return node def visit_Print(self, node): @@ -132,20 +191,20 @@ class AccessResolver(gast.NodeTransformer): self.scope = args_scope for n in node.values: self.visit(n) - anno.setanno(node, 'args_scope', args_scope) + anno.setanno(node, NodeAnno.ARGS_SCOPE, args_scope) self.scope = current_scope return node def visit_Call(self, node): current_scope = self.scope - args_scope = Scope(current_scope) + args_scope = Scope(current_scope, isolated=False) self.scope = args_scope for n in node.args: self.visit(n) # TODO(mdan): Account starargs, kwargs for n in node.keywords: self.visit(n) - anno.setanno(node, 'args_scope', args_scope) + anno.setanno(node, NodeAnno.ARGS_SCOPE, args_scope) self.scope = current_scope self.visit(node.func) return node @@ -156,7 +215,7 @@ class AccessResolver(gast.NodeTransformer): self.scope = block_scope for n in block: self.visit(n) - anno.setanno(node, '%s_scope' % scope_name, block_scope) + anno.setanno(node, scope_name, block_scope) self.scope = current_scope return node @@ -168,38 +227,44 @@ class AccessResolver(gast.NodeTransformer): before_parent = Scope(None) before_parent.copy_from(self.scope) after_children = [] - for child, name in children: + for child, scope_name in children: self.scope.copy_from(before_parent) - parent = self._process_block_node(parent, child, name) + parent = self._process_block_node(parent, child, scope_name) after_child = Scope(None) after_child.copy_from(self.scope) after_children.append(after_child) for after_child in after_children: self.scope.merge_from(after_child) - for child, name in children: - # TODO(mdan): We don't need this - we have the parent link from scope. - anno.setanno(parent, '%s_parent_scope' % name, self.scope) return parent def visit_If(self, node): self.visit(node.test) - node = self._process_parallel_blocks( - node, ((node.body, 'body'), (node.orelse, 'orelse'))) + node = self._process_parallel_blocks(node, + ((node.body, NodeAnno.BODY_SCOPE), + (node.orelse, NodeAnno.ORELSE_SCOPE))) return node def visit_For(self, node): self.visit(node.target) self.visit(node.iter) - node = self._process_parallel_blocks( - node, ((node.body, 'body'), (node.orelse, 'orelse'))) + node = self._process_parallel_blocks(node, + ((node.body, NodeAnno.BODY_SCOPE), + (node.orelse, NodeAnno.ORELSE_SCOPE))) return node def visit_While(self, node): self.visit(node.test) - node = self._process_parallel_blocks( - node, ((node.body, 'body'), (node.orelse, 'orelse'))) + node = self._process_parallel_blocks(node, + ((node.body, NodeAnno.BODY_SCOPE), + (node.orelse, NodeAnno.ORELSE_SCOPE))) + return node + + def visit_Return(self, node): + self._in_return_statement = True + node = self.generic_visit(node) + self._in_return_statement = False return node -def resolve(node): - return AccessResolver().visit(node) +def resolve(node, context, parent_scope=None): + return ActivityAnalizer(context, parent_scope).visit(node) diff --git a/tensorflow/contrib/py2tf/pyct/static_analysis/activity_test.py b/tensorflow/contrib/py2tf/pyct/static_analysis/activity_test.py new file mode 100644 index 0000000000000000000000000000000000000000..e1eb954a5efef4d6a00ac492e7c85394d54e28c9 --- /dev/null +++ b/tensorflow/contrib/py2tf/pyct/static_analysis/activity_test.py @@ -0,0 +1,271 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for activity module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import gast + +from tensorflow.contrib.py2tf.pyct import anno +from tensorflow.contrib.py2tf.pyct import context +from tensorflow.contrib.py2tf.pyct import parser +from tensorflow.contrib.py2tf.pyct import qual_names +from tensorflow.contrib.py2tf.pyct.qual_names import QN +from tensorflow.contrib.py2tf.pyct.static_analysis import activity +from tensorflow.contrib.py2tf.pyct.static_analysis.annos import NodeAnno +from tensorflow.python.platform import test + + +class ScopeTest(test.TestCase): + + def test_basic(self): + scope = activity.Scope(None) + self.assertFalse(scope.has(QN('foo'))) + + scope.mark_read(QN('foo')) + self.assertFalse(scope.has(QN('foo'))) + + scope.mark_write(QN('foo')) + self.assertTrue(scope.has(QN('foo'))) + + scope.mark_read(QN('bar')) + self.assertFalse(scope.has(QN('bar'))) + + def test_copy(self): + scope = activity.Scope(None) + scope.mark_write(QN('foo')) + + other = activity.Scope(None) + other.copy_from(scope) + + self.assertTrue(QN('foo') in other.created) + + scope.mark_write(QN('bar')) + scope.copy_from(other) + + self.assertFalse(QN('bar') in scope.created) + + scope.mark_write(QN('bar')) + scope.merge_from(other) + + self.assertTrue(QN('bar') in scope.created) + self.assertFalse(QN('bar') in other.created) + + def test_nesting(self): + scope = activity.Scope(None) + scope.mark_write(QN('foo')) + scope.mark_read(QN('bar')) + + child = activity.Scope(scope) + self.assertTrue(child.has(QN('foo'))) + self.assertTrue(scope.has(QN('foo'))) + + child.mark_write(QN('bar')) + self.assertTrue(child.has(QN('bar'))) + self.assertFalse(scope.has(QN('bar'))) + + def test_referenced(self): + scope = activity.Scope(None) + scope.mark_read(QN('a')) + + child = activity.Scope(scope) + child.mark_read(QN('b')) + + child2 = activity.Scope(child, isolated=False) + child2.mark_read(QN('c')) + + self.assertTrue(QN('c') in child2.referenced) + self.assertTrue(QN('b') in child2.referenced) + self.assertFalse(QN('a') in child2.referenced) + + self.assertTrue(QN('c') in child.referenced) + self.assertTrue(QN('b') in child.referenced) + self.assertFalse(QN('a') in child.referenced) + + +class ActivityAnalizerTest(test.TestCase): + + def _parse_and_analyze(self, test_fn): + node, source = parser.parse_entity(test_fn) + ctx = context.EntityContext( + namer=None, + source_code=source, + source_file=None, + namespace={}, + arg_values=None, + arg_types=None, + recursive=True) + node = qual_names.resolve(node) + node = activity.resolve(node, ctx) + return node + + def test_local_markers(self): + + def test_fn(a): # pylint:disable=unused-argument + b = c # pylint:disable=undefined-variable + while b > 0: + b -= 1 + return b + + node = self._parse_and_analyze(test_fn) + self.assertFalse( + anno.getanno(node.body[0].body[0].value, + NodeAnno.IS_LOCAL)) # c in b = c + self.assertTrue( + anno.getanno(node.body[0].body[1].test.left, + NodeAnno.IS_LOCAL)) # b in b > 0 + self.assertTrue( + anno.getanno(node.body[0].body[2].value, + NodeAnno.IS_LOCAL)) # b in return b + + def assertScopeIs(self, scope, used, modified, created): + self.assertItemsEqual(used, tuple(str(s) for s in scope.used)) + self.assertItemsEqual(modified, tuple(str(s) for s in scope.modified)) + self.assertItemsEqual(created, tuple(str(s) for s in scope.created)) + + def test_print_statement(self): + + def test_fn(a): + b = 0 + c = 1 + print(a, b) + return c + + node = self._parse_and_analyze(test_fn) + print_node = node.body[0].body[2] + if isinstance(print_node, gast.Print): + # Python 2 + print_args_scope = anno.getanno(print_node, NodeAnno.ARGS_SCOPE) + else: + # Python 3 + assert isinstance(print_node, gast.Expr) + # The call node should be the one being annotated. + print_node = print_node.value + print_args_scope = anno.getanno(print_node, NodeAnno.ARGS_SCOPE) + # We basically need to detect which variables are captured by the call + # arguments. + self.assertScopeIs(print_args_scope, ('a', 'b'), (), ()) + + def test_call(self): + + def test_fn(a): + b = 0 + c = 1 + foo(a, b) # pylint:disable=undefined-variable + return c + + node = self._parse_and_analyze(test_fn) + call_node = node.body[0].body[2].value + # We basically need to detect which variables are captured by the call + # arguments. + self.assertScopeIs( + anno.getanno(call_node, NodeAnno.ARGS_SCOPE), ('a', 'b'), (), ()) + + def test_while(self): + + def test_fn(a): + b = a + while b > 0: + c = b + b -= 1 + return b, c + + node = self._parse_and_analyze(test_fn) + while_node = node.body[0].body[1] + self.assertScopeIs( + anno.getanno(while_node, NodeAnno.BODY_SCOPE), ('b',), ('b', 'c'), + ('c',)) + self.assertScopeIs( + anno.getanno(while_node, NodeAnno.BODY_SCOPE).parent, ('a', 'b', 'c'), + ('b', 'c'), ('a', 'b', 'c')) + + def test_for(self): + + def test_fn(a): + b = a + for _ in a: + c = b + b -= 1 + return b, c + + node = self._parse_and_analyze(test_fn) + for_node = node.body[0].body[1] + self.assertScopeIs( + anno.getanno(for_node, NodeAnno.BODY_SCOPE), ('b',), ('b', 'c'), ('c',)) + self.assertScopeIs( + anno.getanno(for_node, NodeAnno.BODY_SCOPE).parent, ('a', 'b', 'c'), + ('b', 'c', '_'), ('a', 'b', 'c', '_')) + + def test_if(self): + + def test_fn(x): + if x > 0: + x = -x + y = 2 * x + z = -y + else: + x = 2 * x + y = -x + u = -y + return z, u + + node = self._parse_and_analyze(test_fn) + if_node = node.body[0].body[0] + self.assertScopeIs( + anno.getanno(if_node, NodeAnno.BODY_SCOPE), ('x', 'y'), ('x', 'y', 'z'), + ('y', 'z')) + # TODO(mdan): Double check: is it ok to not mark a local symbol as not read? + self.assertScopeIs( + anno.getanno(if_node, NodeAnno.BODY_SCOPE).parent, ('x', 'z', 'u'), + ('x', 'y', 'z', 'u'), ('x', 'y', 'z', 'u')) + self.assertScopeIs( + anno.getanno(if_node, NodeAnno.ORELSE_SCOPE), ('x', 'y'), + ('x', 'y', 'u'), ('y', 'u')) + self.assertScopeIs( + anno.getanno(if_node, NodeAnno.ORELSE_SCOPE).parent, ('x', 'z', 'u'), + ('x', 'y', 'z', 'u'), ('x', 'y', 'z', 'u')) + + def test_call_with_composite_names(self): + + def foo(*_): + pass + + def test_fn(a): + foo(a.b, a.c) + if a > 0: + a.b = 2 + else: + d = 2 + d.e = a.c + f = d.e + 1 + a.c = f + + node = self._parse_and_analyze(test_fn) + call_node = node.body[0].body[0].value + self.assertScopeIs( + anno.getanno(call_node, NodeAnno.ARGS_SCOPE), ('a', 'a.b', 'a.c'), (), + ()) + if_node = node.body[0].body[1] + self.assertScopeIs( + anno.getanno(if_node, NodeAnno.BODY_SCOPE), ('a',), ('a.b',), ()) + self.assertScopeIs( + anno.getanno(if_node, NodeAnno.ORELSE_SCOPE), + ('a', 'a.c', 'd', 'd.e', 'f'), ('a.c', 'd', 'd.e', 'f'), ('d', 'f')) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/py2tf/pyct/static_analysis/annos.py b/tensorflow/contrib/py2tf/pyct/static_analysis/annos.py new file mode 100644 index 0000000000000000000000000000000000000000..2d8e49442364fdd4a4752c8a83a5f3b76117fe57 --- /dev/null +++ b/tensorflow/contrib/py2tf/pyct/static_analysis/annos.py @@ -0,0 +1,50 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Annotations used by the static analizer.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from enum import Enum + + +class NoValue(Enum): + + def __repr__(self): + return self.name + + +class NodeAnno(NoValue): + """Additionnal annotations used by the static analyzer. + + These are in addition to the basic annotations declared in anno.py. + """ + + # Symbols + + IS_LOCAL = 'Symbol is local to the function scope being analized.' + IS_PARAM = 'Symbol is a parameter to the function being analized.' + IS_MODIFIED_SINCE_ENTRY = ( + 'Symbol has been explicitly replaced in the current function scope.') + + # Scopes + ARGS_SCOPE = 'The scope for the argument list of a function call.' + BODY_SCOPE = ( + 'The scope for the main body of a statement (True branch for if ' + 'statements, main body for loops).') + ORELSE_SCOPE = ( + 'The scope for the orelse body of a statement (False branch for if ' + 'statements, orelse body for loops).') diff --git a/tensorflow/contrib/py2tf/pyct/static_analysis/live_values.py b/tensorflow/contrib/py2tf/pyct/static_analysis/live_values.py index 242e544b5286c683ee4aa97bc586751932c73815..9c0a9a9e74eccb3d22840032e8f0c2b81e051e7e 100644 --- a/tensorflow/contrib/py2tf/pyct/static_analysis/live_values.py +++ b/tensorflow/contrib/py2tf/pyct/static_analysis/live_values.py @@ -16,7 +16,7 @@ Live values are extracted from the known execution context. -Requires annotations generated by AccessResolver. +Requires activity analysis annotations. """ from __future__ import absolute_import @@ -26,47 +26,56 @@ from __future__ import print_function import gast from tensorflow.contrib.py2tf.pyct import anno +from tensorflow.contrib.py2tf.pyct import transformer +from tensorflow.contrib.py2tf.pyct.static_analysis.annos import NodeAnno -class LiveValueResolver(gast.NodeTransformer): +class LiveValueResolver(transformer.Base): """Annotates nodes with live values.""" - def __init__(self, namespace, literals): - """Create a new resolver. - - Args: - namespace: A dict representing the namespace visible to the AST in the - intended execution context. - literals: A dict mapping literal lymbol names to their value. An example - literal is "None". - """ - self.namespace = namespace + def __init__(self, context, literals): + super(LiveValueResolver, self).__init__(context) self.literals = literals def visit_ClassDef(self, node): self.generic_visit(node) - anno.setanno(node, 'live_val', self.namespace[node.name]) + anno.setanno(node, 'live_val', self.context.namespace[node.name]) return node def visit_Name(self, node): self.generic_visit(node) if isinstance(node.ctx, gast.Load): - assert anno.hasanno(node, 'is_local'), node - symbol_is_local = anno.getanno(node, 'is_local') - if not symbol_is_local: + assert anno.hasanno(node, NodeAnno.IS_LOCAL), node + symbol_is_local = anno.getanno(node, NodeAnno.IS_LOCAL) + assert anno.hasanno(node, NodeAnno.IS_MODIFIED_SINCE_ENTRY), node + symbol_is_modified = anno.getanno(node, NodeAnno.IS_MODIFIED_SINCE_ENTRY) + assert anno.hasanno(node, NodeAnno.IS_PARAM), node + symbol_is_param = anno.getanno(node, NodeAnno.IS_PARAM) + + if not symbol_is_local and not symbol_is_param: if node.id in self.literals: anno.setanno(node, 'live_val', self.literals[node.id]) # TODO(mdan): Could live values have FQNs? i.e. 'a'.join() - elif node.id in self.namespace: - obj = self.namespace[node.id] + elif node.id in self.context.namespace: + obj = self.context.namespace[node.id] anno.setanno(node, 'live_val', obj) anno.setanno(node, 'fqn', (obj.__name__,)) else: - raise ValueError('Could not find global symbol %s.' % node.id) + pass + # TODO(mdan): Should we raise an error here? + # Can encounter this when: + # * a symbol truly lacks reference + # * a symbol is new, like the new name of a function we just renamed. else: pass # TODO(mdan): Attempt to trace its value through the local chain. # TODO(mdan): Use type annotations as fallback. + + if not symbol_is_modified: + if node.id in self.context.arg_values: + obj = self.context.arg_values[node.id] + anno.setanno(node, 'live_val', obj) + anno.setanno(node, 'fqn', (obj.__class__.__name__,)) return node def visit_Attribute(self, node): @@ -79,15 +88,25 @@ class LiveValueResolver(gast.NodeTransformer): node.attr)) anno.setanno(node, 'live_val', getattr(parent_object, node.attr)) anno.setanno(node, 'fqn', anno.getanno(node.value, 'fqn') + (node.attr,)) + # TODO(mdan): Investigate the role built-in annotations can play here. + elif anno.hasanno(node.value, 'type'): + parent_type = anno.getanno(node.value, 'type') + if hasattr(parent_type, node.attr): + # This should hold for static members like methods. + # This would not hold for dynamic members like function attributes. + # For the dynamic case, we simply leave the node without an annotation, + # and let downstream consumers figure out what to do. + anno.setanno(node, 'live_val', getattr(parent_type, node.attr)) + anno.setanno(node, 'fqn', + anno.getanno(node.value, 'type_fqn') + (node.attr,)) elif isinstance(node.value, gast.Name): stem_name = node.value # All nonlocal symbols should be fully resolved. - assert anno.hasanno(stem_name, 'is_local'), stem_name - assert anno.getanno(stem_name, 'is_local'), stem_name + assert anno.hasanno(stem_name, NodeAnno.IS_LOCAL), stem_name # TODO(mdan): Figure out what to do when calling attribute on local object # Maybe just leave as-is? return node -def resolve(node, namespace, literals): - return LiveValueResolver(namespace, literals).visit(node) +def resolve(node, context, literals): + return LiveValueResolver(context, literals).visit(node) diff --git a/tensorflow/contrib/py2tf/pyct/static_analysis/live_values_test.py b/tensorflow/contrib/py2tf/pyct/static_analysis/live_values_test.py index e77497654a0b3096422deef9a3f008eeb6c6be05..9f64689401e3594a77fbdd7b6f02880bd6e90492 100644 --- a/tensorflow/contrib/py2tf/pyct/static_analysis/live_values_test.py +++ b/tensorflow/contrib/py2tf/pyct/static_analysis/live_values_test.py @@ -19,24 +19,47 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.py2tf.pyct import anno +from tensorflow.contrib.py2tf.pyct import context from tensorflow.contrib.py2tf.pyct import parser -from tensorflow.contrib.py2tf.pyct.static_analysis import access +from tensorflow.contrib.py2tf.pyct import qual_names +from tensorflow.contrib.py2tf.pyct.static_analysis import activity from tensorflow.contrib.py2tf.pyct.static_analysis import live_values +from tensorflow.contrib.py2tf.pyct.static_analysis import type_info from tensorflow.python.framework import constant_op from tensorflow.python.platform import test class LiveValuesResolverTest(test.TestCase): + def _parse_and_analyze(self, + test_fn, + namespace, + literals=None, + arg_types=None): + literals = literals or {} + arg_types = arg_types or {} + node, source = parser.parse_entity(test_fn) + ctx = context.EntityContext( + namer=None, + source_code=source, + source_file=None, + namespace=namespace, + arg_values=None, + arg_types=arg_types, + recursive=True) + node = qual_names.resolve(node) + node = activity.resolve(node, ctx) + node = live_values.resolve(node, ctx, literals) + node = type_info.resolve(node, ctx) + node = live_values.resolve(node, ctx, literals) + return node + def test_literals(self): def test_fn(): return Foo # pylint: disable=undefined-variable - node = parser.parse_object(test_fn) - node = access.resolve(node) - node = live_values.resolve(node, {}, {'Foo': 'bar'}) - + node = self._parse_and_analyze(test_fn, {}, {'Foo': 'bar'}) retval_node = node.body[0].body[0].value self.assertEquals('bar', anno.getanno(retval_node, 'live_val')) @@ -48,10 +71,7 @@ class LiveValuesResolverTest(test.TestCase): def test_fn(): return foo() - node = parser.parse_object(test_fn) - node = access.resolve(node) - node = live_values.resolve(node, {'foo': foo}, {}) - + node = self._parse_and_analyze(test_fn, {'foo': foo}) func_node = node.body[0].body[0].value.func self.assertEquals(foo, anno.getanno(func_node, 'live_val')) self.assertEquals(('foo',), anno.getanno(func_node, 'fqn')) @@ -61,15 +81,29 @@ class LiveValuesResolverTest(test.TestCase): def test_fn(): return constant_op.constant(0) - node = parser.parse_object(test_fn) - node = access.resolve(node) - node = live_values.resolve(node, {'constant_op': constant_op}, {}) - + node = self._parse_and_analyze(test_fn, {'constant_op': constant_op}) func_node = node.body[0].body[0].value.func self.assertEquals(constant_op.constant, anno.getanno(func_node, 'live_val')) self.assertEquals((constant_op.__name__, 'constant'), anno.getanno(func_node, 'fqn')) + def test_attributes_with_type_hints(self): + + class TestClass(object): + + def member(self): + pass + + def test_fn(self): + return self.member() + + node = self._parse_and_analyze( + TestClass.test_fn, {'constant_op': constant_op}, + arg_types={'self': (TestClass.__name__, TestClass)}) + func_node = node.body[0].body[0].value.func + self.assertEquals(TestClass.member, anno.getanno(func_node, 'live_val')) + self.assertEquals(('TestClass', 'member'), anno.getanno(func_node, 'fqn')) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/py2tf/pyct/static_analysis/type_info.py b/tensorflow/contrib/py2tf/pyct/static_analysis/type_info.py index 0042aa90ed218d42aedc720c94d1a478bc9f18f5..8203bda0f9a792a5b24b9abb25d8f39b61625748 100644 --- a/tensorflow/contrib/py2tf/pyct/static_analysis/type_info.py +++ b/tensorflow/contrib/py2tf/pyct/static_analysis/type_info.py @@ -36,8 +36,6 @@ class Scope(object): most recently assigned to the symbol. """ - # TODO(mdan): Should rather use a CFG here? - def __init__(self, parent): """Create a new scope. @@ -117,18 +115,34 @@ class TypeInfoResolver(transformer.Base): node.orelse = self._visit_block(node.orelse) return node + def _process_function_arg(self, arg_name): + str_name = str(arg_name) + if self.function_level == 1 and str_name in self.context.arg_types: + # Forge a node to hold the type information, so that method calls on + # it can resolve the type. + type_holder = arg_name.ast() + type_string, type_obj = self.context.arg_types[str_name] + anno.setanno(type_holder, 'type', type_obj) + anno.setanno(type_holder, 'type_fqn', tuple(type_string.split('.'))) + self.scope.setval(arg_name, type_holder) + + def visit_arg(self, node): + self._process_function_arg(anno.getanno(node.arg, anno.Basic.QN)) + return node + def visit_Name(self, node): self.generic_visit(node) + qn = anno.getanno(node, anno.Basic.QN) if isinstance(node.ctx, gast.Param): - self.scope.setval(node.id, gast.Name(node.id, gast.Load(), None)) - if self.function_level == 1 and node.id in self.context.arg_types: - # Forge a node to hold the type information, so that method calls on - # it can resolve the type. - type_holder = gast.Name(node.id, gast.Load(), None) - type_string, type_obj = self.context.arg_types[node.id] - anno.setanno(type_holder, 'type', type_obj) - anno.setanno(type_holder, 'type_fqn', tuple(type_string.split('.'))) - self.scope.setval(node.id, type_holder) + self._process_function_arg(qn) + elif isinstance(node.ctx, gast.Load) and self.scope.hasval(qn): + # E.g. if we had + # a = b + # then for future references to `a` we should have traced_source = `b` + traced_source = self.scope.getval(qn) + if anno.hasanno(traced_source, 'type'): + anno.setanno(node, 'type', anno.getanno(traced_source, 'type')) + anno.setanno(node, 'type_fqn', anno.getanno(traced_source, 'type_fqn')) return node def _process_variable_assignment(self, source, targets): @@ -147,16 +161,11 @@ class TypeInfoResolver(transformer.Base): for t in targets: if isinstance(t, gast.Tuple): for i, e in enumerate(t.elts): - self.scope.setval(e.id, - gast.Subscript( - source, gast.Index(i), ctx=gast.Store())) - elif isinstance(t, gast.Name): - self.scope.setval(t.id, source) - elif isinstance(t, gast.Attribute): - if not (isinstance(t.value, gast.Name) and t.value.id == 'self'): - raise ValueError( - 'Dont know how to handle assignment to attributes of objects' - ' other than "self": [%s].%s' % (t.value, t.attr)) + self.scope.setval( + anno.getanno(e, anno.Basic.QN), + gast.Subscript(source, gast.Index(i), ctx=gast.Store())) + elif isinstance(t, (gast.Name, gast.Attribute)): + self.scope.setval(anno.getanno(t, anno.Basic.QN), source) else: raise ValueError('Dont know how to handle assignment to %s' % t) @@ -172,38 +181,6 @@ class TypeInfoResolver(transformer.Base): self._process_variable_assignment(node.value, node.targets) return node - def visit_Call(self, node): - target = node.func - if not anno.hasanno(target, 'live_val'): - if not isinstance(target, gast.Attribute): - # Suspecting this pattern would reach here: - # foo = bar - # foo() - raise ValueError('Dont know how to handle dynamic functions.') - if not isinstance(target.value, gast.Name): - # Possible example of this kind: - # foo = module.Foo() - # foo.bar.baz() - # TODO(mdan): This should be doable by using the FQN. - raise ValueError('Dont know how to handle object properties yet.') - # In the example below, object_source is 'tr.train.Optimizer()': - # opt = tf.train.Optimizer() - # opt.foo() - if self.scope.hasval(target.value.id): - object_source = self.scope.getval(target.value.id) - if not anno.hasanno(object_source, 'type'): - raise ValueError('Could not determine type of "%s". Is it dynamic?' % - (target.value.id)) - anno.setanno(target, 'type', anno.getanno(object_source, 'type')) - anno.setanno(target, 'type_fqn', anno.getanno(object_source, - 'type_fqn')) - else: - # TODO(mdan): Figure out what could the user do to get past this. - raise ValueError('No info on "%s". Is it dynamically built?' % - (target.value.id)) - self.generic_visit(node) - return node - def resolve(node, context): return TypeInfoResolver(context).visit(node) diff --git a/tensorflow/contrib/py2tf/pyct/static_analysis/type_info_test.py b/tensorflow/contrib/py2tf/pyct/static_analysis/type_info_test.py index a491f49ca3b87d1340fdd691431e127737abc006..3659f949db9910534870d8dd9e42fd4ee8297253 100644 --- a/tensorflow/contrib/py2tf/pyct/static_analysis/type_info_test.py +++ b/tensorflow/contrib/py2tf/pyct/static_analysis/type_info_test.py @@ -21,8 +21,8 @@ from __future__ import print_function from tensorflow.contrib.py2tf.pyct import anno from tensorflow.contrib.py2tf.pyct import context from tensorflow.contrib.py2tf.pyct import parser -from tensorflow.contrib.py2tf.pyct import transformer -from tensorflow.contrib.py2tf.pyct.static_analysis import access +from tensorflow.contrib.py2tf.pyct import qual_names +from tensorflow.contrib.py2tf.pyct.static_analysis import activity from tensorflow.contrib.py2tf.pyct.static_analysis import live_values from tensorflow.contrib.py2tf.pyct.static_analysis import type_info from tensorflow.python.client import session @@ -57,17 +57,20 @@ class ScopeTest(test.TestCase): class TypeInfoResolverTest(test.TestCase): def _parse_and_analyze(self, test_fn, namespace, arg_types=None): + node, source = parser.parse_entity(test_fn) ctx = context.EntityContext( namer=None, - source_code=None, + source_code=source, source_file=None, namespace=namespace, arg_values=None, - arg_types=arg_types) - node = parser.parse_object(test_fn) - node = access.resolve(node) - node = live_values.resolve(node, namespace, {}) + arg_types=arg_types, + recursive=True) + node = qual_names.resolve(node) + node = activity.resolve(node, ctx) + node = live_values.resolve(node, ctx, {}) node = type_info.resolve(node, ctx) + node = live_values.resolve(node, ctx, {}) return node def test_constructor_detection(self): @@ -83,16 +86,16 @@ class TypeInfoResolverTest(test.TestCase): self.assertEquals((training.__name__, 'GradientDescentOptimizer'), anno.getanno(call_node, 'type_fqn')) - def test_class_members(self): + def test_class_members_of_detected_constructor(self): def test_fn(): opt = training.GradientDescentOptimizer(0.1) opt.minimize(0) node = self._parse_and_analyze(test_fn, {'training': training}) - attr_call_node = node.body[0].body[1].value.func - self.assertEquals((training.__name__, 'GradientDescentOptimizer'), - anno.getanno(attr_call_node, 'type_fqn')) + method_call = node.body[0].body[1].value.func + self.assertEquals(training.GradientDescentOptimizer.minimize, + anno.getanno(method_call, 'live_val')) def test_class_members_in_with_stmt(self): @@ -106,11 +109,11 @@ class TypeInfoResolverTest(test.TestCase): self.assertEquals((session.__name__, 'Session'), anno.getanno(constructor_call, 'type_fqn')) - member_call = node.body[0].body[0].body[0].value.func - self.assertEquals((session.__name__, 'Session'), - anno.getanno(member_call, 'type_fqn')) + method_call = node.body[0].body[0].body[0].value.func + self.assertEquals(session.Session.run, anno.getanno(method_call, + 'live_val')) - def test_constructor_deta_dependent(self): + def test_constructor_data_dependent(self): def test_fn(x): if x > 0: @@ -119,16 +122,18 @@ class TypeInfoResolverTest(test.TestCase): opt = training.GradientDescentOptimizer(0.01) opt.minimize(0) - with self.assertRaises(transformer.PyFlowParseError): - self._parse_and_analyze(test_fn, {'training': training}) + node = self._parse_and_analyze(test_fn, {'training': training}) + method_call = node.body[0].body[1].value.func + self.assertFalse(anno.hasanno(method_call, 'live_val')) def test_parameter_class_members(self): def test_fn(opt): opt.minimize(0) - with self.assertRaises(transformer.PyFlowParseError): - self._parse_and_analyze(test_fn, {'training': training}) + node = self._parse_and_analyze(test_fn, {}) + method_call = node.body[0].body[0].value.func + self.assertFalse(anno.hasanno(method_call, 'live_val')) def test_parameter_class_members_with_value_hints(self): @@ -138,14 +143,13 @@ class TypeInfoResolverTest(test.TestCase): node = self._parse_and_analyze( test_fn, {'training': training}, arg_types={ - 'opt': (('%s.GradientDescentOptimizer' % training.__name__), - training.GradientDescentOptimizer(0.1)) + 'opt': (training.GradientDescentOptimizer.__name__, + training.GradientDescentOptimizer) }) - attr_call_node = node.body[0].body[0].value.func - self.assertEquals( - tuple(training.__name__.split('.')) + ('GradientDescentOptimizer',), - anno.getanno(attr_call_node, 'type_fqn')) + method_call = node.body[0].body[0].value.func + self.assertEquals(training.GradientDescentOptimizer.minimize, + anno.getanno(method_call, 'live_val')) def test_function_variables(self): @@ -156,8 +160,9 @@ class TypeInfoResolverTest(test.TestCase): foo = bar foo() - with self.assertRaises(transformer.PyFlowParseError): - self._parse_and_analyze(test_fn, {'bar': bar}) + node = self._parse_and_analyze(test_fn, {'bar': bar}) + method_call = node.body[0].body[1].value.func + self.assertFalse(anno.hasanno(method_call, 'live_val')) def test_nested_members(self): @@ -165,8 +170,9 @@ class TypeInfoResolverTest(test.TestCase): foo = training.GradientDescentOptimizer(0.1) foo.bar.baz() - with self.assertRaises(transformer.PyFlowParseError): - self._parse_and_analyze(test_fn, {'training': training}) + node = self._parse_and_analyze(test_fn, {'training': training}) + method_call = node.body[0].body[1].value.func + self.assertFalse(anno.hasanno(method_call, 'live_val')) if __name__ == '__main__': diff --git a/tensorflow/contrib/py2tf/pyct/templates.py b/tensorflow/contrib/py2tf/pyct/templates.py index 77c5fbe02a11ed4a6b3d2cd80a032858f5b07e33..1039fc871393b1b4a388ce56d5beb563a37ee565 100644 --- a/tensorflow/contrib/py2tf/pyct/templates.py +++ b/tensorflow/contrib/py2tf/pyct/templates.py @@ -22,11 +22,13 @@ from __future__ import division from __future__ import print_function import ast -import copy +import textwrap import gast +from tensorflow.contrib.py2tf.pyct import copier from tensorflow.contrib.py2tf.pyct import parser +from tensorflow.contrib.py2tf.pyct import qual_names class ReplaceTransformer(gast.NodeTransformer): @@ -40,6 +42,7 @@ class ReplaceTransformer(gast.NodeTransformer): that these placeholders will be replaced by. """ self.replacements = replacements + self.in_replacements = False # TODO(mdan): Make a more detailed pass and clean up if needed. @@ -61,34 +64,53 @@ class ReplaceTransformer(gast.NodeTransformer): node.name = repl.id return node - def visit_Name(self, node): - if node.id in self.replacements: - # TODO(mdan): Sanitize the nodes by erasing scope-dependent annotations. - new_nodes = copy.copy(self.replacements[node.id]) - if isinstance(new_nodes, gast.AST): - new_nodes = [new_nodes] - # Preserve the target context. - for n in new_nodes: - if isinstance(n, gast.Tuple): - for e in n.elts: - e.ctx = node.ctx - n.ctx = node.ctx - if len(new_nodes) == 1: - new_nodes, = new_nodes - return new_nodes + def _set_inner_child_context(self, node, ctx): + if isinstance(node, gast.Attribute): + self._set_inner_child_context(node.value, ctx) + node.ctx = gast.Load() + elif isinstance(node, gast.Name): + node.ctx = ctx else: + raise ValueError('unexpected node type "%s"' % node) + + def visit_Name(self, node): + if node.id not in self.replacements: return node + new_nodes = copier.copy_clean(self.replacements[node.id]) + if isinstance(new_nodes, gast.AST): + new_nodes = [new_nodes] + + # Preserve the target context. + for n in new_nodes: + if isinstance(n, gast.Tuple): + for e in n.elts: + self._set_inner_child_context(e, node.ctx) + if isinstance(n, gast.Attribute): + # For attributes, the inner Name node receives the context, while the + # outer ones have it set to Load. + self._set_inner_child_context(n, node.ctx) + else: + n.ctx = node.ctx + + if len(new_nodes) == 1: + new_nodes, = new_nodes + + return new_nodes + -def _strings_to_names(n): +def _convert_to_ast(n): + """Convert from a known data type to AST.""" if isinstance(n, str): # Note: the node will receive the ctx value from the template, see # ReplaceTransformer.visit_Name. return gast.Name(id=n, ctx=None, annotation=None) + if isinstance(n, qual_names.QN): + return n.ast() if isinstance(n, list): - return [_strings_to_names(e) for e in n] + return [_convert_to_ast(e) for e in n] if isinstance(n, tuple): - return tuple(_strings_to_names(e) for e in n) + return tuple(_convert_to_ast(e) for e in n) return n @@ -119,7 +141,10 @@ def replace(template, **replacements): """ if not isinstance(template, str): raise ValueError('Expected string template, got %s' % type(template)) - tree = parser.parse_str(template) + tree = parser.parse_str(textwrap.dedent(template)) for k in replacements: - replacements[k] = _strings_to_names(replacements[k]) - return ReplaceTransformer(replacements).visit(tree).body + replacements[k] = _convert_to_ast(replacements[k]) + results = ReplaceTransformer(replacements).visit(tree).body + if isinstance(results, list): + return [qual_names.resolve(r) for r in results] + return qual_names.resolve(results) diff --git a/tensorflow/contrib/py2tf/pyct/templates_test.py b/tensorflow/contrib/py2tf/pyct/templates_test.py index 1143131283cd92c42abfc73d5728fac96cc31c23..0e3d07e378972de67d67d9ef7fd9bef351c5b5be 100644 --- a/tensorflow/contrib/py2tf/pyct/templates_test.py +++ b/tensorflow/contrib/py2tf/pyct/templates_test.py @@ -27,6 +27,16 @@ from tensorflow.python.platform import test class TemplatesTest(test.TestCase): + def test_replace_tuple(self): + template = """ + def test_fn(a, c): + return b, + """ + + node = templates.replace(template, b=('a', 'c'))[0] + result = compiler.ast_to_object(node) + self.assertEquals((2, 3), result.test_fn(2, 3)) + def test_replace_variable(self): template = """ def test_fn(a): diff --git a/tensorflow/contrib/py2tf/pyct/transformer.py b/tensorflow/contrib/py2tf/pyct/transformer.py index d5aa23eaebbbf7540d52d9fa9cc5292e0f756e6d..877d52af016af720424c8a56257fec9ab64611cb 100644 --- a/tensorflow/contrib/py2tf/pyct/transformer.py +++ b/tensorflow/contrib/py2tf/pyct/transformer.py @@ -18,8 +18,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import sys + import gast +import six +from tensorflow.contrib.py2tf.pyct import anno from tensorflow.contrib.py2tf.pyct import pretty_printer @@ -41,18 +45,25 @@ class Base(gast.NodeTransformer): self.context = context def visit(self, node): + source_code = self.context.source_code + source_file = self.context.source_file try: - source_code = self.context.source_code - source_file = self.context.source_file if source_code and hasattr(node, 'lineno'): self._lineno = node.lineno self._col_offset = node.col_offset + if anno.hasanno(node, anno.Basic.SKIP_PROCESSING): + return node return super(Base, self).visit(node) - except ValueError as e: - msg = '%s\nOccurred at node:\n%s' % (str(e), pretty_printer.fmt(node)) + except (ValueError, AttributeError, KeyError, NotImplementedError, + AssertionError) as e: + msg = '%s: %s\nOccurred at node:\n%s' % ( + e.__class__.__name__, str(e), pretty_printer.fmt(node, color=False)) if source_code: - line = self._source.splitlines()[self._lineno - 1] + line = source_code.splitlines()[self._lineno - 1] else: line = '' - raise PyFlowParseError( - msg, (source_file, self._lineno, self._col_offset + 1, line)) + six.reraise(PyFlowParseError, + PyFlowParseError( + msg, + (source_file, self._lineno, self._col_offset + 1, line)), + sys.exc_info()[2]) diff --git a/tensorflow/contrib/py2tf/utils/BUILD b/tensorflow/contrib/py2tf/utils/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..01804aa8834f23851dbc7af3ae9082645639ffbc --- /dev/null +++ b/tensorflow/contrib/py2tf/utils/BUILD @@ -0,0 +1,37 @@ +licenses(["notice"]) # Apache 2.0 + +load("//tensorflow:tensorflow.bzl", "py_test") + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) + +py_library( + name = "utils", + srcs = [ + "__init__.py", + "context_managers.py", + ], + srcs_version = "PY2AND3", + visibility = ["//tensorflow:__subpackages__"], + deps = [ + ], +) + +py_test( + name = "context_managers_test", + srcs = ["context_managers_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":utils", + "//tensorflow/python:client_testlib", + ], +) diff --git a/tensorflow/contrib/py2tf/utils/__init__.py b/tensorflow/contrib/py2tf/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bca33e89e99cd5939683ad10a2eb17db243af2ef --- /dev/null +++ b/tensorflow/contrib/py2tf/utils/__init__.py @@ -0,0 +1,21 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utility module that contains APIs usable in the generated code.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.py2tf.utils.context_managers import control_dependency_on_returns diff --git a/tensorflow/contrib/py2tf/utils/context_managers.py b/tensorflow/contrib/py2tf/utils/context_managers.py new file mode 100644 index 0000000000000000000000000000000000000000..47d98399971da039a8987ea17039c8b44bfa3b61 --- /dev/null +++ b/tensorflow/contrib/py2tf/utils/context_managers.py @@ -0,0 +1,41 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Various context managers.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import contextlib + + +def control_dependency_on_returns(tf, return_value): + """Create a TF control dependency on the return values of a function. + + If the function had no return value, a no-op context is returned. + + Args: + tf: The TensorFlow module. + return_value: The return value to set as control dependency. + + Returns: + A context manager. + """ + if return_value is None: + return contextlib.contextmanager(lambda: (yield))() + # TODO(mdan): Filter to tensor objects. + if not isinstance(return_value, (list, tuple)): + return_value = (return_value,) + return tf.control_dependencies(return_value) diff --git a/tensorflow/contrib/py2tf/utils/context_managers_test.py b/tensorflow/contrib/py2tf/utils/context_managers_test.py new file mode 100644 index 0000000000000000000000000000000000000000..c903f082528118aba2d0163b25a38178b99a17e7 --- /dev/null +++ b/tensorflow/contrib/py2tf/utils/context_managers_test.py @@ -0,0 +1,43 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for context_managers module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.py2tf.utils import context_managers +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.platform import test + + +class ContextManagersTest(test.TestCase): + + def test_control_dependency_on_returns(self): + # Just dry run them. + with context_managers.control_dependency_on_returns(ops, None): + pass + with context_managers.control_dependency_on_returns( + ops, constant_op.constant(1)): + pass + with context_managers.control_dependency_on_returns( + ops, [constant_op.constant(1), + constant_op.constant(2)]): + pass + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/quantize/BUILD b/tensorflow/contrib/quantize/BUILD index 3c5b34a0a6adb2f4e340a8e378c1eb51a2e2b534..b7d525a1fa203fd150642c18304759e1a9c48c4b 100644 --- a/tensorflow/contrib/quantize/BUILD +++ b/tensorflow/contrib/quantize/BUILD @@ -77,9 +77,13 @@ py_library( "//tensorflow/contrib/graph_editor:graph_editor_py", "//tensorflow/python:array_ops", "//tensorflow/python:framework_ops", + "//tensorflow/python:layers", "//tensorflow/python:math_ops", "//tensorflow/python:nn", "//tensorflow/python:nn_ops", + "//tensorflow/python:ops", + "//tensorflow/python:training", + "//tensorflow/python:variables", ], ) diff --git a/tensorflow/contrib/quantize/python/fold_batch_norms.py b/tensorflow/contrib/quantize/python/fold_batch_norms.py index aa605e6caadf4d1e69a4a331b1e580797e4fdef8..7fa0d484ec94f90a817630f2eb985c0cdc0e4c04 100644 --- a/tensorflow/contrib/quantize/python/fold_batch_norms.py +++ b/tensorflow/contrib/quantize/python/fold_batch_norms.py @@ -26,14 +26,16 @@ from tensorflow.contrib.quantize.python import input_to_ops from tensorflow.core.framework import attr_value_pb2 from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.layers import utils from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn from tensorflow.python.ops import nn_ops +from tensorflow.python.training import training_util from tensorflow.python.util import compat -def FoldBatchNorms(graph): +def FoldBatchNorms(graph, freeze_batch_norm_delay=None, is_training=True): """Finds batch norm layers and folds them into preceding layers. Folding only affects the following layers: Conv2D, fully connected, depthwise @@ -41,15 +43,25 @@ def FoldBatchNorms(graph): Args: graph: Graph to walk and modify. + freeze_batch_norm_delay: How many steps to wait before freezing + moving mean and variance and using them for batch normalization. This value + is used only when is_training is True. + is_training: Bool, true if training Raises: ValueError: When batch norm folding fails. """ - _FoldFusedBatchNorms(graph) - _FoldUnfusedBatchNorms(graph) + _FoldFusedBatchNorms( + graph, + freeze_batch_norm_delay=freeze_batch_norm_delay, + is_training=is_training) + _FoldUnfusedBatchNorms( + graph, + freeze_batch_norm_delay=freeze_batch_norm_delay, + is_training=is_training) -def _FoldFusedBatchNorms(graph): +def _FoldFusedBatchNorms(graph, freeze_batch_norm_delay, is_training): """Finds fused batch norm layers and folds them into preceding layers. Folding only affects the following layers: Conv2D, fully connected, depthwise @@ -57,6 +69,9 @@ def _FoldFusedBatchNorms(graph): Args: graph: Graph to walk and modify. + freeze_batch_norm_delay: How many steps to wait before freezing + moving mean and variance and using them for batch normalization + is_training: Bool, true if training Raises: ValueError: When batch norm folding fails. @@ -67,8 +82,7 @@ def _FoldFusedBatchNorms(graph): # `bn_op`. The '/' (i.e. `sep`) ensures that we reuse the existing scope # named `scope`. Otherwise, TF creates a unique scope whose name starts with # `scope`. - with graph.as_default(), graph.name_scope(scope + sep), ops.device( - match.bn_op.device): + with graph.as_default(), graph.name_scope(scope + sep): with graph.name_scope(scope + sep + 'BatchNorm_Fold' + sep): # new weights = old weights * gamma / sqrt(variance + epsilon) # new biases = -mean * gamma / sqrt(variance + epsilon) + beta @@ -79,9 +93,18 @@ def _FoldFusedBatchNorms(graph): match.mean_tensor * multiplier_tensor, name='bias') + correction_scale, correction_recip, correction_offset = None, None, None + if is_training: + correction_scale, correction_recip, correction_offset = ( + _ComputeBatchNormCorrections( + context='', + match=match, + freeze_batch_norm_delay=freeze_batch_norm_delay, + fused_batch_norm=True)) # The shape of depthwise weights is different, so we need to reshape the # multiplier_tensor to ensure that the scaled_weight_tensor has the # expected shape. + weights = match.weight_tensor if match.layer_op.type == 'DepthwiseConv2dNative': new_shape = [ match.weight_tensor.get_shape().as_list()[2], @@ -90,15 +113,25 @@ def _FoldFusedBatchNorms(graph): multiplier_tensor = array_ops.reshape( multiplier_tensor, new_shape, name='scale_reshape') - # TODO(suharshs): This naming of the following ops needs to carefully - # follow the naming expected by quantize.py. Generalize the quantize code - # to not require these delicate naming conventions. - scaled_weight_tensor = math_ops.multiply( - match.weight_tensor, multiplier_tensor, name='mul_fold') + if correction_scale is not None: + correction_scale = array_ops.reshape( + correction_scale, new_shape, name='correction_reshape') + + if correction_scale is not None: + weights = math_ops.multiply( + correction_scale, weights, name='correction_mult') + scaled_weight_tensor = math_ops.multiply( + weights, multiplier_tensor, name='mul_fold') new_layer_tensor = _CloneWithNewOperands( match.layer_op, match.input_tensor, scaled_weight_tensor) + if correction_recip is not None: + new_layer_tensor = math_ops.multiply( + correction_recip, new_layer_tensor, name='post_conv_mul') + new_layer_tensor = math_ops.add(new_layer_tensor, (correction_offset), + 'correction_add') + bias_add_tensor = math_ops.add( new_layer_tensor, bias_tensor, name='add_fold') @@ -109,46 +142,6 @@ def _FoldFusedBatchNorms(graph): 'Unexpected inputs to op: %s' % match.output_tensor.name) -def _CloneWithNewOperands(layer_op, input_tensor, weight_tensor): - """Clones layer_op with input_tensor and weight_tensor as new inputs.""" - new_layer_name = layer_op.name.split('/')[-1] + '_Fold' - if layer_op.type == 'Conv2D': - return nn_ops.conv2d( - input_tensor, - weight_tensor, - strides=layer_op.get_attr('strides'), - padding=layer_op.get_attr('padding'), - use_cudnn_on_gpu=layer_op.get_attr('use_cudnn_on_gpu'), - data_format=layer_op.get_attr('data_format'), - name=new_layer_name) - elif layer_op.type == 'MatMul': - return math_ops.matmul( - input_tensor, - weight_tensor, - transpose_a=layer_op.get_attr('transpose_a'), - transpose_b=layer_op.get_attr('transpose_b'), - name=new_layer_name) - elif layer_op.type == 'DepthwiseConv2dNative': - return nn.depthwise_conv2d( - input_tensor, - weight_tensor, - strides=layer_op.get_attr('strides'), - padding=layer_op.get_attr('padding'), - name=new_layer_name) - else: - raise ValueError('Cannot handle operation of type: %s' % layer_op.type) - - -@ops.RegisterGradient('FoldFusedBatchNormGrad') -def _FoldFusedBatchNormGrad(op, unused_grad_y, grad_mean, grad_var, unused_1, - unused_2): - x = op.inputs[0] - n = x.get_shape().num_elements() / grad_mean.get_shape().num_elements() - dmean_dx = grad_mean / n - dvar_dx = 2 * grad_var * (x - op.outputs[1]) / (n - 1) - return (dmean_dx + dvar_dx), None, None, None, None - - def _FindFusedBatchNorms(graph): """Finds all ops and tensors related to found FusedBatchNorms. @@ -165,37 +158,59 @@ def _FindFusedBatchNorms(graph): mean_pattern = graph_matcher.OpTypePattern('*') variance_pattern = graph_matcher.OpTypePattern('*') - conv_pattern = graph_matcher.OpTypePattern( - 'Conv2D|DepthwiseConv2dNative', inputs=[input_pattern, weight_pattern]) + moving_average_pattern = graph_matcher.OpTypePattern('*') + bn_decay_pattern = graph_matcher.OpTypePattern('*') + layer_pattern = graph_matcher.OpTypePattern( + 'Conv2D|DepthwiseConv2dNative|MatMul', + inputs=[input_pattern, weight_pattern]) # MatMul has a Reshape between it and FusedBatchNorm. - matmul_pattern = graph_matcher.OpTypePattern( - 'MatMul', inputs=[input_pattern, weight_pattern]) matmul_reshape_pattern = graph_matcher.OpTypePattern( - 'Reshape', inputs=[matmul_pattern, + 'Reshape', inputs=[layer_pattern, graph_matcher.OpTypePattern('*')]) - conv_batch_norm_pattern = graph_matcher.OpTypePattern( - 'FusedBatchNorm', - inputs=[ - conv_pattern, gamma_pattern, beta_pattern, mean_pattern, - variance_pattern - ]) - matmul_batch_norm_pattern = graph_matcher.OpTypePattern( + batch_norm_pattern = graph_matcher.OpTypePattern( 'FusedBatchNorm', inputs=[ - matmul_reshape_pattern, gamma_pattern, beta_pattern, mean_pattern, - variance_pattern + graph_matcher.OneofPattern([matmul_reshape_pattern, layer_pattern]), + gamma_pattern, beta_pattern, mean_pattern, variance_pattern ]) matmul_bn_output_reshape_pattern = graph_matcher.OpTypePattern( - 'Reshape', - inputs=[matmul_batch_norm_pattern, - graph_matcher.OpTypePattern('*')]) + 'Reshape', inputs=[batch_norm_pattern, + graph_matcher.OpTypePattern('*')]) + + bn_matcher = graph_matcher.GraphMatcher( + graph_matcher.OneofPattern( + [matmul_bn_output_reshape_pattern, batch_norm_pattern])) + + moving_average_sub_pattern = graph_matcher.OpTypePattern( + 'Sub', inputs=[moving_average_pattern, batch_norm_pattern]) + moving_average_mul_pattern = graph_matcher.OpTypePattern( + 'Mul', inputs=[moving_average_sub_pattern, bn_decay_pattern]) - conv_matcher = graph_matcher.GraphMatcher(conv_batch_norm_pattern) - matmul_matcher = graph_matcher.GraphMatcher(matmul_bn_output_reshape_pattern) + moving_avg_mul_matcher = graph_matcher.GraphMatcher( + moving_average_mul_pattern) + + for match_result in bn_matcher.match_graph(graph): + moving_mean_tensor = None + moving_variance_tensor = None + bn_decay_mean_tensor = None + bn_decay_var_tensor = None + layer_op = match_result.get_op(layer_pattern) + layer_tensor = match_result.get_tensor(layer_pattern) + bn_op = match_result.get_op(batch_norm_pattern) + batch_epsilon_tensor = bn_op.get_attr('epsilon') + + # In the MatMul case, the output of batch norm is reshaped back into a + # 2D tensor, so the output_tensor is the output of the Reshape op. + output_tensor = bn_op.outputs[0] + if layer_op.type == 'MatMul': + output_reshape_op = match_result.get_op(matmul_bn_output_reshape_pattern) + # If the matcher didn't match matmul_bn_output_reshape, there will be + # another match for this 'MatMul' later, so we can skip this one. + if output_reshape_op is None: + continue + output_tensor = output_reshape_op.outputs[0] - def _GetCommonTensors(match_result, bn_op, bn_input_tensor): - """Gets tensors needed for FusedBatchNormMatch from match_result.""" input_tensor = match_result.get_tensor(input_pattern) weight_tensor = match_result.get_tensor(weight_pattern) gamma_tensor = match_result.get_tensor(gamma_pattern) @@ -222,48 +237,30 @@ def _FindFusedBatchNorms(graph): # calculation, the variance is corrected by the term N/N-1 (Bessel's # correction). The variance tensor read from FuseBatchNorm has bessel's # correction applied, so we undo it here. - n = math_ops.cast( - array_ops.size(bn_input_tensor) / array_ops.size(mean_tensor), - dtypes.float32) - variance_tensor = bn_op.outputs[2] * (n - 1) / n + scope, sep, _ = bn_op.name.rpartition('/') + g = ops.get_default_graph() + with g.as_default(), g.name_scope(scope + sep): + n = math_ops.cast( + array_ops.size(layer_tensor) / array_ops.size(mean_tensor), + dtypes.float32) + variance_tensor = math_ops.multiply( + bn_op.outputs[2], (n - 1) / n, name='Undo_Bessel_Correction') + # TODO(suharshs): Find a way to get rid of this inner match. + for mul_match_result in moving_avg_mul_matcher.match_graph(graph): + sub_op = mul_match_result.get_op(moving_average_sub_pattern) + if sub_op.inputs[1].name == bn_op.outputs[1].name: + # During training: Batch Mean is bn_op.outputs[1] + moving_mean_tensor = sub_op.inputs[0] + bn_decay_mean_tensor = mul_match_result.get_tensor(bn_decay_pattern) + if sub_op.inputs[1].name == bn_op.outputs[2].name: + # During training: Batch Var is bn_op.outputs[2] + moving_variance_tensor = sub_op.inputs[0] + bn_decay_var_tensor = mul_match_result.get_tensor(bn_decay_pattern) else: mean_tensor = match_result.get_tensor(mean_pattern) variance_tensor = match_result.get_tensor(variance_pattern) - return (input_tensor, weight_tensor, gamma_tensor, beta_tensor, mean_tensor, - variance_tensor) - - for match_result in conv_matcher.match_graph(graph): - layer_op = match_result.get_op(conv_pattern) - layer_tensor = match_result.get_tensor(conv_pattern) - bn_op = match_result.get_op(conv_batch_norm_pattern) - # In the case of convolution the output_tensor is the output of bn_op. - output_tensor = bn_op.outputs[0] - - (input_tensor, weight_tensor, gamma_tensor, beta_tensor, mean_tensor, - variance_tensor) = _GetCommonTensors(match_result, bn_op, layer_tensor) - yield _FusedBatchNormMatch( - layer_op=layer_op, - bn_op=bn_op, - output_tensor=output_tensor, - input_tensor=input_tensor, - weight_tensor=weight_tensor, - gamma_tensor=gamma_tensor, - beta_tensor=beta_tensor, - mean_tensor=mean_tensor, - variance_tensor=variance_tensor) - - for match_result in matmul_matcher.match_graph(graph): - layer_op = match_result.get_op(matmul_pattern) - layer_tensor = match_result.get_tensor(matmul_pattern) - bn_op = match_result.get_op(matmul_batch_norm_pattern) - # In the MatMul case, the output of batch norm is reshaped back into a - # 2D tensor, so the output_tensor is the output of the Reshape op. - output_reshape_op = match_result.get_op(matmul_bn_output_reshape_pattern) - output_tensor = output_reshape_op.outputs[0] - (input_tensor, weight_tensor, gamma_tensor, beta_tensor, mean_tensor, - variance_tensor) = _GetCommonTensors(match_result, bn_op, layer_tensor) - yield _FusedBatchNormMatch( + yield _BatchNormMatch( layer_op=layer_op, bn_op=bn_op, output_tensor=output_tensor, @@ -272,63 +269,156 @@ def _FindFusedBatchNorms(graph): gamma_tensor=gamma_tensor, beta_tensor=beta_tensor, mean_tensor=mean_tensor, - variance_tensor=variance_tensor) + variance_tensor=variance_tensor, + moving_mean_tensor=moving_mean_tensor, + moving_variance_tensor=moving_variance_tensor, + bn_decay_mean_tensor=bn_decay_mean_tensor, + bn_decay_var_tensor=bn_decay_var_tensor, + batch_epsilon_tensor=batch_epsilon_tensor) + + +def _ComputeBatchNormCorrections(context, match, freeze_batch_norm_delay, + fused_batch_norm): + """Computes batch norm correction params. + + Before batch normalization is frozen: + We use batch statistics for batch norm. + correction_scale = sigma_b/sigma_mv + correction_recip = 1/correction_scale + correction_offset = 0 + + After batch normalization is frozen: + correction_scale = sigma_b/sigma_mv + correction_recip = 1 + correction_offset = gamma*(mu_b/sigma_b-mu_mv/sigma_mv). + + Batch norm is frozen if global_step > bn_freeze_delay. + The corrections ensure that: + a) The weights are quantized after scaling by gamma/sigma_mv. This enables + smoother training as the scaling on the weights changes slowly, rather than + jump across mini-batches + b) Changing the values of the corrections allows for one to switch between + using batch statistics to using moving mean and average, without requiring + changes to batch_norm -class _FusedBatchNormMatch(object): - """Contains all information related to a found FusedBatchNorm.""" - - def __init__(self, layer_op, bn_op, output_tensor, input_tensor, - weight_tensor, gamma_tensor, beta_tensor, mean_tensor, - variance_tensor): - self._layer_op = layer_op - self._bn_op = bn_op - self._output_tensor = output_tensor - self._input_tensor = input_tensor - self._weight_tensor = weight_tensor - self._gamma_tensor = gamma_tensor - self._beta_tensor = beta_tensor - self._mean_tensor = mean_tensor - self._variance_tensor = variance_tensor - - @property - def layer_op(self): - return self._layer_op - - @property - def bn_op(self): - return self._bn_op - - @property - def output_tensor(self): - return self._output_tensor + Args: + context: The scope under which we look for batch norm params + match: Object containg required batch norm tensors for correction + computation + freeze_batch_norm_delay: Delay in steps at which computation switches + from regular batch norm to frozen mean and variance. + fused_batch_norm: Bool, true if fused batch norm is used - @property - def input_tensor(self): - return self._input_tensor + Returns: + A tuple of correction_scale, correction_recip, correction_offset + """ - @property - def weight_tensor(self): - return self._weight_tensor + g = ops.get_default_graph() + with g.name_scope(context + '/batch_norm_correction'): + recip_sigma_mv = math_ops.rsqrt( + match.moving_variance_tensor + match.batch_epsilon_tensor) + recip_sigma = math_ops.rsqrt( + match.variance_tensor + match.batch_epsilon_tensor) + correction_scale = math_ops.divide( + recip_sigma_mv, recip_sigma, name='scale_compute') + correction_scale = array_ops.identity( + correction_scale, name='correction_scale') + correction_recip = math_ops.reciprocal( + correction_scale, name='reciprocal_compute') + correction_offset = math_ops.multiply( + match.gamma_tensor, + match.mean_tensor * recip_sigma - + match.moving_mean_tensor * recip_sigma_mv, + name='offset_compute') + + if freeze_batch_norm_delay is not None: + use_mv_avg = math_ops.greater_equal( + training_util.get_or_create_global_step(), + freeze_batch_norm_delay, + name='use_moving_average') + else: + use_mv_avg = False + + bn_decay_zero = 0.0 + bn_decay_mean_consumers = list(match.bn_decay_mean_tensor.consumers()) + bn_decay_var_consumers = list(match.bn_decay_mean_tensor.consumers()) + + bn_decay_mean_out = utils.smart_cond( + use_mv_avg, + lambda: bn_decay_zero, + lambda: match.bn_decay_mean_tensor, + name='freeze_moving_mean') + graph_editor.reroute_ts( + [bn_decay_mean_out], [match.bn_decay_mean_tensor], + can_modify=bn_decay_mean_consumers) + + if fused_batch_norm is False: + bn_decay_var_consumers = list(match.bn_decay_var_tensor.consumers()) + bn_decay_var_out = utils.smart_cond( + use_mv_avg, + lambda: bn_decay_zero, + lambda: match.bn_decay_var_tensor, + name='freeze_moving_var') + graph_editor.reroute_ts( + [bn_decay_var_out], [match.bn_decay_var_tensor], + can_modify=bn_decay_var_consumers) + + correction_recip = utils.smart_cond( + use_mv_avg, + lambda: array_ops.ones(correction_scale.shape), + lambda: correction_recip, + name='correction_recip') + + correction_offset = utils.smart_cond( + use_mv_avg, + lambda: correction_offset, + lambda: array_ops.zeros(correction_offset.shape), + name='correction_offset') + return correction_scale, correction_recip, correction_offset - @property - def gamma_tensor(self): - return self._gamma_tensor - @property - def beta_tensor(self): - return self._beta_tensor +def _CloneWithNewOperands(layer_op, input_tensor, weight_tensor): + """Clones layer_op with input_tensor and weight_tensor as new inputs.""" + new_layer_name = layer_op.name.split('/')[-1] + '_Fold' + if layer_op.type == 'Conv2D': + return nn_ops.conv2d( + input_tensor, + weight_tensor, + strides=layer_op.get_attr('strides'), + padding=layer_op.get_attr('padding'), + use_cudnn_on_gpu=layer_op.get_attr('use_cudnn_on_gpu'), + data_format=layer_op.get_attr('data_format'), + name=new_layer_name) + elif layer_op.type == 'MatMul': + return math_ops.matmul( + input_tensor, + weight_tensor, + transpose_a=layer_op.get_attr('transpose_a'), + transpose_b=layer_op.get_attr('transpose_b'), + name=new_layer_name) + elif layer_op.type == 'DepthwiseConv2dNative': + return nn.depthwise_conv2d( + input_tensor, + weight_tensor, + strides=layer_op.get_attr('strides'), + padding=layer_op.get_attr('padding'), + name=new_layer_name) + else: + raise ValueError('Cannot handle operation of type: %s' % layer_op.type) - @property - def mean_tensor(self): - return self._mean_tensor - @property - def variance_tensor(self): - return self._variance_tensor +@ops.RegisterGradient('FoldFusedBatchNormGrad') +def _FoldFusedBatchNormGrad(op, unused_grad_y, grad_mean, grad_var, unused_1, + unused_2): + x = op.inputs[0] + n = x.get_shape().num_elements() / grad_mean.get_shape().num_elements() + dmean_dx = grad_mean / n + dvar_dx = 2 * grad_var * (x - op.outputs[1]) / (n - 1) + return (dmean_dx + dvar_dx), None, None, None, None -def _FoldUnfusedBatchNorms(graph): +def _FoldUnfusedBatchNorms(graph, freeze_batch_norm_delay, is_training): """Finds unfused batch norm layers and folds them into preceding layers. Folding only affects the following layers: Conv2D, fully connected, depthwise @@ -336,6 +426,9 @@ def _FoldUnfusedBatchNorms(graph): Args: graph: Graph to walk and modify. + freeze_batch_norm_delay: How many steps to wait before freezing + moving mean and variance and using them for batch normalization + is_training: Bool, True if training Raises: ValueError: When batch norm folding fails. @@ -346,7 +439,12 @@ def _FoldUnfusedBatchNorms(graph): has_scaling = _HasScaling(graph, input_to_ops_map, bn) # The mangling code intimately depends on BatchNorm node's internals. - original_op, folded_op = _CreateFoldedOp(graph, bn, has_scaling=has_scaling) + original_op, folded_op = _CreateFoldedOp( + graph, + bn, + has_scaling=has_scaling, + freeze_batch_norm_delay=freeze_batch_norm_delay, + is_training=is_training) activation = common.GetEndpointActivationOp(graph, bn) if activation: @@ -368,46 +466,85 @@ def _FoldUnfusedBatchNorms(graph): raise ValueError('Unexpected inputs to op: %s' % add_bypass.name) -def _HasScaling(graph, input_to_ops_map, bn): - r"""Checks if batch norm has scaling enabled. - - Difference between batch norm with scaling and without is that with scaling: - - Rsqrt -> mul -> mul_1 - \-> mul_2 - - where - mul multiplies gamma by inverse square root of EMA of batch variance, - mul_1 multiplies output of mul with output from the base operation - (convolution, FC or depthwise convolution), - mul_2 multiplies output of mul with EMA of batch mean, - and without scaling: - - Rsqrt -> mul - \-> mul_1 - - where - mul multiplies the inverse square root of EMA of batch variance with output - from the base operation, - mul_1 multiplies inverse square root of EMA of batch variance with EMA - of batch mean. +def _GetBatchNormParams(graph, context, has_scaling): + """Extracts relevant tensors for folding batch norms. Args: graph: Graph to inspect. - input_to_ops_map: InputToOps object containing mapping from tensor's name - to ops that take it as input. - bn: Batch norm layer prefix string. + context: The scope under which we look for batch norm params + has_scaling: Bool that specifies if scaling is done as part of batch + norm Returns: - A boolean indicating whether this batch norm layer has scaling enabled. + _BatchNormMatch containing all required batch norm parameters """ - rsqrt_op = graph.get_operation_by_name(bn + '/BatchNorm/batchnorm/Rsqrt') - rsqrt_consumers = input_to_ops_map.ConsumerOperations(rsqrt_op) - - return sum(1 for op in rsqrt_consumers if op.type == 'Mul') == 1 - - -def _CreateFoldedOp(graph, context, has_scaling): + gamma_tensor = None + batch_mean_tensor = None + batch_variance_tensor = None + moving_mean_tensor = None + moving_variance_tensor = None + batch_epsilon_tensor = None + bn_decay_mean_tensor = None + bn_decay_var_tensor = None + + split_context = context.split('/') + base_context = split_context[-1] + + oplist = graph.get_operations() + op_suffix_gamma = base_context + '/BatchNorm/gamma' + op_suffix_mean = base_context + '/BatchNorm/moments/Squeeze' + op_suffix_variance = base_context + '/BatchNorm/moments/Squeeze_1' + op_suffix_moving_variance = base_context + '/BatchNorm/moving_variance/read' + op_suffix_moving_mean = base_context + '/BatchNorm/moving_mean/read' + op_suffix_epsilon = base_context + '/BatchNorm/batchnorm/add/y' + op_suffix_bn_decay_mean = base_context + '/BatchNorm/AssignMovingAvg/decay' + op_suffix_bn_decay_var = base_context + '/BatchNorm/AssignMovingAvg_1/decay' + + # Parse through list of ops to find relevant ops + for op in oplist: + if op.name.endswith(op_suffix_mean): + # This is an efficient way to check for two things: + # Is batch norm present and is it training mode? + # Batch statistics are computed only during batch norm in training + batch_mean_tensor = graph.get_tensor_by_name(op.name + ':0') + if op.name.endswith(op_suffix_variance): + batch_variance_tensor = graph.get_tensor_by_name(op.name + ':0') + if op.name.endswith(op_suffix_moving_mean): + moving_mean_tensor = graph.get_tensor_by_name(op.name + ':0') + if op.name.endswith(op_suffix_moving_variance): + moving_variance_tensor = graph.get_tensor_by_name(op.name + ':0') + if op.name.endswith(op_suffix_epsilon): + batch_epsilon_tensor = graph.get_tensor_by_name(op.name + ':0') + if op.name.endswith(op_suffix_bn_decay_mean): + bn_decay_mean_tensor = graph.get_tensor_by_name(op.name + ':0') + if op.name.endswith(op_suffix_bn_decay_var): + bn_decay_var_tensor = graph.get_tensor_by_name(op.name + ':0') + if has_scaling: + if op.name.endswith(op_suffix_gamma): + gamma_tensor = graph.get_tensor_by_name(op.name + ':0') + + if not has_scaling: + gamma_tensor = array_ops.ones(batch_mean_tensor.shape) + + return _BatchNormMatch( + layer_op=None, + bn_op=None, + output_tensor=None, + input_tensor=None, + weight_tensor=None, + gamma_tensor=gamma_tensor, + beta_tensor=None, + mean_tensor=batch_mean_tensor, + variance_tensor=batch_variance_tensor, + moving_mean_tensor=moving_mean_tensor, + moving_variance_tensor=moving_variance_tensor, + bn_decay_mean_tensor=bn_decay_mean_tensor, + bn_decay_var_tensor=bn_decay_var_tensor, + batch_epsilon_tensor=batch_epsilon_tensor) + + +def _CreateFoldedOp(graph, context, has_scaling, freeze_batch_norm_delay, + is_training): """Folds in batch norm layer into preceding convolution or FC layer. Creates 3 new nodes, connects their inputs and adds them to the graph: @@ -419,6 +556,9 @@ def _CreateFoldedOp(graph, context, has_scaling): context: String, batch norm context, i.e. node into which BatchNorm is nested. has_scaling: Whether the batch norm has scaling enabled. + freeze_batch_norm_delay: How many steps to wait before freezing + moving mean and variance and using them for batch normalization + is_training: Bool, true if training Raises: ValueError: When operation type is not supported, or input and output tensor @@ -435,19 +575,43 @@ def _CreateFoldedOp(graph, context, has_scaling): mul_scale_name) op_below = mul_scale.inputs[0].op weights = op_below.inputs[1] - + match = _GetBatchNormParams( + graph=graph, context=context, has_scaling=has_scaling) + correction_scale, correction_recip, correction_offset = None, None, None + if is_training: + correction_scale, correction_recip, correction_offset = ( + _ComputeBatchNormCorrections( + context=context, + match=match, + freeze_batch_norm_delay=freeze_batch_norm_delay, + fused_batch_norm=False)) # Special handling for weights of depthwise convolution. if op_below.type == 'DepthwiseConv2dNative': - new_shape = [weights.get_shape().as_list()[2], - weights.get_shape().as_list()[3]] + new_shape = [ + weights.get_shape().as_list()[2], + weights.get_shape().as_list()[3] + ] scale_name = 'mul' if has_scaling else 'Rsqrt' - scale = graph.get_operation_by_name(context + '/BatchNorm/batchnorm/' + - scale_name) + scale = graph.get_operation_by_name( + context + '/BatchNorm/batchnorm/' + scale_name) scale = array_ops.reshape(scale.outputs[0], new_shape, context + '/scale_reshape') - mul_fold = _CloneOp(mul_scale, context + '/mul_fold', - [(0, weights), (1, scale)]) + + if correction_scale is not None: + correction_scale = array_ops.reshape(correction_scale, new_shape, + context + '/correction_reshape') + with ops.device(mul_scale.device): + weights = math_ops.multiply(correction_scale, weights, + context + '/correction_mult') + + mul_fold = _CloneOp(mul_scale, context + '/mul_fold', [(0, weights), + (1, scale)]) elif op_below.type in ['Conv2D', 'MatMul']: + + if correction_scale is not None: + with ops.device(mul_scale.device): + weights = math_ops.multiply(correction_scale, weights, + context + '/correction_mult') mul_fold = _CloneOp(mul_scale, context + '/mul_fold', [(0, weights)]) else: raise ValueError('Cannot handle operation of type: %s' % op_below.op) @@ -456,10 +620,17 @@ def _CreateFoldedOp(graph, context, has_scaling): conv_or_fc_folded = _CloneOp(op_below, op_below.name + '_Fold', [(1, mul_fold.outputs[0])]) - add_shift = graph.get_operation_by_name(context + - '/BatchNorm/batchnorm/add_1') - add_fold = _CloneOp(add_shift, context + '/add_fold', - [(0, conv_or_fc_folded.outputs[0])]) + add_shift = graph.get_operation_by_name( + context + '/BatchNorm/batchnorm/add_1') + + corrected_output = conv_or_fc_folded.outputs[0] + if correction_offset is not None: + with ops.device(conv_or_fc_folded.device): + corrected_output = math_ops.multiply(correction_recip, corrected_output, + context + '/post_conv_mul') + corrected_output = math_ops.add(corrected_output, (correction_offset), + context + '/correction_add') + add_fold = _CloneOp(add_shift, context + '/add_fold', [(0, corrected_output)]) _AssertShapesMatch('add_fold', add_fold.inputs[0], add_fold.outputs[0]) return add_shift, add_fold @@ -603,3 +774,121 @@ def _AssertShapesMatch(op_name, in_tensor, out_tensor): if not in_shape.is_compatible_with(out_shape): raise ValueError('%s should not change tensor shape: input %s, ' 'output %s' % (op_name, in_shape, out_shape)) + + +def _HasScaling(graph, input_to_ops_map, bn): + r"""Checks if batch norm has scaling enabled. + + Difference between batch norm with scaling and without is that with scaling: + + Rsqrt -> mul -> mul_1 + \-> mul_2 + + where + mul multiplies gamma by inverse square root of EMA of batch variance, + mul_1 multiplies output of mul with output from the base operation + (convolution, FC or depthwise convolution), + mul_2 multiplies output of mul with EMA of batch mean, + and without scaling: + + Rsqrt -> mul + \-> mul_1 + + where + mul multiplies the inverse square root of EMA of batch variance with output + from the base operation, + mul_1 multiplies inverse square root of EMA of batch variance with EMA + of batch mean. + + Args: + graph: Graph to inspect. + input_to_ops_map: InputToOps object containing mapping from tensor's name + to ops that take it as input. + bn: Batch norm layer prefix string. + + Returns: + A boolean indicating whether this batch norm layer has scaling enabled. + """ + rsqrt_op = graph.get_operation_by_name(bn + '/BatchNorm/batchnorm/Rsqrt') + rsqrt_consumers = input_to_ops_map.ConsumerOperations(rsqrt_op) + + return sum(1 for op in rsqrt_consumers if op.type == 'Mul') == 1 + + +class _BatchNormMatch(object): + """Contains all information related to a found Fused/UnfusedBatchNorm.""" + + def __init__(self, layer_op, bn_op, output_tensor, input_tensor, + weight_tensor, gamma_tensor, beta_tensor, mean_tensor, + variance_tensor, moving_mean_tensor, moving_variance_tensor, + bn_decay_mean_tensor, bn_decay_var_tensor, batch_epsilon_tensor): + self._layer_op = layer_op + self._bn_op = bn_op + self._output_tensor = output_tensor + self._input_tensor = input_tensor + self._weight_tensor = weight_tensor + self._gamma_tensor = gamma_tensor + self._beta_tensor = beta_tensor + self._mean_tensor = mean_tensor + self._variance_tensor = variance_tensor + self._moving_mean_tensor = moving_mean_tensor + self._moving_variance_tensor = moving_variance_tensor + self._bn_decay_mean_tensor = bn_decay_mean_tensor + self._bn_decay_var_tensor = bn_decay_var_tensor + self._batch_epsilon_tensor = batch_epsilon_tensor + + @property + def layer_op(self): + return self._layer_op + + @property + def bn_op(self): + return self._bn_op + + @property + def output_tensor(self): + return self._output_tensor + + @property + def input_tensor(self): + return self._input_tensor + + @property + def weight_tensor(self): + return self._weight_tensor + + @property + def gamma_tensor(self): + return self._gamma_tensor + + @property + def beta_tensor(self): + return self._beta_tensor + + @property + def mean_tensor(self): + return self._mean_tensor + + @property + def variance_tensor(self): + return self._variance_tensor + + @property + def moving_mean_tensor(self): + return self._moving_mean_tensor + + @property + def moving_variance_tensor(self): + return self._moving_variance_tensor + + @property + def batch_epsilon_tensor(self): + return self._batch_epsilon_tensor + + @property + def bn_decay_mean_tensor(self): + return self._bn_decay_mean_tensor + + @property + def bn_decay_var_tensor(self): + return self._bn_decay_var_tensor diff --git a/tensorflow/contrib/quantize/python/fold_batch_norms_test.py b/tensorflow/contrib/quantize/python/fold_batch_norms_test.py index ecf321ff573181c7a2e325770a8dde223bf0c021..330bd8a6474c18b236b635d930e7a1df9594d84f 100644 --- a/tensorflow/contrib/quantize/python/fold_batch_norms_test.py +++ b/tensorflow/contrib/quantize/python/fold_batch_norms_test.py @@ -46,26 +46,27 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): def _RunTestOverParameters(self, test_fn): parameters_list = [ - # (relu, relu_op_name, with_bypass, has_scaling, fused_batch_norm) - (nn_ops.relu6, 'Relu6', False, False, False), - (nn_ops.relu, 'Relu', False, False, False), - (nn_ops.relu6, 'Relu6', True, False, False), - (nn_ops.relu, 'Relu', True, False, False), - (nn_ops.relu6, 'Relu6', False, True, False), - (nn_ops.relu, 'Relu', False, True, False), - (nn_ops.relu6, 'Relu6', True, True, False), - (nn_ops.relu, 'Relu', True, True, False), + # (relu, relu_op_name, with_bypass, has_scaling, fused_batch_norm, + # freeze_batch_norm_delay) + (nn_ops.relu6, 'Relu6', False, False, False, 100), + (nn_ops.relu, 'Relu', False, False, False, None), + (nn_ops.relu6, 'Relu6', True, False, False, 100), + (nn_ops.relu, 'Relu', True, False, False, None), + (nn_ops.relu6, 'Relu6', False, True, False, 100), + (nn_ops.relu, 'Relu', False, True, False, None), + (nn_ops.relu6, 'Relu6', True, True, False, 100), + (nn_ops.relu, 'Relu', True, True, False, None), # Fused batch norm always has scaling enabled. - (nn_ops.relu6, 'Relu6', False, True, True), - (nn_ops.relu, 'Relu', False, True, True), - (nn_ops.relu6, 'Relu6', True, True, True), - (nn_ops.relu, 'Relu', True, True, True), + (nn_ops.relu6, 'Relu6', False, True, True, None), + (nn_ops.relu, 'Relu', False, True, True, 100), + (nn_ops.relu6, 'Relu6', True, True, True, None), + (nn_ops.relu, 'Relu', True, True, True, 100), ] for params in parameters_list: - test_fn(params[0], params[1], params[2], params[3], params[4]) + test_fn(params[0], params[1], params[2], params[3], params[4], params[5]) def _TestFoldConv2d(self, relu, relu_op_name, with_bypass, has_scaling, - fused_batch_norm): + fused_batch_norm, freeze_batch_norm_delay): """Tests folding cases: inputs -> Conv2d with batch norm -> Relu*. Args: @@ -75,6 +76,8 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): inputs to just before Relu*. has_scaling: Bool, when true the batch norm has scaling. fused_batch_norm: Bool, when true the batch norm is fused. + freeze_batch_norm_delay: None or the number of steps after which training + switches to using frozen mean and variance """ g = ops.Graph() with g.as_default(): @@ -99,12 +102,13 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): node = math_ops.add(inputs, node, name='test/Add') relu(node, name='test/' + relu_op_name) - fold_batch_norms.FoldBatchNorms(g) + fold_batch_norms.FoldBatchNorms( + g, is_training=True, freeze_batch_norm_delay=freeze_batch_norm_delay) folded_mul = g.get_operation_by_name(scope + '/mul_fold') self.assertEqual(folded_mul.type, 'Mul') self._AssertInputOpsAre(folded_mul, [ - scope + '/weights/read', + scope + '/correction_mult', self._BatchNormMultiplierName(scope, has_scaling, fused_batch_norm) ]) self._AssertOutputGoesToOps(folded_mul, g, [scope + '/Conv2D_Fold']) @@ -113,12 +117,12 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): self.assertEqual(folded_conv.type, 'Conv2D') self._AssertInputOpsAre(folded_conv, [scope + '/mul_fold', inputs.op.name]) - self._AssertOutputGoesToOps(folded_conv, g, [scope + '/add_fold']) + self._AssertOutputGoesToOps(folded_conv, g, [scope + '/post_conv_mul']) folded_add = g.get_operation_by_name(scope + '/add_fold') self.assertEqual(folded_add.type, 'Add') self._AssertInputOpsAre(folded_add, [ - scope + '/Conv2D_Fold', + scope + '/correction_add', self._BathNormBiasName(scope, fused_batch_norm) ]) output_op_names = ['test/Add' if with_bypass else 'test/' + relu_op_name] @@ -128,7 +132,8 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): self._RunTestOverParameters(self._TestFoldConv2d) def _TestFoldConv2dUnknownShape(self, relu, relu_op_name, with_bypass, - has_scaling, fused_batch_norm): + has_scaling, fused_batch_norm, + freeze_batch_norm_delay): """Tests folding cases: inputs -> Conv2d with batch norm -> Relu*. Tests that folding works even with an input shape where some dimensions are @@ -141,6 +146,8 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): inputs to just before Relu*. has_scaling: Bool, when true the batch norm has scaling. fused_batch_norm: Bool, when true the batch norm is fused. + freeze_batch_norm_delay: None or the number of steps after which training + switches to using frozen mean and variance """ g = ops.Graph() with g.as_default(): @@ -164,12 +171,13 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): node = math_ops.add(inputs, node, name='test/Add') relu(node, name='test/' + relu_op_name) - fold_batch_norms.FoldBatchNorms(g) + fold_batch_norms.FoldBatchNorms( + g, is_training=True, freeze_batch_norm_delay=freeze_batch_norm_delay) folded_mul = g.get_operation_by_name(scope + '/mul_fold') self.assertEqual(folded_mul.type, 'Mul') self._AssertInputOpsAre(folded_mul, [ - scope + '/weights/read', + scope + '/correction_mult', self._BatchNormMultiplierName(scope, has_scaling, fused_batch_norm) ]) self._AssertOutputGoesToOps(folded_mul, g, [scope + '/Conv2D_Fold']) @@ -177,12 +185,12 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): folded_conv = g.get_operation_by_name(scope + '/Conv2D_Fold') self.assertEqual(folded_conv.type, 'Conv2D') self._AssertInputOpsAre(folded_conv, [scope + '/mul_fold', inputs.op.name]) - self._AssertOutputGoesToOps(folded_conv, g, [scope + '/add_fold']) + self._AssertOutputGoesToOps(folded_conv, g, [scope + '/post_conv_mul']) folded_add = g.get_operation_by_name(scope + '/add_fold') self.assertEqual(folded_add.type, 'Add') self._AssertInputOpsAre(folded_add, [ - scope + '/Conv2D_Fold', + scope + '/correction_add', self._BathNormBiasName(scope, fused_batch_norm) ]) output_op_names = ['test/Add' if with_bypass else 'test/' + relu_op_name] @@ -192,7 +200,8 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): self._RunTestOverParameters(self._TestFoldConv2dUnknownShape) def _TestFoldFullyConnectedLayer(self, relu, relu_op_name, with_bypass, - has_scaling, fused_batch_norm): + has_scaling, fused_batch_norm, + freeze_batch_norm_delay): """Tests folding cases: inputs -> FC with batch norm -> Relu*. Args: @@ -202,6 +211,8 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): inputs to just before Relu*. has_scaling: Bool, when true the batch norm has scaling. fused_batch_norm: Bool, when true the batch norm is fused. + freeze_batch_norm_delay: None or the number of steps after which training + switches to using frozen mean and variance """ g = ops.Graph() with g.as_default(): @@ -223,12 +234,13 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): node = math_ops.add(inputs, node, name='test/Add') relu(node, name='test/' + relu_op_name) - fold_batch_norms.FoldBatchNorms(g) + fold_batch_norms.FoldBatchNorms( + g, is_training=True, freeze_batch_norm_delay=freeze_batch_norm_delay) folded_mul = g.get_operation_by_name(scope + '/mul_fold') self.assertEqual(folded_mul.type, 'Mul') self._AssertInputOpsAre(folded_mul, [ - scope + '/weights/read', + scope + '/correction_mult', self._BatchNormMultiplierName(scope, has_scaling, fused_batch_norm) ]) self._AssertOutputGoesToOps(folded_mul, g, [scope + '/MatMul_Fold']) @@ -237,12 +249,12 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): self.assertEqual(folded_conv.type, 'MatMul') self._AssertInputOpsAre(folded_conv, [scope + '/mul_fold', inputs.op.name]) - self._AssertOutputGoesToOps(folded_conv, g, [scope + '/add_fold']) + self._AssertOutputGoesToOps(folded_conv, g, [scope + '/post_conv_mul']) folded_add = g.get_operation_by_name(scope + '/add_fold') self.assertEqual(folded_add.type, 'Add') self._AssertInputOpsAre(folded_add, [ - scope + '/MatMul_Fold', + scope + '/correction_add', self._BathNormBiasName(scope, fused_batch_norm) ]) output_op_names = ['test/Add' if with_bypass else 'test/' + relu_op_name] @@ -252,7 +264,8 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): self._RunTestOverParameters(self._TestFoldFullyConnectedLayer) def _TestFoldDepthwiseConv2d(self, relu, relu_op_name, with_bypass, - has_scaling, fused_batch_norm): + has_scaling, fused_batch_norm, + freeze_batch_norm_delay): """Tests folding: inputs -> DepthwiseConv2d with batch norm -> Relu*. Args: @@ -262,6 +275,8 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): inputs to just before Relu*. has_scaling: Bool, when true the batch norm has scaling. fused_batch_norm: Bool, when true the batch norm is fused. + freeze_batch_norm_delay: None or the number of steps after which training + switches to using frozen mean and variance """ g = ops.Graph() with g.as_default(): @@ -286,7 +301,8 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): node = math_ops.add(inputs, node, name='test/Add') relu(node, name='test/' + relu_op_name) - fold_batch_norms.FoldBatchNorms(g) + fold_batch_norms.FoldBatchNorms( + g, is_training=True, freeze_batch_norm_delay=freeze_batch_norm_delay) folded_mul = g.get_operation_by_name(scope + '/mul_fold') self.assertEqual(folded_mul.type, 'Mul') @@ -295,8 +311,7 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): else: scale_reshape_op_name = scope + '/scale_reshape' self._AssertInputOpsAre(folded_mul, - [scope + '/depthwise_weights/read', - scale_reshape_op_name]) + [scope + '/correction_mult', scale_reshape_op_name]) self._AssertOutputGoesToOps(folded_mul, g, [scope + '/depthwise_Fold']) scale_reshape = g.get_operation_by_name(scale_reshape_op_name) @@ -311,12 +326,12 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): self.assertEqual(folded_conv.type, 'DepthwiseConv2dNative') self._AssertInputOpsAre(folded_conv, [scope + '/mul_fold', inputs.op.name]) - self._AssertOutputGoesToOps(folded_conv, g, [scope + '/add_fold']) + self._AssertOutputGoesToOps(folded_conv, g, [scope + '/post_conv_mul']) folded_add = g.get_operation_by_name(scope + '/add_fold') self.assertEqual(folded_add.type, 'Add') self._AssertInputOpsAre(folded_add, [ - scope + '/depthwise_Fold', + scope + '/correction_add', self._BathNormBiasName(scope, fused_batch_norm) ]) output_op_names = ['test/Add' if with_bypass else 'test/' + relu_op_name] @@ -326,7 +341,8 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): self._RunTestOverParameters(self._TestFoldDepthwiseConv2d) def _TestCompareFoldAndUnfolded(self, relu, relu_op_name, with_bypass, - has_scaling, fused_batch_norm): + has_scaling, fused_batch_norm, + freeze_batch_norm_delay): """Tests that running folded and unfolded BN returns the same results. Args: @@ -336,6 +352,8 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): inputs to just before Relu*. has_scaling: Bool, when true the batch norm has scaling. fused_batch_norm: Bool, when true the batch norm is fused. + freeze_batch_norm_delay: None or the number of steps after which training + switches to using frozen mean and variance """ random_seed.set_random_seed(1234) unfolded_g = ops.Graph() @@ -361,11 +379,12 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): if with_bypass: node = math_ops.add(inputs, node, name='test/Add') relu_node = relu(node, name='test/' + relu_op_name) - folded_g = copy_graph.CopyGraph(unfolded_g) with folded_g.as_default(): - fold_batch_norms.FoldBatchNorms(folded_g) - + fold_batch_norms.FoldBatchNorms( + folded_g, + is_training=True, + freeze_batch_norm_delay=freeze_batch_norm_delay) with session.Session(graph=unfolded_g) as sess: sess.run(variables.global_variables_initializer()) grad_node = gradients.gradients(relu_node, inputs) diff --git a/tensorflow/contrib/quantize/python/graph_matcher.py b/tensorflow/contrib/quantize/python/graph_matcher.py index e3581cc55905a0af7d0464bc0ec673d3ed7f0363..b458f039df0523b5b8b07cff7d14643154124b95 100644 --- a/tensorflow/contrib/quantize/python/graph_matcher.py +++ b/tensorflow/contrib/quantize/python/graph_matcher.py @@ -18,8 +18,19 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import abc -class OpTypePattern(object): + +class Pattern(object): + """The parent class of all patterns (e.g. OpTypePattern and OneofPattern).""" + + @abc.abstractmethod + def match(self, op, tensor): + """Returns the result of matching op/tensor against this pattern.""" + raise NotImplementedError('Method "match" not implemented.') + + +class OpTypePattern(Pattern): """A tree pattern that matches TF expressions with certain op types.""" def __init__(self, op_type, name=None, inputs=None): @@ -34,7 +45,7 @@ class OpTypePattern(object): similar TF op types. name: Optional string. The name of the pattern that can be looked up in MatchResult. - inputs: Optional list of `OpTypePattern`s or strings that specify the + inputs: Optional list of `Pattern`s or strings that specify the patterns for the inputs of a matching op. If None, this pattern accepts any inputs of a matching op. """ @@ -43,22 +54,51 @@ class OpTypePattern(object): if inputs is None: inputs = [] self._inputs = [ - input_pattern if isinstance(input_pattern, OpTypePattern) else - OpTypePattern(input_pattern) for input_pattern in inputs + input_pattern + if isinstance(input_pattern, Pattern) else OpTypePattern(input_pattern) + for input_pattern in inputs ] - @property - def op_type(self): - return self._op_type - - @property - def inputs(self): - return self._inputs - @property def name(self): return self._name + def match(self, op, tensor): + if self._op_type != '*': + if op.type not in self._op_type.split('|'): + return None + + match_result = MatchResult() + match_result.add(self, op, tensor) + + if not self._inputs: + # If pattern.inputs is empty, skips the rest and accepts all the inputs. + return match_result + + if len(op.inputs) != len(self._inputs): + return None + + for input_tensor, input_pattern in zip(op.inputs, self._inputs): + input_match_result = input_pattern.match(input_tensor.op, input_tensor) + if input_match_result is None: + return None + match_result.merge_from(input_match_result) + return match_result + + +class OneofPattern(Pattern): + """Matches one of the given sub-patterns.""" + + def __init__(self, sub_patterns): + self._sub_patterns = sub_patterns + + def match(self, op, tensor): + for sub_pattern in self._sub_patterns: + match_result = sub_pattern.match(op, tensor) + if match_result is not None: + return match_result + return None + class MatchResult(object): r"""Encapsulates the result of a match done by GraphMatcher. @@ -102,16 +142,36 @@ class MatchResult(object): return pattern_or_name if isinstance(pattern_or_name, str): + if pattern_or_name not in self._name_to_pattern: + return None return self._name_to_pattern[pattern_or_name] raise ValueError('pattern_or_name has type %s. Expect OpTypePattern or str.' % type(pattern_or_name)) + def _get_op_tensor(self, pattern_or_name): + pattern = self._to_pattern(pattern_or_name) + if pattern is None: + return None + + if pattern not in self._pattern_to_op_tensor: + return None + + return self._pattern_to_op_tensor[pattern] + def get_op(self, pattern_or_name): - return self._pattern_to_op_tensor[self._to_pattern(pattern_or_name)][0] + op_tensor = self._get_op_tensor(pattern_or_name) + return op_tensor[0] if op_tensor else None def get_tensor(self, pattern_or_name): - return self._pattern_to_op_tensor[self._to_pattern(pattern_or_name)][1] + op_tensor = self._get_op_tensor(pattern_or_name) + return op_tensor[1] if op_tensor else None + + def merge_from(self, other_match_result): + # pylint: disable=protected-access + self._pattern_to_op_tensor.update(other_match_result._pattern_to_op_tensor) + self._name_to_pattern.update(other_match_result._name_to_pattern) + # pylint: enable=protected-access class GraphMatcher(object): @@ -121,7 +181,7 @@ class GraphMatcher(object): """Initializes a GraphMatcher. Args: - pattern: The `OpTypePattern` against which `GraphMatcher` matches + pattern: The `Pattern` against which `GraphMatcher` matches subgraphs. """ self._pattern = pattern @@ -133,7 +193,7 @@ class GraphMatcher(object): with key `pattern`. Args: - pattern: An `OpTypePattern`. + pattern: An `Pattern`. op: A `tf.Operation` to match against the pattern. tensor: the output `tf.Tensor` of `op` that is used by the matching op of `pattern`'s parent. Can be None if `pattern` is already the root of the @@ -142,20 +202,11 @@ class GraphMatcher(object): Returns: True if an TF expression rooted at `op` matches `pattern`. """ - if pattern.op_type != '*': - if op.type not in pattern.op_type.split('|'): - return False - - self._match_result.add(pattern, op, tensor) - - if not pattern.inputs: - # If pattern.inputs is empty, skips the rest and accepts all the inputs. - return True - - return len(op.inputs) == len(pattern.inputs) and all([ - self._match_pattern(input_pattern, input_tensor.op, input_tensor) - for input_tensor, input_pattern in zip(op.inputs, pattern.inputs) - ]) + match_result = pattern.match(op, tensor) + if match_result is None: + return False + self._match_result.merge_from(match_result) + return True def match_op(self, op): """Matches `op` against `self._pattern`. diff --git a/tensorflow/contrib/quantize/python/graph_matcher_test.py b/tensorflow/contrib/quantize/python/graph_matcher_test.py index e1572865e423e569ee3b280036c0e02b71b70648..6d587572181c125faa02d36fb54933cff24f11c6 100644 --- a/tensorflow/contrib/quantize/python/graph_matcher_test.py +++ b/tensorflow/contrib/quantize/python/graph_matcher_test.py @@ -105,7 +105,7 @@ class GraphMatcherTest(test_util.TensorFlowTestCase): self.assertEqual(match_result.get_op(y1_pattern), y1.op) self.assertEqual(match_result.get_tensor(y1_pattern), y1) - def test_oneof_pattern(self): + def test_oneof_type_pattern(self): # - + # / \ / \ # x y z @@ -125,6 +125,44 @@ class GraphMatcherTest(test_util.TensorFlowTestCase): for match_result in matcher.match_graph(g) ], [plus.op, minus.op]) + def test_oneof_pattern(self): + reshape_pattern = graph_matcher.OpTypePattern('Reshape') + transpose_pattern = graph_matcher.OneofPattern([ + graph_matcher.OpTypePattern( + 'Transpose', + name='transpose', + inputs=[ + graph_matcher.OpTypePattern( + 'Slice', name='slice', inputs=[reshape_pattern, '*', '*']), + '*' + ]), + graph_matcher.OpTypePattern( + 'Transpose', name='transpose', inputs=[reshape_pattern, '*']) + ]) + + matcher = graph_matcher.GraphMatcher(transpose_pattern) + + g = ops.Graph() + with g.as_default(): + inputs = array_ops.placeholder(dtypes.float32, shape=[6]) + reshape = array_ops.reshape(inputs, [2, 3]) + transpose = array_ops.transpose(reshape) + [match_result] = list(matcher.match_graph(g)) + self.assertEqual(match_result.get_tensor(reshape_pattern), reshape) + self.assertEqual(match_result.get_tensor('slice'), None) + self.assertEqual(match_result.get_op('transpose'), transpose.op) + + g = ops.Graph() + with g.as_default(): + inputs = array_ops.placeholder(dtypes.float32, shape=[6]) + reshape = array_ops.reshape(inputs, [2, 3]) + slicing = array_ops.slice(reshape, [0, 0], [-1, -1]) + transpose = array_ops.transpose(slicing) + [match_result] = list(matcher.match_graph(g)) + self.assertEqual(match_result.get_tensor(reshape_pattern), reshape) + self.assertEqual(match_result.get_tensor('slice'), slicing) + self.assertEqual(match_result.get_op('transpose'), transpose.op) + if __name__ == '__main__': googletest.main() diff --git a/tensorflow/contrib/quantize/python/quantize_graph.py b/tensorflow/contrib/quantize/python/quantize_graph.py index bbd9743d8014ce495a4967e7484981f7e60ae4a3..89b744c559170e7d9e502d3d8610afaca2c549b7 100644 --- a/tensorflow/contrib/quantize/python/quantize_graph.py +++ b/tensorflow/contrib/quantize/python/quantize_graph.py @@ -52,9 +52,19 @@ def _create_graph(input_graph, """ # TODO(suharshs): Describe the process in more detail in the doc string. g = copy_graph.CopyGraph(input_graph) + if is_training: + # TODO(raghuramank): Need to make freeze_batch_norm_delay + # a function of the batch size. For now setting this to 250 epochs + # This corresponds to 5 million steps at a batch size of 64. + freeze_batch_norm_delay = 5000000 + else: + freeze_batch_norm_delay = None with g.as_default(): with ops.device(device_name_or_function): - fold_batch_norms.FoldBatchNorms(g) + fold_batch_norms.FoldBatchNorms( + g, + freeze_batch_norm_delay=freeze_batch_norm_delay, + is_training=is_training) quantize.Quantize(g, is_training=is_training) if elements is None: return g diff --git a/tensorflow/contrib/receptive_field/python/util/parse_layer_parameters.py b/tensorflow/contrib/receptive_field/python/util/parse_layer_parameters.py index 69188a461b353e682807a1630eaa044519ac1b1d..bc383a803496380aaba4d0248d2b7f93253b2b50 100644 --- a/tensorflow/contrib/receptive_field/python/util/parse_layer_parameters.py +++ b/tensorflow/contrib/receptive_field/python/util/parse_layer_parameters.py @@ -40,10 +40,14 @@ def _stride_size(node, name_to_node): Args: node: Tensorflow node (NodeDef proto). + name_to_node: For MaxPoolV2, mapping from variable name Tensorflow node. Returns: stride_x: Stride size for horizontal direction (integer). stride_y: Stride size for vertical direction (integer). + + Raises: + ValueError: If stride input cannot be found in `name_to_node`. """ if node.op == "MaxPoolV2": strides_input_name = node.input[2] @@ -159,6 +163,7 @@ def _pool_kernel_size(node, name_to_node): Args: node: Tensorflow node (NodeDef proto). + name_to_node: For MaxPoolV2, mapping from node name to NodeDef. Returns: kernel_size_x: Kernel size for horizontal direction (integer). diff --git a/tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops.cc b/tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops.cc index c33804906fc21cf2573b79091a76ab1ea86f5966..2def4f3f176b8d4d26c2c94168e9698f14649d94 100644 --- a/tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops.cc +++ b/tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops.cc @@ -15,8 +15,8 @@ limitations under the License. #define EIGEN_USE_THREADS -#include #include "tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops.h" +#include #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" diff --git a/tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops.h b/tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops.h index 9bb1724a2c0b70ee7ce7238cc179aded95935b26..d8c0a0631d38e55ef9653e0e88e90604ec0f0329 100644 --- a/tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops.h +++ b/tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops.h @@ -16,10 +16,10 @@ limitations under the License. #ifndef TENSORFLOW_CORE_KERNELS_PARTIAL_REDUCTION_OPS_H_ #define TENSORFLOW_CORE_KERNELS_PARTIAL_REDUCTION_OPS_H_ +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_types.h" -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #define Sum(a, b) ((a) + (b)) #define Prod(a, b) ((a) * (b)) @@ -58,11 +58,11 @@ inline T negative_infinity() { } // namespace reduce_functions -#define CALL_ALL_REDUCEOPS(func, ...) \ - func(Sum, functor::reduce_functions::zero, ##__VA_ARGS__) \ - func(Prod, functor::reduce_functions::one, ##__VA_ARGS__) \ - func(Max, functor::reduce_functions::negative_infinity, ##__VA_ARGS__) \ - func(Min, functor::reduce_functions::infinity, ##__VA_ARGS__) +#define CALL_ALL_REDUCEOPS(func, ...) \ + func(Sum, functor::reduce_functions::zero, ##__VA_ARGS__) \ + func(Prod, functor::reduce_functions::one, ##__VA_ARGS__) func( \ + Max, functor::reduce_functions::negative_infinity, ##__VA_ARGS__) \ + func(Min, functor::reduce_functions::infinity, ##__VA_ARGS__) #define ReduceSliceFunctorReduceop(reduceop, dummy) \ template \ diff --git a/tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops_gpu.cu.cc b/tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops_gpu.cu.cc index 501cddb8c8f4f263aae45e83538af8ee782a935c..9f2be03d718364058da6b63add8752c046798c5b 100644 --- a/tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops_gpu.cu.cc +++ b/tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops_gpu.cu.cc @@ -17,10 +17,10 @@ limitations under the License. #define EIGEN_USE_GPU +#include "tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" -#include "tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops.h" #include "tensorflow/core/util/cuda_kernel_helper.h" namespace tensorflow { diff --git a/tensorflow/contrib/reduce_slice_ops/ops/reduce_slice_ops.cc b/tensorflow/contrib/reduce_slice_ops/ops/reduce_slice_ops.cc index 31e565027f8d229c1268f2d55aec5d7a9074704c..92879ab5356623dfa82fce8dff8db4d3036ae46c 100644 --- a/tensorflow/contrib/reduce_slice_ops/ops/reduce_slice_ops.cc +++ b/tensorflow/contrib/reduce_slice_ops/ops/reduce_slice_ops.cc @@ -246,9 +246,9 @@ and 'indices' is [[0,1] [1,1] [0,2]], -the the output will be [[ 1, 20, 3] - [ +BIG_VALUE, +BIG_VALUE, +BIG_VALUE] - [ 1, 5, 3]]. +the output will be [[ 1, 20, 3] + [ +BIG_VALUE, +BIG_VALUE, +BIG_VALUE] + [ 1, 5, 3]]. ``` The data must be at least rank 1. The indices can be of shape (?,2) where the diff --git a/tensorflow/contrib/resampler/kernels/resampler_ops.cc b/tensorflow/contrib/resampler/kernels/resampler_ops.cc index e02c1b6a2bd9daf9e1f81059f7c1f92106cebc8f..63c72836d793a3df4e96a0134f3a1534c288c8c8 100644 --- a/tensorflow/contrib/resampler/kernels/resampler_ops.cc +++ b/tensorflow/contrib/resampler/kernels/resampler_ops.cc @@ -36,17 +36,12 @@ using GPUDevice = Eigen::GpuDevice; namespace functor { template -struct Resampler2DFunctor{ - void operator ()(::tensorflow::OpKernelContext* ctx, - const CPUDevice& d, - const T* __restrict__ data, - const T* __restrict__ warp, - T* __restrict__ output, - const int batch_size, - const int data_height, - const int data_width, - const int data_channels, - const int num_sampling_points){ +struct Resampler2DFunctor { + void operator()(::tensorflow::OpKernelContext* ctx, const CPUDevice& d, + const T* __restrict__ data, const T* __restrict__ warp, + T* __restrict__ output, const int batch_size, + const int data_height, const int data_width, + const int data_channels, const int num_sampling_points) { const int warp_batch_stride = num_sampling_points * 2; const int data_batch_stride = data_height * data_width * data_channels; const int output_batch_stride = num_sampling_points * data_channels; @@ -59,24 +54,19 @@ struct Resampler2DFunctor{ // The functions take care of performing the relevant pointer // arithmetics abstracting away the low level details in the // main loop over samples. Note that data is stored in NHWC format. - auto set_output = [&](const int sample_id, - const int channel, + auto set_output = [&](const int sample_id, const int channel, const T value) { - output[batch_id * output_batch_stride + - sample_id * data_channels + + output[batch_id * output_batch_stride + sample_id * data_channels + channel] = value; }; - auto get_data_point = [&](const int x, - const int y, - const int chan) { + auto get_data_point = [&](const int x, const int y, const int chan) { const bool point_is_in_range = (x >= 0 && y >= 0 && x <= data_width - 1 && y <= data_height - 1); return point_is_in_range - ? data[batch_id * data_batch_stride + - data_channels * (y * data_width + x) + - chan] - : zero; + ? data[batch_id * data_batch_stride + + data_channels * (y * data_width + x) + chan] + : zero; }; for (int sample_id = 0; sample_id < num_sampling_points; ++sample_id) { @@ -89,8 +79,7 @@ struct Resampler2DFunctor{ // The effect is that the sampled signal smoothly goes to 0 outside // the original input domain, rather than presenting a jump // discontinuity at the image boundaries. - if (x > static_cast(-1.0) && - y > static_cast(-1.0) && + if (x > static_cast(-1.0) && y > static_cast(-1.0) && x < static_cast(data_width) && y < static_cast(data_height)) { // Precompute floor (f) and ceil (c) values for x and y. @@ -103,12 +92,10 @@ struct Resampler2DFunctor{ for (int chan = 0; chan < data_channels; ++chan) { const T img_fxfy = dx * dy * get_data_point(fx, fy, chan); - const T img_cxcy = (one - dx) * (one - dy) * - get_data_point(cx, cy, chan); - const T img_fxcy = dx * (one - dy) * - get_data_point(fx, cy, chan); - const T img_cxfy = (one - dx) * dy * - get_data_point(cx, fy, chan); + const T img_cxcy = + (one - dx) * (one - dy) * get_data_point(cx, cy, chan); + const T img_fxcy = dx * (one - dy) * get_data_point(fx, cy, chan); + const T img_cxfy = (one - dx) * dy * get_data_point(cx, fy, chan); set_output(sample_id, chan, img_fxfy + img_cxcy + img_fxcy + img_cxfy); } @@ -125,8 +112,8 @@ struct Resampler2DFunctor{ // estimate of the cost of each work unit is needed to correctly shard the // workload. Shard assumes each cost unit is 1ns, minimum cost per shard // being 10us. - const int64 cost = static_cast(num_sampling_points) * - data_channels * 1000; + const int64 cost = + static_cast(num_sampling_points) * data_channels * 1000; auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads()); ::tensorflow::Shard(worker_threads.num_threads, worker_threads.workers, batch_size, cost, resample_batches); @@ -138,8 +125,8 @@ struct Resampler2DFunctor{ template class ResamplerOp : public ::tensorflow::OpKernel { public: - explicit ResamplerOp(::tensorflow::OpKernelConstruction* context) : - ::tensorflow::OpKernel(context) {} + explicit ResamplerOp(::tensorflow::OpKernelConstruction* context) + : ::tensorflow::OpKernel(context) {} void Compute(::tensorflow::OpKernelContext* ctx) override { const ::tensorflow::Tensor& data = ctx->input(0); @@ -158,16 +145,17 @@ class ResamplerOp : public ::tensorflow::OpKernel { ::tensorflow::errors::InvalidArgument( "warp should be at least a matrix, got shape ", warp_shape.DebugString())); - OP_REQUIRES(ctx, warp_shape.dim_size(warp_shape.dims()-1) == 2, + OP_REQUIRES(ctx, warp_shape.dim_size(warp_shape.dims() - 1) == 2, ::tensorflow::errors::Unimplemented( "Only bilinear interpolation is supported, warping " "coordinates must be 2D; warp shape last entry should be " - "2, but shape vector is: ", warp_shape.DebugString())); + "2, but shape vector is: ", + warp_shape.DebugString())); OP_REQUIRES(ctx, data_shape.dim_size(0) == warp_shape.dim_size(0), ::tensorflow::errors::InvalidArgument( "Batch size of data and warp tensor must be the same, but " - "input shapes are: ", data_shape.DebugString(), ", ", - warp_shape.DebugString())); + "input shapes are: ", + data_shape.DebugString(), ", ", warp_shape.DebugString())); const int batch_size = data_shape.dim_size(0); const int data_height = data_shape.dim_size(1); const int data_width = data_shape.dim_size(2); @@ -180,16 +168,10 @@ class ResamplerOp : public ::tensorflow::OpKernel { // Execute kernel only for nonempty output; otherwise Eigen crashes on GPU. if (num_sampling_points > 0) { - functor::Resampler2DFunctor()(ctx, - ctx->eigen_device(), - data.flat().data(), - warp.flat().data(), - output->flat().data(), - batch_size, - data_height, - data_width, - data_channels, - num_sampling_points); + functor::Resampler2DFunctor()( + ctx, ctx->eigen_device(), data.flat().data(), + warp.flat().data(), output->flat().data(), batch_size, + data_height, data_width, data_channels, num_sampling_points); } } @@ -197,12 +179,9 @@ class ResamplerOp : public ::tensorflow::OpKernel { TF_DISALLOW_COPY_AND_ASSIGN(ResamplerOp); }; - -#define REGISTER(TYPE) \ - REGISTER_KERNEL_BUILDER( \ - Name("Resampler") \ - .Device(DEVICE_CPU) \ - .TypeConstraint("T"), \ +#define REGISTER(TYPE) \ + REGISTER_KERNEL_BUILDER( \ + Name("Resampler").Device(DEVICE_CPU).TypeConstraint("T"), \ ResamplerOp); TF_CALL_half(REGISTER); @@ -211,40 +190,32 @@ TF_CALL_double(REGISTER); #undef REGISTER #if GOOGLE_CUDA -#define REGISTER(TYPE) \ - REGISTER_KERNEL_BUILDER(Name("Resampler") \ - .Device(DEVICE_GPU) \ - .TypeConstraint("T"), \ - ResamplerOp) +#define REGISTER(TYPE) \ + REGISTER_KERNEL_BUILDER( \ + Name("Resampler").Device(DEVICE_GPU).TypeConstraint("T"), \ + ResamplerOp) TF_CALL_float(REGISTER); TF_CALL_double(REGISTER); #undef REGISTER #endif // GOOGLE_CUDA - namespace functor { template -struct ResamplerGrad2DFunctor{ - void operator ()(::tensorflow::OpKernelContext* ctx, - const CPUDevice& d, - const T* __restrict__ data, - const T* __restrict__ warp, - const T* __restrict__ grad_output, - T* __restrict__ grad_data, - T* __restrict__ grad_warp, - const int batch_size, - const int data_height, - const int data_width, - const int data_channels, - const int num_sampling_points){ +struct ResamplerGrad2DFunctor { + void operator()(::tensorflow::OpKernelContext* ctx, const CPUDevice& d, + const T* __restrict__ data, const T* __restrict__ warp, + const T* __restrict__ grad_output, T* __restrict__ grad_data, + T* __restrict__ grad_warp, const int batch_size, + const int data_height, const int data_width, + const int data_channels, const int num_sampling_points) { // Set gradients to 0, because the kernel incrementally updates the // tensor entries by adding partial contributions. - const int resampler_output_size = batch_size * num_sampling_points * - data_channels; + const int resampler_output_size = + batch_size * num_sampling_points * data_channels; const int grad_warp_size = resampler_output_size / data_channels * 2; - const int grad_data_size = data_height * data_width * data_channels * - batch_size; + const int grad_data_size = + data_height * data_width * data_channels * batch_size; memset(grad_data, 0, sizeof(T) * grad_data_size); memset(grad_warp, 0, sizeof(T) * grad_warp_size); @@ -260,35 +231,29 @@ struct ResamplerGrad2DFunctor{ // The functions take care of performing the relevant pointer // arithmetics abstracting away the low level details in the // main loop over samples. Note that data is stored in NHWC format. - auto get_data_point = [&](const int x, - const int y, - const int chan) { + auto get_data_point = [&](const int x, const int y, const int chan) { const bool point_is_in_range = - (x >= 0 && y >= 0 && x <= data_width - 1 && y <= data_height - 1); + (x >= 0 && y >= 0 && x <= data_width - 1 && y <= data_height - 1); return point_is_in_range - ? data[batch_id * data_batch_stride + - data_channels * (y * data_width + x) + - chan] - : zero; + ? data[batch_id * data_batch_stride + + data_channels * (y * data_width + x) + chan] + : zero; }; auto update_grad_data = [&](const int x, const int y, const int chan, const T value) { const bool point_is_in_range = (x >= 0 && y >= 0 && x <= data_width - 1 && y <= data_height - 1); - if (point_is_in_range){ + if (point_is_in_range) { grad_data[batch_id * data_batch_stride + - data_channels * (y * data_width + x) + - chan] += value; + data_channels * (y * data_width + x) + chan] += value; } }; - auto update_grad_warp = [&](const int sample_id, - const int channel, + auto update_grad_warp = [&](const int sample_id, const int channel, const T value) { - grad_warp[batch_id * warp_batch_stride + - sample_id * 2 + - channel] += value; + grad_warp[batch_id * warp_batch_stride + sample_id * 2 + channel] += + value; }; for (int sample_id = 0; sample_id < num_sampling_points; ++sample_id) { @@ -301,8 +266,7 @@ struct ResamplerGrad2DFunctor{ // The effect is that the sampled signal smoothly goes to 0 outside // the original input domain, rather than presenting a jump // discontinuity at the image boundaries. - if (x > static_cast(-1.0) && - y > static_cast(-1.0) && + if (x > static_cast(-1.0) && y > static_cast(-1.0) && x < static_cast(data_width) && y < static_cast(data_height)) { // Precompute floor (f) and ceil (c) values for x and y. @@ -316,27 +280,25 @@ struct ResamplerGrad2DFunctor{ for (int chan = 0; chan < data_channels; ++chan) { const T grad_output_value = grad_output[batch_id * output_batch_stride + - sample_id * data_channels + - chan]; + sample_id * data_channels + chan]; const T img_fxfy = get_data_point(fx, fy, chan); const T img_cxcy = get_data_point(cx, cy, chan); const T img_fxcy = get_data_point(fx, cy, chan); const T img_cxfy = get_data_point(cx, fy, chan); // Update partial gradients wrt relevant warp field entries - update_grad_warp(sample_id, 0, - grad_output_value * - ((one - dy) * (img_cxcy - img_fxcy) + - dy * (img_cxfy - img_fxfy))); + update_grad_warp( + sample_id, 0, + grad_output_value * ((one - dy) * (img_cxcy - img_fxcy) + + dy * (img_cxfy - img_fxfy))); - update_grad_warp(sample_id, 1, - grad_output_value * - ((one - dx) * (img_cxcy - img_cxfy) + - dx * (img_fxcy - img_fxfy))); + update_grad_warp( + sample_id, 1, + grad_output_value * ((one - dx) * (img_cxcy - img_cxfy) + + dx * (img_fxcy - img_fxfy))); // Update partial gradients wrt sampled data - update_grad_data(fx, fy, chan, - grad_output_value * dx * dy); + update_grad_data(fx, fy, chan, grad_output_value * dx * dy); update_grad_data(cx, cy, chan, grad_output_value * (one - dx) * (one - dy)); update_grad_data(fx, cy, chan, @@ -355,8 +317,8 @@ struct ResamplerGrad2DFunctor{ // being 10us. // TODO(fviola): Check out if there is a better way of doing this. auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads()); - const int64 cost = static_cast(num_sampling_points) * - data_channels * 1000; + const int64 cost = + static_cast(num_sampling_points) * data_channels * 1000; ::tensorflow::Shard(worker_threads.num_threads, worker_threads.workers, batch_size, cost, update_grads_for_batches); } @@ -364,12 +326,11 @@ struct ResamplerGrad2DFunctor{ } // namespace functor - template class ResamplerGradOp : public ::tensorflow::OpKernel { public: - explicit ResamplerGradOp(::tensorflow::OpKernelConstruction* context) : - ::tensorflow::OpKernel(context) {} + explicit ResamplerGradOp(::tensorflow::OpKernelConstruction* context) + : ::tensorflow::OpKernel(context) {} void Compute(::tensorflow::OpKernelContext* ctx) override { const ::tensorflow::Tensor& data = ctx->input(0); @@ -383,7 +344,7 @@ class ResamplerGradOp : public ::tensorflow::OpKernel { "tensor must be a batch of 2d data; data shape should have " "4 entries corresponding to [batch_size, data_height, " "data_width, data_channels], but is: ", - data_shape.DebugString())); + data_shape.DebugString())); const int batch_size = data_shape.dim_size(0); const int data_height = data_shape.dim_size(1); const int data_width = data_shape.dim_size(2); @@ -394,7 +355,7 @@ class ResamplerGradOp : public ::tensorflow::OpKernel { ::tensorflow::errors::InvalidArgument( "warp should be at least a matrix, got shape ", warp_shape.DebugString())); - OP_REQUIRES(ctx, warp_shape.dim_size(warp_shape.dims()-1) == 2, + OP_REQUIRES(ctx, warp_shape.dim_size(warp_shape.dims() - 1) == 2, ::tensorflow::errors::Unimplemented( "Only bilinear interpolation is supported, warping " "coordinates must be 2D; warp shape last entry should be " @@ -417,18 +378,11 @@ class ResamplerGradOp : public ::tensorflow::OpKernel { OP_REQUIRES_OK(ctx, ctx->allocate_output(1, warp.shape(), &grad_warp)); // Execute kernel only for nonempty output; otherwise Eigen crashes on GPU. if (num_sampling_points > 0) { - functor::ResamplerGrad2DFunctor()(ctx, - ctx->eigen_device(), - data.flat().data(), - warp.flat().data(), - grad_output.flat().data(), - grad_data->flat().data(), - grad_warp->flat().data(), - batch_size, - data_height, - data_width, - data_channels, - num_sampling_points); + functor::ResamplerGrad2DFunctor()( + ctx, ctx->eigen_device(), data.flat().data(), + warp.flat().data(), grad_output.flat().data(), + grad_data->flat().data(), grad_warp->flat().data(), batch_size, + data_height, data_width, data_channels, num_sampling_points); } } @@ -436,11 +390,9 @@ class ResamplerGradOp : public ::tensorflow::OpKernel { TF_DISALLOW_COPY_AND_ASSIGN(ResamplerGradOp); }; -#define REGISTER(TYPE) \ - REGISTER_KERNEL_BUILDER( \ - Name("ResamplerGrad") \ - .Device(DEVICE_CPU) \ - .TypeConstraint("T"), \ +#define REGISTER(TYPE) \ + REGISTER_KERNEL_BUILDER( \ + Name("ResamplerGrad").Device(DEVICE_CPU).TypeConstraint("T"), \ ResamplerGradOp); TF_CALL_half(REGISTER); @@ -449,11 +401,10 @@ TF_CALL_double(REGISTER); #undef REGISTER #if GOOGLE_CUDA -#define REGISTER(TYPE) \ - REGISTER_KERNEL_BUILDER(Name("ResamplerGrad") \ - .Device(DEVICE_GPU) \ - .TypeConstraint("T"), \ - ResamplerGradOp) +#define REGISTER(TYPE) \ + REGISTER_KERNEL_BUILDER( \ + Name("ResamplerGrad").Device(DEVICE_GPU).TypeConstraint("T"), \ + ResamplerGradOp) // Disable half and double precision since atomicAdds are not supported // TF_CALL_half(REGISTER); // TF_CALL_double(REGISTER); diff --git a/tensorflow/contrib/resampler/kernels/resampler_ops.h b/tensorflow/contrib/resampler/kernels/resampler_ops.h index 85d3676efac70fe9237d31c2be1fe75e67d70abd..7fe3b9c0df71f51e07d38ea15a672d79fdc70453 100644 --- a/tensorflow/contrib/resampler/kernels/resampler_ops.h +++ b/tensorflow/contrib/resampler/kernels/resampler_ops.h @@ -29,38 +29,25 @@ namespace functor { // Helper functor for the Resampler Op in 2D template -struct Resampler2DFunctor{ - void operator ()(::tensorflow::OpKernelContext* ctx, - const Device& d, - const T* __restrict__ data, - const T* __restrict__ warp, - T* __restrict__ output, - const int batch_size, - const int data_height, - const int data_width, - const int data_channels, - const int num_sampling_points); +struct Resampler2DFunctor { + void operator()(::tensorflow::OpKernelContext* ctx, const Device& d, + const T* __restrict__ data, const T* __restrict__ warp, + T* __restrict__ output, const int batch_size, + const int data_height, const int data_width, + const int data_channels, const int num_sampling_points); }; - // Helper functor for the Resampler Gradient Op in 2D template -struct ResamplerGrad2DFunctor{ - void operator ()(::tensorflow::OpKernelContext* ctx, - const Device& d, - const T* __restrict__ data, - const T* __restrict__ warp, - const T* __restrict__ grad_output, - T* __restrict__ grad_data, - T* __restrict__ grad_warp, - const int batch_size, - const int data_height, - const int data_width, - const int data_channels, - const int num_sampling_points); +struct ResamplerGrad2DFunctor { + void operator()(::tensorflow::OpKernelContext* ctx, const Device& d, + const T* __restrict__ data, const T* __restrict__ warp, + const T* __restrict__ grad_output, T* __restrict__ grad_data, + T* __restrict__ grad_warp, const int batch_size, + const int data_height, const int data_width, + const int data_channels, const int num_sampling_points); }; - } // namespace functor } // namespace tensorflow diff --git a/tensorflow/contrib/resampler/kernels/resampler_ops_gpu.cu.cc b/tensorflow/contrib/resampler/kernels/resampler_ops_gpu.cu.cc index 636847a212f27c738032128e3f3f653ec32f851b..3c07051f685c74b6e45fb782c80871f38dffbbf4 100644 --- a/tensorflow/contrib/resampler/kernels/resampler_ops_gpu.cu.cc +++ b/tensorflow/contrib/resampler/kernels/resampler_ops_gpu.cu.cc @@ -31,18 +31,15 @@ using GPUDevice = Eigen::GpuDevice; namespace { -#define GET_DATA_POINT(x, y) \ - data[batch_id * data_batch_stride + \ - data_channels * (y * data_width + x) + \ +#define GET_DATA_POINT(x, y) \ + data[batch_id * data_batch_stride + data_channels * (y * data_width + x) + \ chan] template __global__ void Resampler2DKernel(const T* __restrict__ data, const T* __restrict__ warp, - T* __restrict__ output, - const int batch_size, - const int data_height, - const int data_width, + T* __restrict__ output, const int batch_size, + const int data_height, const int data_width, const int data_channels, const int num_sampling_points) { const int output_data_size = batch_size * num_sampling_points * data_channels; @@ -75,10 +72,8 @@ __global__ void Resampler2DKernel(const T* __restrict__ data, // The effect is that the sampled signal smoothly goes to 0 outside // the original input domain, rather than presenting a jump // discontinuity at the image boundaries. - if (x > static_cast(-1.0) && - y > static_cast(-1.0) && - x < static_cast(data_width) && - y < static_cast(data_height)) { + if (x > static_cast(-1.0) && y > static_cast(-1.0) && + x < static_cast(data_width) && y < static_cast(data_height)) { // Precompute floor (f) and ceil (c) values for x and y. const int fx = std::floor(static_cast(x)); const int fy = std::floor(static_cast(y)); @@ -87,21 +82,20 @@ __global__ void Resampler2DKernel(const T* __restrict__ data, const T dx = static_cast(cx) - x; const T dy = static_cast(cy) - y; - const T img_fxfy = (fx >= 0 && fy >= 0) - ? dx * dy * GET_DATA_POINT(fx, fy) - : zero; + const T img_fxfy = + (fx >= 0 && fy >= 0) ? dx * dy * GET_DATA_POINT(fx, fy) : zero; const T img_cxcy = (cx <= data_width - 1 && cy <= data_height - 1) - ? (one - dx) * (one - dy) * GET_DATA_POINT(cx, cy) - : zero; + ? (one - dx) * (one - dy) * GET_DATA_POINT(cx, cy) + : zero; const T img_fxcy = (fx >= 0 && cy <= data_height - 1) - ? dx * (one - dy) * GET_DATA_POINT(fx, cy) - : zero; + ? dx * (one - dy) * GET_DATA_POINT(fx, cy) + : zero; const T img_cxfy = (cx <= data_width - 1 && fy >= 0) - ? (one - dx) * dy * GET_DATA_POINT(cx, fy) - : zero; + ? (one - dx) * dy * GET_DATA_POINT(cx, fy) + : zero; output[out_index] = img_fxfy + img_cxcy + img_fxcy + img_cxfy; } else { @@ -115,24 +109,20 @@ __global__ void Resampler2DKernel(const T* __restrict__ data, namespace functor { template -struct Resampler2DFunctor{ - void operator ()(::tensorflow::OpKernelContext* ctx, - const GPUDevice& d, - const T* __restrict__ data, - const T* __restrict__ warp, - T* __restrict__ output, - const int batch_size, - const int data_height, - const int data_width, - const int data_channels, - const int num_sampling_points) { - const int output_data_size = batch_size * num_sampling_points * data_channels; - ::tensorflow::CudaLaunchConfig config = - ::tensorflow::GetCudaLaunchConfig(output_data_size, d); - Resampler2DKernel - <<>>( - data, warp, output, batch_size, data_height, data_width, - data_channels, num_sampling_points); +struct Resampler2DFunctor { + void operator()(::tensorflow::OpKernelContext* ctx, const GPUDevice& d, + const T* __restrict__ data, const T* __restrict__ warp, + T* __restrict__ output, const int batch_size, + const int data_height, const int data_width, + const int data_channels, const int num_sampling_points) { + const int output_data_size = + batch_size * num_sampling_points * data_channels; + ::tensorflow::CudaLaunchConfig config = + ::tensorflow::GetCudaLaunchConfig(output_data_size, d); + Resampler2DKernel + <<>>( + data, warp, output, batch_size, data_height, data_width, + data_channels, num_sampling_points); } }; @@ -145,26 +135,20 @@ template struct Resampler2DFunctor; namespace { -#define UPDATE_GRAD_DATA_POINT(x, y, v) \ - atomicAdd(grad_data + (batch_id * data_batch_stride + \ - data_channels * (y * data_width + x) + \ - chan), \ +#define UPDATE_GRAD_DATA_POINT(x, y, v) \ + atomicAdd(grad_data + (batch_id * data_batch_stride + \ + data_channels * (y * data_width + x) + chan), \ v) - template -__global__ void ResamplerGrad2DKernel(const T* __restrict__ data, - const T* __restrict__ warp, - const T* __restrict__ grad_output, - T* __restrict__ grad_data, - T* __restrict__ grad_warp, - const int batch_size, - const int data_height, - const int data_width, - const int data_channels, - const int num_sampling_points) { - const int resampler_output_size = batch_size * num_sampling_points * - data_channels; +__global__ void ResamplerGrad2DKernel( + const T* __restrict__ data, const T* __restrict__ warp, + const T* __restrict__ grad_output, T* __restrict__ grad_data, + T* __restrict__ grad_warp, const int batch_size, const int data_height, + const int data_width, const int data_channels, + const int num_sampling_points) { + const int resampler_output_size = + batch_size * num_sampling_points * data_channels; CUDA_1D_KERNEL_LOOP(index, resampler_output_size) { const int out_index = index; @@ -199,10 +183,8 @@ __global__ void ResamplerGrad2DKernel(const T* __restrict__ data, // The effect is that the sampled signal smoothly goes to 0 outside // the original input domain, rather than presenting a jump // discontinuity at the image boundaries. - if (x > static_cast(-1.0) && - y > static_cast(-1.0) && - x < static_cast(data_width) && - y < static_cast(data_height)) { + if (x > static_cast(-1.0) && y > static_cast(-1.0) && + x < static_cast(data_width) && y < static_cast(data_height)) { // Precompute floor (f) and ceil (c) values for x and y. const int fx = std::floor(static_cast(x)); const int fy = std::floor(static_cast(y)); @@ -211,21 +193,17 @@ __global__ void ResamplerGrad2DKernel(const T* __restrict__ data, const T dx = static_cast(cx) - x; const T dy = static_cast(cy) - y; - const T img_fxfy = (fx >= 0 && fy >= 0) - ? GET_DATA_POINT(fx, fy) - : zero; + const T img_fxfy = (fx >= 0 && fy >= 0) ? GET_DATA_POINT(fx, fy) : zero; const T img_cxcy = (cx <= data_width - 1 && cy <= data_height - 1) - ? GET_DATA_POINT(cx, cy) - : zero; + ? GET_DATA_POINT(cx, cy) + : zero; - const T img_fxcy = (fx >= 0 && cy <= data_height - 1) - ? GET_DATA_POINT(fx, cy) - : zero; + const T img_fxcy = + (fx >= 0 && cy <= data_height - 1) ? GET_DATA_POINT(fx, cy) : zero; - const T img_cxfy = (cx <= data_width - 1 && fy >= 0) - ? GET_DATA_POINT(cx, fy) - : zero; + const T img_cxfy = + (cx <= data_width - 1 && fy >= 0) ? GET_DATA_POINT(cx, fy) : zero; // Update partial gradients wrt relevant warp field entries atomicAdd(grad_warp + warp_id_x, @@ -241,7 +219,7 @@ __global__ void ResamplerGrad2DKernel(const T* __restrict__ data, } if (cx <= data_width - 1 && cy <= data_height - 1) { UPDATE_GRAD_DATA_POINT(cx, cy, - grad_output_value * (one - dx) * (one - dy)); + grad_output_value * (one - dx) * (one - dy)); } if (fx >= 0 && cy <= data_height - 1) { UPDATE_GRAD_DATA_POINT(fx, cy, grad_output_value * dx * (one - dy)); @@ -261,43 +239,37 @@ __global__ void ResamplerGrad2DKernel(const T* __restrict__ data, namespace functor { template -struct ResamplerGrad2DFunctor{ - void operator ()(::tensorflow::OpKernelContext* ctx, - const GPUDevice& d, - const T* __restrict__ data, - const T* __restrict__ warp, - const T* __restrict__ grad_output, - T* __restrict__ grad_data, - T* __restrict__ grad_warp, - const int batch_size, - const int data_height, - const int data_width, - const int data_channels, - const int num_sampling_points) { - // Set gradients to 0, because the kernel incrementally updates the - // tensor entries by adding partial contributions. - const int grad_warp_size = batch_size * num_sampling_points * 2; - const int grad_data_size = batch_size * data_height * data_width * - data_channels; - - ::tensorflow::CudaLaunchConfig config = - ::tensorflow::GetCudaLaunchConfig(grad_warp_size, d); - ::tensorflow::SetZero - <<>>( - grad_warp_size, grad_warp); - - config = ::tensorflow::GetCudaLaunchConfig(grad_data_size, d); - ::tensorflow::SetZero - <<>>( - grad_data_size, grad_data); - - const int resampler_output_size = batch_size * num_sampling_points * - data_channels; - config = ::tensorflow::GetCudaLaunchConfig(resampler_output_size, d); - ResamplerGrad2DKernel - <<>>( - data, warp, grad_output, grad_data, grad_warp, batch_size, - data_height, data_width, data_channels, num_sampling_points); +struct ResamplerGrad2DFunctor { + void operator()(::tensorflow::OpKernelContext* ctx, const GPUDevice& d, + const T* __restrict__ data, const T* __restrict__ warp, + const T* __restrict__ grad_output, T* __restrict__ grad_data, + T* __restrict__ grad_warp, const int batch_size, + const int data_height, const int data_width, + const int data_channels, const int num_sampling_points) { + // Set gradients to 0, because the kernel incrementally updates the + // tensor entries by adding partial contributions. + const int grad_warp_size = batch_size * num_sampling_points * 2; + const int grad_data_size = + batch_size * data_height * data_width * data_channels; + + ::tensorflow::CudaLaunchConfig config = + ::tensorflow::GetCudaLaunchConfig(grad_warp_size, d); + ::tensorflow:: + SetZero<<>>( + grad_warp_size, grad_warp); + + config = ::tensorflow::GetCudaLaunchConfig(grad_data_size, d); + ::tensorflow:: + SetZero<<>>( + grad_data_size, grad_data); + + const int resampler_output_size = + batch_size * num_sampling_points * data_channels; + config = ::tensorflow::GetCudaLaunchConfig(resampler_output_size, d); + ResamplerGrad2DKernel + <<>>( + data, warp, grad_output, grad_data, grad_warp, batch_size, + data_height, data_width, data_channels, num_sampling_points); } }; diff --git a/tensorflow/contrib/rnn/kernels/blas_gemm.cc b/tensorflow/contrib/rnn/kernels/blas_gemm.cc index e62501e9b100484a7be3cc6ae0fc25905c0d0724..03006dab323a7c6dc83d9a17c035ef705f7b0366 100644 --- a/tensorflow/contrib/rnn/kernels/blas_gemm.cc +++ b/tensorflow/contrib/rnn/kernels/blas_gemm.cc @@ -36,11 +36,10 @@ perftools::gputools::DeviceMemory AsDeviceMemory(const T* cuda_memory) { namespace functor { template -void TensorCuBlasGemm::operator()(OpKernelContext* ctx, - bool transa, bool transb, uint64 m, - uint64 n, uint64 k, T alpha, const T* a, - int lda, const T* b, int ldb, T beta, T* c, - int ldc) { +void TensorCuBlasGemm::operator()(OpKernelContext* ctx, bool transa, + bool transb, uint64 m, uint64 n, uint64 k, + T alpha, const T* a, int lda, const T* b, + int ldb, T beta, T* c, int ldc) { #if GOOGLE_CUDA perftools::gputools::blas::Transpose trans[] = { perftools::gputools::blas::Transpose::kNoTranspose, diff --git a/tensorflow/contrib/rnn/kernels/gru_ops.cc b/tensorflow/contrib/rnn/kernels/gru_ops.cc index 0796f82b214620dd71d154fb8f8ec953dbcbb9ec..bd3d898fb09da0f490050c85b1e585502d8ecb2c 100644 --- a/tensorflow/contrib/rnn/kernels/gru_ops.cc +++ b/tensorflow/contrib/rnn/kernels/gru_ops.cc @@ -15,8 +15,8 @@ limitations under the License. #define EIGEN_USE_THREADS -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/contrib/rnn/kernels/gru_ops.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/op_kernel.h" namespace tensorflow { @@ -61,9 +61,9 @@ class GRUCellBlockOp : public OpKernel { h_prev_tensor->dim_size(0), " vs. ", batch_size)); OP_REQUIRES(ctx, h_prev_tensor->dim_size(1) == cell_size, - errors::InvalidArgument("h_prev.dims(1) != cell_size: ", - h_prev_tensor->dim_size(1), " vs. ", - cell_size)); + errors::InvalidArgument( + "h_prev.dims(1) != cell_size: ", h_prev_tensor->dim_size(1), + " vs. ", cell_size)); // Shape of 'w_ru' must be [input_size+cell_size, 2*cell_size] OP_REQUIRES(ctx, w_ru_tensor->dim_size(0) == input_size + cell_size, @@ -82,10 +82,10 @@ class GRUCellBlockOp : public OpKernel { "w_c.dim_size(0) != input_size + cell_size: ", w_c_tensor->dim_size(0), " vs. ", input_size + cell_size)); - OP_REQUIRES( - ctx, w_c_tensor->dim_size(1) == cell_size, - errors::InvalidArgument("w_c.dim_size(1) != cell_size: ", - w_c_tensor->dim_size(1), " vs. ", cell_size)); + OP_REQUIRES(ctx, w_c_tensor->dim_size(1) == cell_size, + errors::InvalidArgument( + "w_c.dim_size(1) != cell_size: ", w_c_tensor->dim_size(1), + " vs. ", cell_size)); // Shape of 'b_ru' must be [2*cell_size] OP_REQUIRES(ctx, b_ru_tensor->dim_size(0) == cell_size * 2, @@ -97,10 +97,10 @@ class GRUCellBlockOp : public OpKernel { errors::InvalidArgument("Rank of b_ru must be 1", b_ru_tensor->dims(), " vs. 1", 1)); // Shape of 'b_c' must be [cell_size] - OP_REQUIRES( - ctx, b_c_tensor->dim_size(0) == cell_size, - errors::InvalidArgument("b_c.dim_size(0) != cell_size: ", - b_c_tensor->dim_size(0), " vs. ", cell_size)); + OP_REQUIRES(ctx, b_c_tensor->dim_size(0) == cell_size, + errors::InvalidArgument( + "b_c.dim_size(0) != cell_size: ", b_c_tensor->dim_size(0), + " vs. ", cell_size)); OP_REQUIRES(ctx, b_c_tensor->dims() == 1, errors::InvalidArgument("Rank of b_c must be 1", b_c_tensor->dims(), " vs. 1")); @@ -216,9 +216,9 @@ class GRUBlockCellGradOp : public OpKernel { h_prev_tensor->dim_size(0), " vs. ", batch_size)); OP_REQUIRES(ctx, h_prev_tensor->dim_size(1) == cell_size, - errors::InvalidArgument("h_prev.dims(1) != cell_size: ", - h_prev_tensor->dim_size(1), " vs. ", - cell_size)); + errors::InvalidArgument( + "h_prev.dims(1) != cell_size: ", h_prev_tensor->dim_size(1), + " vs. ", cell_size)); // Shape of 'w_ru' must be [input_size+cell_size, 2*cell_size] OP_REQUIRES(ctx, w_ru_tensor->dim_size(0) == input_size + cell_size, @@ -237,10 +237,10 @@ class GRUBlockCellGradOp : public OpKernel { "w_c.dim_size(0) != input_size + cell_size: ", w_c_tensor->dim_size(0), " vs. ", input_size + cell_size)); - OP_REQUIRES( - ctx, w_c_tensor->dim_size(1) == cell_size, - errors::InvalidArgument("w_c.dim_size(1) != cell_size: ", - w_c_tensor->dim_size(1), " vs. ", cell_size)); + OP_REQUIRES(ctx, w_c_tensor->dim_size(1) == cell_size, + errors::InvalidArgument( + "w_c.dim_size(1) != cell_size: ", w_c_tensor->dim_size(1), + " vs. ", cell_size)); // Shape of 'b_ru' must be [2*cell_size] OP_REQUIRES(ctx, b_ru_tensor->dim_size(0) == cell_size * 2, @@ -253,54 +253,54 @@ class GRUBlockCellGradOp : public OpKernel { b_ru_tensor->dims(), " vs. 1")); // Shape of 'b_c' must be [cell_size] - OP_REQUIRES( - ctx, b_c_tensor->dim_size(0) == cell_size, - errors::InvalidArgument("b_c.dim_size(0) != cell_size: ", - b_c_tensor->dim_size(0), " vs. ", cell_size)); + OP_REQUIRES(ctx, b_c_tensor->dim_size(0) == cell_size, + errors::InvalidArgument( + "b_c.dim_size(0) != cell_size: ", b_c_tensor->dim_size(0), + " vs. ", cell_size)); OP_REQUIRES(ctx, b_c_tensor->dims() == 1, errors::InvalidArgument("Rank of b_c must be 1 ", b_c_tensor->dims(), " vs. 1")); // Shape of 'r' must be [batch_size, cell_size] - OP_REQUIRES( - ctx, r_tensor->dim_size(0) == batch_size, - errors::InvalidArgument("r.dims(0) != batch_size: ", - r_tensor->dim_size(0), " vs. ", batch_size)); - OP_REQUIRES( - ctx, r_tensor->dim_size(1) == cell_size, - errors::InvalidArgument("r.dims(1) != cell_size: ", - r_tensor->dim_size(1), " vs. ", cell_size)); + OP_REQUIRES(ctx, r_tensor->dim_size(0) == batch_size, + errors::InvalidArgument( + "r.dims(0) != batch_size: ", r_tensor->dim_size(0), " vs. ", + batch_size)); + OP_REQUIRES(ctx, r_tensor->dim_size(1) == cell_size, + errors::InvalidArgument( + "r.dims(1) != cell_size: ", r_tensor->dim_size(1), " vs. ", + cell_size)); // Shape of 'u' must be [batch_size, cell_size] - OP_REQUIRES( - ctx, u_tensor->dim_size(0) == batch_size, - errors::InvalidArgument("u.dims(0) != batch_size: ", - u_tensor->dim_size(0), " vs. ", batch_size)); - OP_REQUIRES( - ctx, u_tensor->dim_size(1) == cell_size, - errors::InvalidArgument("u.dims(1) != cell_size: ", - u_tensor->dim_size(1), " vs. ", cell_size)); + OP_REQUIRES(ctx, u_tensor->dim_size(0) == batch_size, + errors::InvalidArgument( + "u.dims(0) != batch_size: ", u_tensor->dim_size(0), " vs. ", + batch_size)); + OP_REQUIRES(ctx, u_tensor->dim_size(1) == cell_size, + errors::InvalidArgument( + "u.dims(1) != cell_size: ", u_tensor->dim_size(1), " vs. ", + cell_size)); // Shape of 'c' must be [batch_size, cell_size] - OP_REQUIRES( - ctx, c_tensor->dim_size(0) == batch_size, - errors::InvalidArgument("c.dims(0) != batch_size: ", - c_tensor->dim_size(0), " vs. ", batch_size)); - OP_REQUIRES( - ctx, c_tensor->dim_size(1) == cell_size, - errors::InvalidArgument("c.dims(1) != cell_size: ", - c_tensor->dim_size(1), " vs. ", cell_size)); + OP_REQUIRES(ctx, c_tensor->dim_size(0) == batch_size, + errors::InvalidArgument( + "c.dims(0) != batch_size: ", c_tensor->dim_size(0), " vs. ", + batch_size)); + OP_REQUIRES(ctx, c_tensor->dim_size(1) == cell_size, + errors::InvalidArgument( + "c.dims(1) != cell_size: ", c_tensor->dim_size(1), " vs. ", + cell_size)); // Shape of 'd_h' must be [batch_size, cell_size] - OP_REQUIRES( - ctx, d_h_tensor->dim_size(0) == batch_size, - errors::InvalidArgument("d_h.dims(0) != batch_size: ", - d_h_tensor->dim_size(0), " vs. ", batch_size)); - OP_REQUIRES( - ctx, d_h_tensor->dim_size(1) == cell_size, - errors::InvalidArgument("d_h.dims(1) != cell_size: ", - d_h_tensor->dim_size(1), " vs. ", cell_size)); + OP_REQUIRES(ctx, d_h_tensor->dim_size(0) == batch_size, + errors::InvalidArgument( + "d_h.dims(0) != batch_size: ", d_h_tensor->dim_size(0), + " vs. ", batch_size)); + OP_REQUIRES(ctx, d_h_tensor->dim_size(1) == cell_size, + errors::InvalidArgument( + "d_h.dims(1) != cell_size: ", d_h_tensor->dim_size(1), + " vs. ", cell_size)); // Create output tensors. Tensor* d_x_tensor = nullptr; diff --git a/tensorflow/contrib/rnn/kernels/lstm_ops.cc b/tensorflow/contrib/rnn/kernels/lstm_ops.cc index 941a457fd3ada312b981fb23c769ff9ecea9ff13..5e7cf0ce84d332bd24088cd78995f7843813328b 100644 --- a/tensorflow/contrib/rnn/kernels/lstm_ops.cc +++ b/tensorflow/contrib/rnn/kernels/lstm_ops.cc @@ -281,23 +281,23 @@ class LSTMBlockCellOp : public OpKernel { h_prev_tensor->dim_size(0), " vs. ", batch_size)); OP_REQUIRES(ctx, h_prev_tensor->dim_size(1) == cell_size, - errors::InvalidArgument("h_prev.dims(1) != cell_size: ", - h_prev_tensor->dim_size(1), " vs. ", - cell_size)); + errors::InvalidArgument( + "h_prev.dims(1) != cell_size: ", h_prev_tensor->dim_size(1), + " vs. ", cell_size)); OP_REQUIRES(ctx, w_tensor->dim_size(0) == input_size + cell_size, errors::InvalidArgument( "w.dim_size(0) != input_size + cell_size: ", w_tensor->dim_size(0), " vs. ", input_size + cell_size)); - OP_REQUIRES( - ctx, w_tensor->dim_size(1) == cell_size * 4, - errors::InvalidArgument("w.dim_size(1) != cell_size * 4: ", - w_tensor->dim_size(1), " vs. ", cell_size * 4)); + OP_REQUIRES(ctx, w_tensor->dim_size(1) == cell_size * 4, + errors::InvalidArgument( + "w.dim_size(1) != cell_size * 4: ", w_tensor->dim_size(1), + " vs. ", cell_size * 4)); - OP_REQUIRES( - ctx, b_tensor->dim_size(0) == cell_size * 4, - errors::InvalidArgument("b.dim_size(0) != cell_size * 4: ", - b_tensor->dim_size(0), " vs. ", cell_size * 4)); + OP_REQUIRES(ctx, b_tensor->dim_size(0) == cell_size * 4, + errors::InvalidArgument( + "b.dim_size(0) != cell_size * 4: ", b_tensor->dim_size(0), + " vs. ", cell_size * 4)); // Allocate our output tensors. Tensor* i_tensor = nullptr; @@ -484,77 +484,77 @@ class LSTMBlockCellGradOp : public OpKernel { h_prev_tensor->dim_size(0), " vs. ", batch_size)); OP_REQUIRES(ctx, h_prev_tensor->dim_size(1) == cell_size, - errors::InvalidArgument("h_prev.dims(1) != cell_size: ", - h_prev_tensor->dim_size(1), " vs. ", - cell_size)); + errors::InvalidArgument( + "h_prev.dims(1) != cell_size: ", h_prev_tensor->dim_size(1), + " vs. ", cell_size)); OP_REQUIRES(ctx, w_tensor->dim_size(0) == input_size + cell_size, errors::InvalidArgument( "w.dim_size(0) != input_size + cell_size: ", w_tensor->dim_size(0), " vs. ", input_size + cell_size)); - OP_REQUIRES( - ctx, w_tensor->dim_size(1) == cell_size * 4, - errors::InvalidArgument("w.dim_size(1) != cell_size * 4: ", - w_tensor->dim_size(1), " vs. ", cell_size * 4)); + OP_REQUIRES(ctx, w_tensor->dim_size(1) == cell_size * 4, + errors::InvalidArgument( + "w.dim_size(1) != cell_size * 4: ", w_tensor->dim_size(1), + " vs. ", cell_size * 4)); - OP_REQUIRES( - ctx, b_tensor->dim_size(0) == cell_size * 4, - errors::InvalidArgument("b.dim_size(0) != cell_size * 4: ", - b_tensor->dim_size(0), " vs. ", cell_size * 4)); + OP_REQUIRES(ctx, b_tensor->dim_size(0) == cell_size * 4, + errors::InvalidArgument( + "b.dim_size(0) != cell_size * 4: ", b_tensor->dim_size(0), + " vs. ", cell_size * 4)); - OP_REQUIRES( - ctx, i_tensor->dim_size(0) == batch_size, - errors::InvalidArgument("i.dim_size(0) != batch_size: ", - i_tensor->dim_size(0), " vs. ", batch_size)); - OP_REQUIRES( - ctx, i_tensor->dim_size(1) == cell_size, - errors::InvalidArgument("i.dim_size(1) != cell_size: ", - i_tensor->dim_size(1), " vs. ", cell_size)); + OP_REQUIRES(ctx, i_tensor->dim_size(0) == batch_size, + errors::InvalidArgument( + "i.dim_size(0) != batch_size: ", i_tensor->dim_size(0), + " vs. ", batch_size)); + OP_REQUIRES(ctx, i_tensor->dim_size(1) == cell_size, + errors::InvalidArgument( + "i.dim_size(1) != cell_size: ", i_tensor->dim_size(1), + " vs. ", cell_size)); - OP_REQUIRES( - ctx, cs_tensor->dim_size(0) == batch_size, - errors::InvalidArgument("cs.dim_size(0) != batch_size: ", - cs_tensor->dim_size(0), " vs. ", batch_size)); - OP_REQUIRES( - ctx, cs_tensor->dim_size(1) == cell_size, - errors::InvalidArgument("cs.dim_size(1) != cell_size: ", - cs_tensor->dim_size(1), " vs. ", cell_size)); + OP_REQUIRES(ctx, cs_tensor->dim_size(0) == batch_size, + errors::InvalidArgument( + "cs.dim_size(0) != batch_size: ", cs_tensor->dim_size(0), + " vs. ", batch_size)); + OP_REQUIRES(ctx, cs_tensor->dim_size(1) == cell_size, + errors::InvalidArgument( + "cs.dim_size(1) != cell_size: ", cs_tensor->dim_size(1), + " vs. ", cell_size)); - OP_REQUIRES( - ctx, f_tensor->dim_size(0) == batch_size, - errors::InvalidArgument("f.dim_size(0) != batch_size: ", - f_tensor->dim_size(0), " vs. ", batch_size)); - OP_REQUIRES( - ctx, f_tensor->dim_size(1) == cell_size, - errors::InvalidArgument("i.dim_size(1) != cell_size: ", - f_tensor->dim_size(1), " vs. ", cell_size)); + OP_REQUIRES(ctx, f_tensor->dim_size(0) == batch_size, + errors::InvalidArgument( + "f.dim_size(0) != batch_size: ", f_tensor->dim_size(0), + " vs. ", batch_size)); + OP_REQUIRES(ctx, f_tensor->dim_size(1) == cell_size, + errors::InvalidArgument( + "i.dim_size(1) != cell_size: ", f_tensor->dim_size(1), + " vs. ", cell_size)); - OP_REQUIRES( - ctx, o_tensor->dim_size(0) == batch_size, - errors::InvalidArgument("o.dim_size(0) != batch_size: ", - o_tensor->dim_size(0), " vs. ", batch_size)); - OP_REQUIRES( - ctx, o_tensor->dim_size(1) == cell_size, - errors::InvalidArgument("o.dim_size(1) != cell_size: ", - o_tensor->dim_size(1), " vs. ", cell_size)); + OP_REQUIRES(ctx, o_tensor->dim_size(0) == batch_size, + errors::InvalidArgument( + "o.dim_size(0) != batch_size: ", o_tensor->dim_size(0), + " vs. ", batch_size)); + OP_REQUIRES(ctx, o_tensor->dim_size(1) == cell_size, + errors::InvalidArgument( + "o.dim_size(1) != cell_size: ", o_tensor->dim_size(1), + " vs. ", cell_size)); - OP_REQUIRES( - ctx, ci_tensor->dim_size(0) == batch_size, - errors::InvalidArgument("ci.dim_size(0) != batch_size: ", - ci_tensor->dim_size(0), " vs. ", batch_size)); - OP_REQUIRES( - ctx, ci_tensor->dim_size(1) == cell_size, - errors::InvalidArgument("ci.dim_size(1) != cell_size: ", - ci_tensor->dim_size(1), " vs. ", cell_size)); + OP_REQUIRES(ctx, ci_tensor->dim_size(0) == batch_size, + errors::InvalidArgument( + "ci.dim_size(0) != batch_size: ", ci_tensor->dim_size(0), + " vs. ", batch_size)); + OP_REQUIRES(ctx, ci_tensor->dim_size(1) == cell_size, + errors::InvalidArgument( + "ci.dim_size(1) != cell_size: ", ci_tensor->dim_size(1), + " vs. ", cell_size)); - OP_REQUIRES( - ctx, co_tensor->dim_size(0) == batch_size, - errors::InvalidArgument("co.dim_size(0) != batch_size: ", - co_tensor->dim_size(0), " vs. ", batch_size)); - OP_REQUIRES( - ctx, co_tensor->dim_size(1) == cell_size, - errors::InvalidArgument("co.dim_size(1) != cell_size: ", - co_tensor->dim_size(1), " vs. ", cell_size)); + OP_REQUIRES(ctx, co_tensor->dim_size(0) == batch_size, + errors::InvalidArgument( + "co.dim_size(0) != batch_size: ", co_tensor->dim_size(0), + " vs. ", batch_size)); + OP_REQUIRES(ctx, co_tensor->dim_size(1) == cell_size, + errors::InvalidArgument( + "co.dim_size(1) != cell_size: ", co_tensor->dim_size(1), + " vs. ", cell_size)); OP_REQUIRES(ctx, cs_grad_tensor->dim_size(0) == batch_size, errors::InvalidArgument( @@ -860,9 +860,9 @@ class BlockLSTMOp : public OpKernel { h_prev_tensor->dim_size(0), " vs. ", batch_size)); OP_REQUIRES(ctx, h_prev_tensor->dim_size(1) == cell_size, - errors::InvalidArgument("h_prev.dims(1) != cell_size: ", - h_prev_tensor->dim_size(1), " vs. ", - cell_size)); + errors::InvalidArgument( + "h_prev.dims(1) != cell_size: ", h_prev_tensor->dim_size(1), + " vs. ", cell_size)); const Tensor* w_tensor = nullptr; OP_REQUIRES_OK(ctx, ctx->input("w", &w_tensor)); @@ -872,46 +872,46 @@ class BlockLSTMOp : public OpKernel { errors::InvalidArgument( "w.dim_size(0) != input_size + cell_size: ", w_tensor->dim_size(0), " vs. ", input_size + cell_size)); - OP_REQUIRES( - ctx, w_tensor->dim_size(1) == cell_size * 4, - errors::InvalidArgument("w.dim_size(1) != cell_size * 4: ", - w_tensor->dim_size(1), " vs. ", cell_size * 4)); + OP_REQUIRES(ctx, w_tensor->dim_size(1) == cell_size * 4, + errors::InvalidArgument( + "w.dim_size(1) != cell_size * 4: ", w_tensor->dim_size(1), + " vs. ", cell_size * 4)); const Tensor* wci_tensor = nullptr; OP_REQUIRES_OK(ctx, ctx->input("wci", &wci_tensor)); OP_REQUIRES(ctx, wci_tensor->dims() == 1, errors::InvalidArgument("wci must be 1D")); - OP_REQUIRES( - ctx, wci_tensor->dim_size(0) == cell_size, - errors::InvalidArgument("wci.dim_size(0) != cell_size: ", - wci_tensor->dim_size(0), " vs. ", cell_size)); + OP_REQUIRES(ctx, wci_tensor->dim_size(0) == cell_size, + errors::InvalidArgument( + "wci.dim_size(0) != cell_size: ", wci_tensor->dim_size(0), + " vs. ", cell_size)); const Tensor* wcf_tensor = nullptr; OP_REQUIRES_OK(ctx, ctx->input("wcf", &wcf_tensor)); OP_REQUIRES(ctx, wcf_tensor->dims() == 1, errors::InvalidArgument("wcf must be 1D")); - OP_REQUIRES( - ctx, wcf_tensor->dim_size(0) == cell_size, - errors::InvalidArgument("wcf.dim_size(0) != cell_size: ", - wcf_tensor->dim_size(0), " vs. ", cell_size)); + OP_REQUIRES(ctx, wcf_tensor->dim_size(0) == cell_size, + errors::InvalidArgument( + "wcf.dim_size(0) != cell_size: ", wcf_tensor->dim_size(0), + " vs. ", cell_size)); const Tensor* wco_tensor = nullptr; OP_REQUIRES_OK(ctx, ctx->input("wco", &wco_tensor)); OP_REQUIRES(ctx, wco_tensor->dims() == 1, errors::InvalidArgument("wco must be 1D")); - OP_REQUIRES( - ctx, wco_tensor->dim_size(0) == cell_size, - errors::InvalidArgument("wco.dim_size(0) != cell_size: ", - wco_tensor->dim_size(0), " vs. ", cell_size)); + OP_REQUIRES(ctx, wco_tensor->dim_size(0) == cell_size, + errors::InvalidArgument( + "wco.dim_size(0) != cell_size: ", wco_tensor->dim_size(0), + " vs. ", cell_size)); const Tensor* b_tensor = nullptr; OP_REQUIRES_OK(ctx, ctx->input("b", &b_tensor)); OP_REQUIRES(ctx, b_tensor->dims() == 1, errors::InvalidArgument("b must be 1D")); - OP_REQUIRES( - ctx, b_tensor->dim_size(0) == cell_size * 4, - errors::InvalidArgument("b.dim_size(0) != cell_size * 4: ", - b_tensor->dim_size(0), " vs. ", cell_size * 4)); + OP_REQUIRES(ctx, b_tensor->dim_size(0) == cell_size * 4, + errors::InvalidArgument( + "b.dim_size(0) != cell_size * 4: ", b_tensor->dim_size(0), + " vs. ", cell_size * 4)); TensorShape batch_cell_shape({timelen, batch_size, cell_size}); Tensor* i_out; @@ -1065,9 +1065,9 @@ class BlockLSTMGradOp : public OpKernel { OP_REQUIRES_OK(ctx, ctx->input("w", &w_tensor)); const int64 cell_size = w_tensor->dim_size(1) / 4; OP_REQUIRES(ctx, input_size + cell_size == w_tensor->dim_size(0), - errors::InvalidArgument("w matrix rows don't match: ", - input_size + cell_size, " vs. ", - w_tensor->dim_size(0))); + errors::InvalidArgument( + "w matrix rows don't match: ", input_size + cell_size, + " vs. ", w_tensor->dim_size(0))); const Tensor* wci_tensor = nullptr; OP_REQUIRES_OK(ctx, ctx->input("wci", &wci_tensor)); @@ -1193,7 +1193,6 @@ class BlockLSTMGradOp : public OpKernel { OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum::v(), batch_cell_shape, &h_grad_tensor)); - const Device& device = ctx->eigen_device(); functor::TensorZero()(device, cs_grad_tensor.flat()); diff --git a/tensorflow/contrib/rnn/kernels/lstm_ops.h b/tensorflow/contrib/rnn/kernels/lstm_ops.h index bc6b85f3f1ab80b5ef5b4a8ba2e5242cf451adbe..d23cedc234b8c0e1a784346f28164ae79b8cbf89 100644 --- a/tensorflow/contrib/rnn/kernels/lstm_ops.h +++ b/tensorflow/contrib/rnn/kernels/lstm_ops.h @@ -92,7 +92,6 @@ struct TensorZeroPadding { } }; - struct LSTMBlockCell { LSTMBlockCell(const int batch_size, const int input_size, const int cell_size) : batch_size_(batch_size), diff --git a/tensorflow/contrib/rnn/ops/lstm_ops_test.cc b/tensorflow/contrib/rnn/ops/lstm_ops_test.cc index 544cd163c50062093acf7f5e942f67606936c0e3..68184b643e5e7a04ffecb804703051638514b7b2 100644 --- a/tensorflow/contrib/rnn/ops/lstm_ops_test.cc +++ b/tensorflow/contrib/rnn/ops/lstm_ops_test.cc @@ -149,8 +149,9 @@ TEST_F(LSTMOpsTest, BlockLSTMGrad_ShapeFn) { INFER_ERROR("must be rank 1", op, "?;?;?;?;?;?;?;?;[1,?]" + suffix); // Output with all input knowns makes known rank outputs. - INFER_OK(op, JoinedCopies("?", 18), "[?,?,?];" + JoinedCopies("[?,?]", 3) + - ";" + JoinedCopies("[?]", 4)); + INFER_OK( + op, JoinedCopies("?", 18), + "[?,?,?];" + JoinedCopies("[?,?]", 3) + ";" + JoinedCopies("[?]", 4)); // Output with copies input shapes to output. string input = strings::StrCat("?;[?,?,?];", JoinedCopies("[?,?]", 3), ";", diff --git a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py index 5711f41cc38a95c2b4febb6d3f1ddf11b4fd9843..9b84635e85e8d300be4a77a4cc261b70d14ae2ac 100644 --- a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py +++ b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py @@ -42,7 +42,6 @@ from tensorflow.python.platform import test from tensorflow.python.framework import test_util from tensorflow.contrib.rnn.python.ops import rnn_cell as contrib_rnn_cell - # pylint: enable=protected-access Linear = core_rnn_cell._Linear # pylint: disable=invalid-name @@ -84,19 +83,22 @@ class RNNCellTest(test.TestCase): ], [v.name for v in cell.trainable_variables]) self.assertFalse(cell.non_trainable_variables) sess.run([variables_lib.global_variables_initializer()]) - res = sess.run( - [g], {x.name: np.array([[1., 1.]]), - m.name: np.array([[0.1, 0.1]])}) + res = sess.run([g], { + x.name: np.array([[1., 1.]]), + m.name: np.array([[0.1, 0.1]]) + }) self.assertEqual(res[0].shape, (1, 2)) def testBasicRNNCellNotTrainable(self): with self.test_session() as sess: + def not_trainable_getter(getter, *args, **kwargs): kwargs["trainable"] = False return getter(*args, **kwargs) with variable_scope.variable_scope( - "root", initializer=init_ops.constant_initializer(0.5), + "root", + initializer=init_ops.constant_initializer(0.5), custom_getter=not_trainable_getter): x = array_ops.zeros([1, 2]) m = array_ops.zeros([1, 2]) @@ -108,9 +110,10 @@ class RNNCellTest(test.TestCase): "root/basic_rnn_cell/%s:0" % rnn_cell_impl._BIAS_VARIABLE_NAME ], [v.name for v in cell.non_trainable_variables]) sess.run([variables_lib.global_variables_initializer()]) - res = sess.run( - [g], {x.name: np.array([[1., 1.]]), - m.name: np.array([[0.1, 0.1]])}) + res = sess.run([g], { + x.name: np.array([[1., 1.]]), + m.name: np.array([[0.1, 0.1]]) + }) self.assertEqual(res[0].shape, (1, 2)) def testGRUCell(self): @@ -121,9 +124,10 @@ class RNNCellTest(test.TestCase): m = array_ops.zeros([1, 2]) g, _ = rnn_cell_impl.GRUCell(2)(x, m) sess.run([variables_lib.global_variables_initializer()]) - res = sess.run( - [g], {x.name: np.array([[1., 1.]]), - m.name: np.array([[0.1, 0.1]])}) + res = sess.run([g], { + x.name: np.array([[1., 1.]]), + m.name: np.array([[0.1, 0.1]]) + }) # Smoke test self.assertAllClose(res[0], [[0.175991, 0.175991]]) with variable_scope.variable_scope( @@ -133,10 +137,10 @@ class RNNCellTest(test.TestCase): m = array_ops.zeros([1, 2]) g, _ = rnn_cell_impl.GRUCell(2)(x, m) sess.run([variables_lib.global_variables_initializer()]) - res = sess.run( - [g], - {x.name: np.array([[1., 1., 1.]]), - m.name: np.array([[0.1, 0.1]])}) + res = sess.run([g], { + x.name: np.array([[1., 1., 1.]]), + m.name: np.array([[0.1, 0.1]]) + }) # Smoke test self.assertAllClose(res[0], [[0.156736, 0.156736]]) @@ -148,12 +152,13 @@ class RNNCellTest(test.TestCase): m = array_ops.zeros([1, 2]) g, _ = contrib_rnn_cell.SRUCell(2)(x, m) sess.run([variables_lib.global_variables_initializer()]) - res = sess.run( - [g], {x.name: np.array([[1., 1.]]), - m.name: np.array([[0.1, 0.1]])}) + res = sess.run([g], { + x.name: np.array([[1., 1.]]), + m.name: np.array([[0.1, 0.1]]) + }) # Smoke test - self.assertAllClose(res[0], [[0.509682, 0.509682]]) - + self.assertAllClose(res[0], [[0.509682, 0.509682]]) + def testSRUCellWithDiffSize(self): with self.test_session() as sess: with variable_scope.variable_scope( @@ -178,8 +183,7 @@ class RNNCellTest(test.TestCase): m = array_ops.zeros([1, 8], dtype=dtype) cell = rnn_cell_impl.MultiRNNCell( [ - rnn_cell_impl.BasicLSTMCell( - 2, state_is_tuple=False) + rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=False) for _ in range(2) ], state_is_tuple=False) @@ -197,22 +201,21 @@ class RNNCellTest(test.TestCase): "root/multi_rnn_cell/cell_1/basic_lstm_cell/%s:0" % rnn_cell_impl._BIAS_VARIABLE_NAME ] - self.assertEqual( - expected_variable_names, - [v.name for v in cell.trainable_variables]) + self.assertEqual(expected_variable_names, + [v.name for v in cell.trainable_variables]) self.assertFalse(cell.non_trainable_variables) sess.run([variables_lib.global_variables_initializer()]) - res = sess.run( - [g, out_m], - {x.name: np.array([[1., 1.]]), - m.name: 0.1 * np.ones([1, 8])}) + res = sess.run([g, out_m], { + x.name: np.array([[1., 1.]]), + m.name: 0.1 * np.ones([1, 8]) + }) self.assertEqual(len(res), 2) variables = variables_lib.global_variables() self.assertEqual(expected_variable_names, [v.name for v in variables]) # The numbers in results were not calculated, this is just a # smoke test. - self.assertAllClose( - res[0], np.array([[0.240, 0.240]], dtype=np_dtype), 1e-2) + self.assertAllClose(res[0], np.array( + [[0.240, 0.240]], dtype=np_dtype), 1e-2) expected_mem = np.array( [[0.689, 0.689, 0.448, 0.448, 0.398, 0.398, 0.240, 0.240]], dtype=np_dtype) @@ -222,13 +225,13 @@ class RNNCellTest(test.TestCase): # Test BasicLSTMCell with input_size != num_units. x = array_ops.zeros([1, 3], dtype=dtype) m = array_ops.zeros([1, 4], dtype=dtype) - g, out_m = rnn_cell_impl.BasicLSTMCell( - 2, state_is_tuple=False)(x, m) + g, out_m = rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=False)(x, m) sess.run([variables_lib.global_variables_initializer()]) res = sess.run( - [g, out_m], - {x.name: np.array([[1., 1., 1.]], dtype=np_dtype), - m.name: 0.1 * np.ones([1, 4], dtype=np_dtype)}) + [g, out_m], { + x.name: np.array([[1., 1., 1.]], dtype=np_dtype), + m.name: 0.1 * np.ones([1, 4], dtype=np_dtype) + }) self.assertEqual(len(res), 2) def testBasicLSTMCellDimension0Error(self): @@ -246,9 +249,11 @@ class RNNCellTest(test.TestCase): g, out_m = rnn_cell_impl.BasicLSTMCell( num_units, state_is_tuple=False)(x, m) sess.run([variables_lib.global_variables_initializer()]) - sess.run([g, out_m], - {x.name: 1 * np.ones([batch_size, input_size]), - m.name: 0.1 * np.ones([batch_size - 1, state_size])}) + sess.run( + [g, out_m], { + x.name: 1 * np.ones([batch_size, input_size]), + m.name: 0.1 * np.ones([batch_size - 1, state_size]) + }) def testBasicLSTMCellStateSizeError(self): """Tests that state_size must be num_units * 2.""" @@ -265,9 +270,11 @@ class RNNCellTest(test.TestCase): g, out_m = rnn_cell_impl.BasicLSTMCell( num_units, state_is_tuple=False)(x, m) sess.run([variables_lib.global_variables_initializer()]) - sess.run([g, out_m], - {x.name: 1 * np.ones([batch_size, input_size]), - m.name: 0.1 * np.ones([batch_size, state_size])}) + sess.run( + [g, out_m], { + x.name: 1 * np.ones([batch_size, input_size]), + m.name: 0.1 * np.ones([batch_size, state_size]) + }) def testBasicLSTMCellStateTupleType(self): with self.test_session(): @@ -315,11 +322,12 @@ class RNNCellTest(test.TestCase): state_is_tuple=True) g, (out_m0, out_m1) = cell(x, (m0, m1)) sess.run([variables_lib.global_variables_initializer()]) - res = sess.run([g, out_m0, out_m1], { - x.name: np.array([[1., 1.]]), - m0.name: 0.1 * np.ones([1, 4]), - m1.name: 0.1 * np.ones([1, 4]) - }) + res = sess.run( + [g, out_m0, out_m1], { + x.name: np.array([[1., 1.]]), + m0.name: 0.1 * np.ones([1, 4]), + m1.name: 0.1 * np.ones([1, 4]) + }) self.assertEqual(len(res), 3) # The numbers in results were not calculated, this is just a smoke test. # Note, however, these values should match the original @@ -350,10 +358,11 @@ class RNNCellTest(test.TestCase): state_is_tuple=False) output, state = cell(x, m) sess.run([variables_lib.global_variables_initializer()]) - res = sess.run([output, state], { - x.name: np.array([[1., 1.], [2., 2.], [3., 3.]]), - m.name: 0.1 * np.ones((batch_size, state_size)) - }) + res = sess.run( + [output, state], { + x.name: np.array([[1., 1.], [2., 2.], [3., 3.]]), + m.name: 0.1 * np.ones((batch_size, state_size)) + }) self.assertEqual(len(res), 2) # The numbers in results were not calculated, this is mostly just a # smoke test. @@ -456,10 +465,10 @@ class RNNCellTest(test.TestCase): rnn_cell_impl.GRUCell(3), num_proj=3) g, new_m = cell(x, m) sess.run([variables_lib.global_variables_initializer()]) - res = sess.run( - [g, new_m], - {x.name: np.array([[1., 1.]]), - m.name: np.array([[0.1, 0.1, 0.1]])}) + res = sess.run([g, new_m], { + x.name: np.array([[1., 1.]]), + m.name: np.array([[0.1, 0.1, 0.1]]) + }) self.assertEqual(res[1].shape, (1, 3)) # The numbers in results were not calculated, this is just a smoke test. self.assertAllClose(res[0], [[0.154605, 0.154605, 0.154605]]) @@ -493,9 +502,11 @@ class RNNCellTest(test.TestCase): base_cell = rnn_cell_impl.GRUCell(3) g, m_new = base_cell(x, m) variable_scope.get_variable_scope().reuse_variables() + def residual_with_slice_fn(inp, out): inp_sliced = array_ops.slice(inp, [0, 0], [-1, 3]) return inp_sliced + out + g_res, m_new_res = rnn_cell_impl.ResidualWrapper( base_cell, residual_with_slice_fn)(x, m) sess.run([variables_lib.global_variables_initializer()]) @@ -565,10 +576,10 @@ class RNNCellTest(test.TestCase): self.assertEqual(embedding_cell.output_size, 2) g, new_m = embedding_cell(x, m) sess.run([variables_lib.global_variables_initializer()]) - res = sess.run( - [g, new_m], - {x.name: np.array([[1]]), - m.name: np.array([[0.1, 0.1]])}) + res = sess.run([g, new_m], { + x.name: np.array([[1]]), + m.name: np.array([[0.1, 0.1]]) + }) self.assertEqual(res[1].shape, (1, 2)) # The numbers in results were not calculated, this is just a smoke test. self.assertAllClose(res[0], [[0.17139, 0.17139]]) @@ -598,8 +609,8 @@ class RNNCellTest(test.TestCase): x = array_ops.zeros([1, 2]) m = array_ops.zeros([1, 4]) _, ml = rnn_cell_impl.MultiRNNCell( - [rnn_cell_impl.GRUCell(2) - for _ in range(2)], state_is_tuple=False)(x, m) + [rnn_cell_impl.GRUCell(2) for _ in range(2)], + state_is_tuple=False)(x, m) sess.run([variables_lib.global_variables_initializer()]) res = sess.run(ml, { x.name: np.array([[1., 1.]]), @@ -619,19 +630,20 @@ class RNNCellTest(test.TestCase): # Test incorrectness of state with self.assertRaisesRegexp(ValueError, "Expected state .* a tuple"): rnn_cell_impl.MultiRNNCell( - [rnn_cell_impl.GRUCell(2) - for _ in range(2)], state_is_tuple=True)(x, m_bad) + [rnn_cell_impl.GRUCell(2) for _ in range(2)], + state_is_tuple=True)(x, m_bad) _, ml = rnn_cell_impl.MultiRNNCell( - [rnn_cell_impl.GRUCell(2) - for _ in range(2)], state_is_tuple=True)(x, m_good) + [rnn_cell_impl.GRUCell(2) for _ in range(2)], + state_is_tuple=True)(x, m_good) sess.run([variables_lib.global_variables_initializer()]) - res = sess.run(ml, { - x.name: np.array([[1., 1.]]), - m_good[0].name: np.array([[0.1, 0.1]]), - m_good[1].name: np.array([[0.1, 0.1]]) - }) + res = sess.run( + ml, { + x.name: np.array([[1., 1.]]), + m_good[0].name: np.array([[0.1, 0.1]]), + m_good[1].name: np.array([[0.1, 0.1]]) + }) # The numbers in results were not calculated, this is just a # smoke test. However, these numbers should match those of @@ -642,8 +654,11 @@ class RNNCellTest(test.TestCase): class DropoutWrapperTest(test.TestCase): - def _testDropoutWrapper(self, batch_size=None, time_steps=None, - parallel_iterations=None, **kwargs): + def _testDropoutWrapper(self, + batch_size=None, + time_steps=None, + parallel_iterations=None, + **kwargs): with self.test_session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): @@ -654,14 +669,14 @@ class DropoutWrapperTest(test.TestCase): x = constant_op.constant( [[[2., 2., 2.]], [[1., 1., 1.]]], dtype=dtypes.float32) m = rnn_cell_impl.LSTMStateTuple( - *[constant_op.constant([[0.1, 0.1, 0.1]], dtype=dtypes.float32) - ] * 2) + *[constant_op.constant([[0.1, 0.1, 0.1]], dtype=dtypes.float32 + )] * 2) else: x = constant_op.constant( np.random.randn(time_steps, batch_size, 3).astype(np.float32)) m = rnn_cell_impl.LSTMStateTuple(*[ - constant_op.constant( - [[0.1, 0.1, 0.1]] * batch_size, dtype=dtypes.float32) + constant_op. + constant([[0.1, 0.1, 0.1]] * batch_size, dtype=dtypes.float32) ] * 2) outputs, final_state = rnn.dynamic_rnn( cell=rnn_cell_impl.DropoutWrapper( @@ -688,8 +703,8 @@ class DropoutWrapperTest(test.TestCase): res = self._testDropoutWrapper( input_keep_prob=keep, output_keep_prob=keep, state_keep_prob=keep) true_full_output = np.array( - [[[0.751109, 0.751109, 0.751109]], - [[0.895509, 0.895509, 0.895509]]], dtype=np.float32) + [[[0.751109, 0.751109, 0.751109]], [[0.895509, 0.895509, 0.895509]]], + dtype=np.float32) true_full_final_c = np.array( [[1.949385, 1.949385, 1.949385]], dtype=np.float32) self.assertAllClose(true_full_output, res[0]) @@ -701,8 +716,8 @@ class DropoutWrapperTest(test.TestCase): res = self._testDropoutWrapper( input_keep_prob=keep, output_keep_prob=keep, state_keep_prob=keep) true_full_output = np.array( - [[[0.751109, 0.751109, 0.751109]], - [[0.895509, 0.895509, 0.895509]]], dtype=np.float32) + [[[0.751109, 0.751109, 0.751109]], [[0.895509, 0.895509, 0.895509]]], + dtype=np.float32) true_full_final_c = np.array( [[1.949385, 1.949385, 1.949385]], dtype=np.float32) self.assertAllClose(true_full_output, res[0]) @@ -717,16 +732,20 @@ class DropoutWrapperTest(test.TestCase): ## consistent across both calls. Otherwise the seed may not end ## up being munged consistently across both graphs. res_standard_1 = self._testDropoutWrapper( - input_keep_prob=keep_some, output_keep_prob=keep_some, - state_keep_prob=keep_some, seed=10, + input_keep_prob=keep_some, + output_keep_prob=keep_some, + state_keep_prob=keep_some, + seed=10, parallel_iterations=1) # Clear away the graph and the test session (which keeps variables around) ops.reset_default_graph() self._ClearCachedSession() random_seed.set_random_seed(2) res_standard_2 = self._testDropoutWrapper( - input_keep_prob=keep_some, output_keep_prob=keep_some, - state_keep_prob=keep_some, seed=10, + input_keep_prob=keep_some, + output_keep_prob=keep_some, + state_keep_prob=keep_some, + seed=10, parallel_iterations=1) self.assertAllClose(res_standard_1[0], res_standard_2[0]) self.assertAllClose(res_standard_1[1].c, res_standard_2[1].c) @@ -736,11 +755,12 @@ class DropoutWrapperTest(test.TestCase): keep_all = variable_scope.get_variable("all", initializer=1.0) keep_none = variable_scope.get_variable("none", initializer=1e-10) res = self._testDropoutWrapper( - input_keep_prob=keep_all, output_keep_prob=keep_none, + input_keep_prob=keep_all, + output_keep_prob=keep_none, state_keep_prob=keep_all) true_full_output = np.array( - [[[0.751109, 0.751109, 0.751109]], - [[0.895509, 0.895509, 0.895509]]], dtype=np.float32) + [[[0.751109, 0.751109, 0.751109]], [[0.895509, 0.895509, 0.895509]]], + dtype=np.float32) true_full_final_c = np.array( [[1.949385, 1.949385, 1.949385]], dtype=np.float32) self.assertAllClose(np.zeros(res[0].shape), res[0]) @@ -753,13 +773,13 @@ class DropoutWrapperTest(test.TestCase): # Even though we dropout state, by default DropoutWrapper never # drops out the memory ("c") term of an LSTMStateTuple. res = self._testDropoutWrapper( - input_keep_prob=keep_all, output_keep_prob=keep_all, + input_keep_prob=keep_all, + output_keep_prob=keep_all, state_keep_prob=keep_none) - true_c_state = np.array( - [[1.713925, 1.713925, 1.713925]], dtype=np.float32) + true_c_state = np.array([[1.713925, 1.713925, 1.713925]], dtype=np.float32) true_full_output = np.array( - [[[0.751109, 0.751109, 0.751109]], - [[0.895509, 0.895509, 0.895509]]], dtype=np.float32) + [[[0.751109, 0.751109, 0.751109]], [[0.895509, 0.895509, 0.895509]]], + dtype=np.float32) self.assertAllClose(true_full_output[0], res[0][0]) # Second output is modified by zero input state self.assertGreater(np.linalg.norm(true_full_output[1] - res[0][1]), 1e-4) @@ -772,13 +792,14 @@ class DropoutWrapperTest(test.TestCase): keep_all = variable_scope.get_variable("all", initializer=1.0) keep_none = variable_scope.get_variable("none", initializer=1e-10) true_full_output = np.array( - [[[0.751109, 0.751109, 0.751109]], - [[0.895509, 0.895509, 0.895509]]], dtype=np.float32) + [[[0.751109, 0.751109, 0.751109]], [[0.895509, 0.895509, 0.895509]]], + dtype=np.float32) true_full_final_c = np.array( [[1.949385, 1.949385, 1.949385]], dtype=np.float32) # All outputs are different because inputs are zeroed out res = self._testDropoutWrapper( - input_keep_prob=keep_none, output_keep_prob=keep_all, + input_keep_prob=keep_none, + output_keep_prob=keep_all, state_keep_prob=keep_all) self.assertGreater(np.linalg.norm(res[0] - true_full_output), 1e-4) self.assertGreater(np.linalg.norm(res[1].h - true_full_output[1]), 1e-4) @@ -788,9 +809,13 @@ class DropoutWrapperTest(test.TestCase): keep_some = 0.8 keep_all = variable_scope.get_variable("all", initializer=1.0) res = self._testDropoutWrapper( - input_keep_prob=keep_all, output_keep_prob=keep_some, - state_keep_prob=keep_all, variational_recurrent=True, - input_size=3, batch_size=5, time_steps=7) + input_keep_prob=keep_all, + output_keep_prob=keep_some, + state_keep_prob=keep_all, + variational_recurrent=True, + input_size=3, + batch_size=5, + time_steps=7) # Ensure the same dropout pattern for all time steps output_mask = np.abs(res[0]) > 1e-6 for m in output_mask[1:]: @@ -799,9 +824,13 @@ class DropoutWrapperTest(test.TestCase): def testDropoutWrapperRecurrentStateInputAndOutput(self): keep_some = 0.9 res = self._testDropoutWrapper( - input_keep_prob=keep_some, output_keep_prob=keep_some, - state_keep_prob=keep_some, variational_recurrent=True, - input_size=3, batch_size=5, time_steps=7) + input_keep_prob=keep_some, + output_keep_prob=keep_some, + state_keep_prob=keep_some, + variational_recurrent=True, + input_size=3, + batch_size=5, + time_steps=7) # Smoke test for the state/input masks. output_mask = np.abs(res[0]) > 1e-6 @@ -825,17 +854,27 @@ class DropoutWrapperTest(test.TestCase): random_seed.set_random_seed(2347) np.random.seed(23487) res0 = self._testDropoutWrapper( - input_keep_prob=keep_some, output_keep_prob=keep_some, - state_keep_prob=keep_some, variational_recurrent=True, - input_size=3, batch_size=5, time_steps=7, seed=-234987) + input_keep_prob=keep_some, + output_keep_prob=keep_some, + state_keep_prob=keep_some, + variational_recurrent=True, + input_size=3, + batch_size=5, + time_steps=7, + seed=-234987) ops.reset_default_graph() self._ClearCachedSession() random_seed.set_random_seed(2347) np.random.seed(23487) res1 = self._testDropoutWrapper( - input_keep_prob=keep_some, output_keep_prob=keep_some, - state_keep_prob=keep_some, variational_recurrent=True, - input_size=3, batch_size=5, time_steps=7, seed=-234987) + input_keep_prob=keep_some, + output_keep_prob=keep_some, + state_keep_prob=keep_some, + variational_recurrent=True, + input_size=3, + batch_size=5, + time_steps=7, + seed=-234987) output_mask = np.abs(res0[0]) > 1e-6 for time_step in output_mask: @@ -872,9 +911,10 @@ class SlimRNNCellTest(test.TestCase): g, _ = rnn_cell_impl._SlimRNNCell(my_cell)(x, m) # pylint: enable=protected-access sess.run([variables_lib.global_variables_initializer()]) - res = sess.run( - [g], {x.name: np.array([[1., 1.]]), - m.name: np.array([[0.1, 0.1]])}) + res = sess.run([g], { + x.name: np.array([[1., 1.]]), + m.name: np.array([[0.1, 0.1]]) + }) self.assertEqual(res[0].shape, (1, 2)) def testBasicRNNCellMatch(self): diff --git a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py index 8a3894ef9d7042e66b52edefdf08b278dcc6c4f4..7b883ebc5d7756f1bdf445f900500a4b89e6cffd 100644 --- a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py +++ b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py @@ -1545,97 +1545,6 @@ class BenchmarkLSTMCellXLA(test.Benchmark): ])) -class WeightNormLSTMCellTest(test.TestCase): - """Compared cell output with pre-calculated values.""" - - def _cell_output(self, cell): - """Calculate cell output""" - - with self.test_session() as sess: - init = init_ops.constant_initializer(0.5) - with variable_scope.variable_scope("root", initializer=init): - x = array_ops.zeros([1, 2]) - c0 = array_ops.zeros([1, 2]) - h0 = array_ops.zeros([1, 2]) - - state0 = rnn_cell.LSTMStateTuple(c0, h0) - - xout, sout = cell()(x, state0) - - sess.run([variables.global_variables_initializer()]) - res = sess.run( - [xout, sout], { - x.name: np.array([[1., 1.]]), - c0.name: 0.1 * np.asarray([[0, 1]]), - h0.name: 0.1 * np.asarray([[2, 3]]), - }) - - actual_state_c = res[1].c - actual_state_h = res[1].h - - return actual_state_c, actual_state_h - - def testBasicCell(self): - """Tests cell w/o peepholes and w/o normalisation""" - - def cell(): - return contrib_rnn_cell.WeightNormLSTMCell( - 2, norm=False, use_peepholes=False) - - actual_c, actual_h = self._cell_output(cell) - - expected_c = np.array([[0.65937078, 0.74983585]]) - expected_h = np.array([[0.44923624, 0.49362513]]) - - self.assertAllClose(expected_c, actual_c, 1e-5) - self.assertAllClose(expected_h, actual_h, 1e-5) - - def testNonbasicCell(self): - """Tests cell with peepholes and w/o normalisation""" - - def cell(): - return contrib_rnn_cell.WeightNormLSTMCell( - 2, norm=False, use_peepholes=True) - - actual_c, actual_h = self._cell_output(cell) - - expected_c = np.array([[0.65937084, 0.7574988]]) - expected_h = np.array([[0.4792085, 0.53470564]]) - - self.assertAllClose(expected_c, actual_c, 1e-5) - self.assertAllClose(expected_h, actual_h, 1e-5) - - def testBasicCellWithNorm(self): - """Tests cell w/o peepholes and with normalisation""" - - def cell(): - return contrib_rnn_cell.WeightNormLSTMCell( - 2, norm=True, use_peepholes=False) - - actual_c, actual_h = self._cell_output(cell) - - expected_c = np.array([[0.50125383, 0.58805949]]) - expected_h = np.array([[0.32770363, 0.37397948]]) - - self.assertAllClose(expected_c, actual_c, 1e-5) - self.assertAllClose(expected_h, actual_h, 1e-5) - - def testNonBasicCellWithNorm(self): - """Tests cell with peepholes and with normalisation""" - - def cell(): - return contrib_rnn_cell.WeightNormLSTMCell( - 2, norm=True, use_peepholes=True) - - actual_c, actual_h = self._cell_output(cell) - - expected_c = np.array([[0.50125383, 0.59587258]]) - expected_h = np.array([[0.35041603, 0.40873795]]) - - self.assertAllClose(expected_c, actual_c, 1e-5) - self.assertAllClose(expected_h, actual_h, 1e-5) - - class WeightNormLSTMCellTest(test.TestCase): """Compared cell output with pre-calculated values.""" diff --git a/tensorflow/contrib/rnn/python/ops/rnn_cell.py b/tensorflow/contrib/rnn/python/ops/rnn_cell.py index 5fee2e93e4e575f647b3e6f132a57c57056726ad..6af9db3f15d6040133bfffdc1891a833e625958c 100644 --- a/tensorflow/contrib/rnn/python/ops/rnn_cell.py +++ b/tensorflow/contrib/rnn/python/ops/rnn_cell.py @@ -2285,7 +2285,7 @@ class GLSTMCell(rnn_cell_impl.RNNCell): else: self._state_size = rnn_cell_impl.LSTMStateTuple(num_units, num_units) self._output_size = num_units - self._linear1 = None + self._linear1 = [None] * number_of_groups self._linear2 = None @property @@ -2359,9 +2359,11 @@ class GLSTMCell(rnn_cell_impl.RNNCell): self._group_shape[0]) ], axis=1) - if self._linear1 is None: - self._linear1 = _Linear(x_g_id, 4 * self._group_shape[1], False) - R_k = self._linear1(x_g_id) # pylint: disable=invalid-name + linear = self._linear1[group_id] + if linear is None: + linear = _Linear(x_g_id, 4 * self._group_shape[1], False) + self._linear1[group_id] = linear + R_k = linear(x_g_id) # pylint: disable=invalid-name i_k, j_k, f_k, o_k = array_ops.split(R_k, 4, 1) i_parts.append(i_k) diff --git a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py index 95dea312f3a4e77176a4bc4af290ad48c078deda..0a53fd66dbe4d28ea102773b9c5bae50b9d18e9c 100644 --- a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py +++ b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py @@ -331,7 +331,7 @@ def _luong_score(query, keys, scale): # batched matmul on: # [batch_size, 1, depth] . [batch_size, depth, max_time] # resulting in an output shape of: - # [batch_time, 1, max_time]. + # [batch_size, 1, max_time]. # we then squeeze out the center singleton dimension. score = math_ops.matmul(query, keys, transpose_b=True) score = array_ops.squeeze(score, [1]) @@ -924,8 +924,7 @@ class LuongMonotonicAttention(_BaseMonotonicAttentionMechanism): _monotonic_probability_fn, sigmoid_noise=sigmoid_noise, mode=mode, seed=sigmoid_noise_seed) super(LuongMonotonicAttention, self).__init__( - query_layer=layers_core.Dense( - num_units, name="query_layer", use_bias=False, dtype=dtype), + query_layer=None, memory_layer=layers_core.Dense( num_units, name="memory_layer", use_bias=False, dtype=dtype), memory=memory, diff --git a/tensorflow/contrib/seq2seq/python/ops/helper.py b/tensorflow/contrib/seq2seq/python/ops/helper.py index ef3722ee41bb0b49e5f81d4d6514e2f40d2ad9f1..3245cc5e72154289ea3ba000b9a30586a7ad03a9 100644 --- a/tensorflow/contrib/seq2seq/python/ops/helper.py +++ b/tensorflow/contrib/seq2seq/python/ops/helper.py @@ -184,6 +184,7 @@ class TrainingHelper(Helper): """ with ops.name_scope(name, "TrainingHelper", [inputs, sequence_length]): inputs = ops.convert_to_tensor(inputs, name="inputs") + self._inputs = inputs if not time_major: inputs = nest.map_structure(_transpose_batch_time, inputs) @@ -200,6 +201,14 @@ class TrainingHelper(Helper): self._batch_size = array_ops.size(sequence_length) + @property + def inputs(self): + return self._inputs + + @property + def sequence_length(self): + return self._sequence_length + @property def batch_size(self): return self._batch_size diff --git a/tensorflow/contrib/session_bundle/bundle_shim.py b/tensorflow/contrib/session_bundle/bundle_shim.py index 3149875e41f6f77b3bcbc0ab1a150cfdc59ad2ba..69db594f8ae52e608b34cff74650889aaf41a21e 100644 --- a/tensorflow/contrib/session_bundle/bundle_shim.py +++ b/tensorflow/contrib/session_bundle/bundle_shim.py @@ -82,7 +82,8 @@ def _convert_default_signature_to_signature_def(signatures): """ default_signature = signatures.default_signature signature_def = meta_graph_pb2.SignatureDef() - if default_signature.WhichOneof("type") == legacy_constants.REGRESSION_SIGNATURE: + if (default_signature.WhichOneof("type") == + legacy_constants.REGRESSION_SIGNATURE): regression_signature = default_signature.regression_signature signature_def.method_name = signature_constants.REGRESS_METHOD_NAME _add_input_to_signature_def(regression_signature.input.tensor_name, @@ -91,7 +92,8 @@ def _convert_default_signature_to_signature_def(signatures): _add_output_to_signature_def(regression_signature.output.tensor_name, signature_constants.REGRESS_OUTPUTS, signature_def) - elif default_signature.WhichOneof("type") == legacy_constants.CLASSIFICATION_SIGNATURE: + elif (default_signature.WhichOneof("type") == + legacy_constants.CLASSIFICATION_SIGNATURE): classification_signature = default_signature.classification_signature signature_def.method_name = signature_constants.CLASSIFY_METHOD_NAME _add_input_to_signature_def(classification_signature.input.tensor_name, @@ -132,8 +134,10 @@ def _convert_named_signatures_to_signature_def(signatures): signature_constants.PREDICT_OUTPUTS] # TODO(pdudnik): what if there are other signatures? Mimic cr/140900781 once # it is submitted. - if (input_signature.WhichOneof("type") != legacy_constants.GENERIC_SIGNATURE or - output_signature.WhichOneof("type") != legacy_constants.GENERIC_SIGNATURE): + if (input_signature.WhichOneof("type") != + legacy_constants.GENERIC_SIGNATURE or + output_signature.WhichOneof("type") != + legacy_constants.GENERIC_SIGNATURE): raise RuntimeError("Named input and output signatures can only be " "up-converted if they are generic signature. " "Input signature type is %s, output signature type is " diff --git a/tensorflow/contrib/session_bundle/bundle_shim_test.cc b/tensorflow/contrib/session_bundle/bundle_shim_test.cc index 72f32a0f5554e4dd3e7cbf498a57ee6bfba57211..9a1dd9303f43591888dc49984d81c4a0c6af9846 100644 --- a/tensorflow/contrib/session_bundle/bundle_shim_test.cc +++ b/tensorflow/contrib/session_bundle/bundle_shim_test.cc @@ -493,17 +493,15 @@ TEST(BundleShimTest, DefaultAndNamedSignatureWithPredict) { ASSERT_FALSE( actual_signature_def_predict->second.inputs().find("foo-input") == actual_signature_def_predict->second.inputs().end()); - EXPECT_EQ("foo-input", - actual_signature_def_predict->second.inputs() - .find("foo-input") - ->second.name()); + EXPECT_EQ("foo-input", actual_signature_def_predict->second.inputs() + .find("foo-input") + ->second.name()); ASSERT_FALSE( actual_signature_def_predict->second.outputs().find("foo-output") == actual_signature_def_predict->second.outputs().end()); - EXPECT_EQ("foo-output", - actual_signature_def_predict->second.outputs() - .find("foo-output") - ->second.name()); + EXPECT_EQ("foo-output", actual_signature_def_predict->second.outputs() + .find("foo-output") + ->second.name()); EXPECT_EQ(kPredictMethodName, actual_signature_def_predict->second.method_name()); } diff --git a/tensorflow/contrib/session_bundle/signature.cc b/tensorflow/contrib/session_bundle/signature.cc index 7133875ad53e77625bbe799f4f886c074a08f1bd..ed70a5b91b231067e8e69951ef7010406e6b22cf 100644 --- a/tensorflow/contrib/session_bundle/signature.cc +++ b/tensorflow/contrib/session_bundle/signature.cc @@ -38,9 +38,9 @@ namespace { Status BatchSizesMatch(const Tensor& input, const Tensor& output) { // Ensure the number of outputs match the number of inputs. if (input.dim_size(0) != output.dim_size(0)) { - return errors::Internal( - strings::StrCat("Input batch size did not match output batch size: ", - input.dim_size(0), " vs. ", output.dim_size(0))); + return errors::Internal(strings::StrCat( + "Input batch size did not match output batch size: ", input.dim_size(0), + " vs. ", output.dim_size(0))); } return Status::OK(); } @@ -100,8 +100,8 @@ Status GetNamedClassificationSignature( const auto& it = signatures.named_signatures().find(name); if (it == signatures.named_signatures().end()) { return errors::NotFound( - strings::StrCat("Missing signature named \"", name, "\" in: ", - DebugStringIfAvailable(signatures))); + strings::StrCat("Missing signature named \"", name, + "\" in: ", DebugStringIfAvailable(signatures))); } if (!it->second.has_classification_signature()) { return errors::FailedPrecondition( @@ -232,8 +232,8 @@ Status GetNamedSignature(const string& name, const auto& it = signatures.named_signatures().find(name); if (it == signatures.named_signatures().end()) { return errors::NotFound( - strings::StrCat("Missing signature named \"", name, "\" in: ", - DebugStringIfAvailable(signatures))); + strings::StrCat("Missing signature named \"", name, + "\" in: ", DebugStringIfAvailable(signatures))); } *signature = it->second; return Status::OK(); diff --git a/tensorflow/contrib/signal/python/kernel_tests/mfcc_ops_test.py b/tensorflow/contrib/signal/python/kernel_tests/mfcc_ops_test.py index c04f1cf5bad358a14a1827df05a129339502c86f..e7743bdcba180929007d17bdf3b143c64643aacc 100644 --- a/tensorflow/contrib/signal/python/kernel_tests/mfcc_ops_test.py +++ b/tensorflow/contrib/signal/python/kernel_tests/mfcc_ops_test.py @@ -20,6 +20,7 @@ from __future__ import print_function from tensorflow.contrib.signal.python.ops import mfcc_ops from tensorflow.python.framework import dtypes +from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import spectral_ops_test_util @@ -49,6 +50,14 @@ class MFCCTest(test.TestCase): signal = random_ops.random_normal((2, 3, 5)) mfcc_ops.mfccs_from_log_mel_spectrograms(signal).eval() + def test_unknown_shape(self): + """A test that the op runs when shape and rank are unknown.""" + with spectral_ops_test_util.fft_kernel_label_map(): + with self.test_session(use_gpu=True): + signal = array_ops.placeholder_with_default( + random_ops.random_normal((2, 3, 5)), tensor_shape.TensorShape(None)) + self.assertIsNone(signal.shape.ndims) + mfcc_ops.mfccs_from_log_mel_spectrograms(signal).eval() if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/signal/python/ops/mfcc_ops.py b/tensorflow/contrib/signal/python/ops/mfcc_ops.py index 6cef95f742515709f0f41632358c2d8663daed2c..4e842f7f10ae07448cc07e5f636ae80a820e656f 100644 --- a/tensorflow/contrib/signal/python/ops/mfcc_ops.py +++ b/tensorflow/contrib/signal/python/ops/mfcc_ops.py @@ -105,4 +105,4 @@ def mfccs_from_log_mel_spectrograms(log_mel_spectrograms, name=None): num_mel_bins = array_ops.shape(log_mel_spectrograms)[-1] dct2 = spectral_ops.dct(log_mel_spectrograms) - return dct2 * math_ops.rsqrt(num_mel_bins * 2.0) + return dct2 * math_ops.rsqrt(math_ops.to_float(num_mel_bins) * 2.0) diff --git a/tensorflow/contrib/slim/python/slim/learning.py b/tensorflow/contrib/slim/python/slim/learning.py index 54362c87b561595697ee64b9d5e565fdc3f0bbe0..83f33806e055cca764c9f52f318c5b71f70e0551 100644 --- a/tensorflow/contrib/slim/python/slim/learning.py +++ b/tensorflow/contrib/slim/python/slim/learning.py @@ -738,6 +738,7 @@ def train(train_op, if summary_writer is not None: train_step_kwargs['summary_writer'] = sv.summary_writer + total_loss = 0 should_retry = True while should_retry: try: diff --git a/tensorflow/contrib/summary/summary_ops.py b/tensorflow/contrib/summary/summary_ops.py index ee661dfdc11451bb72bc2741b0b54ebf5c1e6543..a6968d8b2a67809e3e63d099ad9448efd619b4d9 100644 --- a/tensorflow/contrib/summary/summary_ops.py +++ b/tensorflow/contrib/summary/summary_ops.py @@ -202,7 +202,7 @@ def create_file_writer(logdir, if flush_millis is None: flush_millis = constant_op.constant(2 * 60 * 1000) if filename_suffix is None: - filename_suffix = constant_op.constant("") + filename_suffix = constant_op.constant(".v2") return _make_summary_writer( name, gen_summary_ops.create_summary_file_writer, diff --git a/tensorflow/contrib/tensor_forest/client/random_forest.py b/tensorflow/contrib/tensor_forest/client/random_forest.py index a998ac1e111090a3702c0499a54ef1a5c1b3ac90..4abcc20ed334e706c8ae59e2127dfd6f4e152361 100644 --- a/tensorflow/contrib/tensor_forest/client/random_forest.py +++ b/tensorflow/contrib/tensor_forest/client/random_forest.py @@ -18,7 +18,7 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib import layers - +from tensorflow.contrib.learn.python.learn.estimators import constants from tensorflow.contrib.learn.python.learn.estimators import estimator from tensorflow.contrib.learn.python.learn.estimators import head as head_lib from tensorflow.contrib.learn.python.learn.estimators import model_fn as model_fn_lib @@ -43,8 +43,8 @@ from tensorflow.python.training import training_util KEYS_NAME = 'keys' LOSS_NAME = 'rf_training_loss' TREE_PATHS_PREDICTION_KEY = 'tree_paths' -VARIANCE_PREDICTION_KEY = 'regression_variance' - +VARIANCE_PREDICTION_KEY = 'prediction_variance' +ALL_SERVING_KEY = 'tensorforest_all' EPSILON = 0.000001 @@ -134,7 +134,8 @@ def get_model_fn(params, trainer_id=0, report_feature_importances=False, local_eval=False, - head_scope=None): + head_scope=None, + include_all_in_serving=False): """Return a model function given a way to construct a graph builder.""" if model_head is None: model_head = get_default_head(params, weights_name) @@ -238,7 +239,13 @@ def get_model_fn(params, model_ops.predictions[TREE_PATHS_PREDICTION_KEY] = tree_paths model_ops.predictions[VARIANCE_PREDICTION_KEY] = regression_variance - + if include_all_in_serving: + # In order to serve the variance we need to add the prediction dict + # to output_alternatives dict. + if not model_ops.output_alternatives: + model_ops.output_alternatives = {} + model_ops.output_alternatives[ALL_SERVING_KEY] = ( + constants.ProblemType.UNSPECIFIED, model_ops.predictions) return model_ops return _model_fn @@ -293,7 +300,8 @@ class TensorForestEstimator(estimator.Estimator): report_feature_importances=False, local_eval=False, version=None, - head=None): + head=None, + include_all_in_serving=False): """Initializes a TensorForestEstimator instance. Args: @@ -339,6 +347,23 @@ class TensorForestEstimator(estimator.Estimator): version: Unused. head: A heads_lib.Head object that calculates losses and such. If None, one will be automatically created based on params. + include_all_in_serving: if True, allow preparation of the complete + prediction dict including the variance to be exported for serving with + the Servo lib; and it also requires calling export_savedmodel with + default_output_alternative_key=ALL_SERVING_KEY, i.e. + estimator.export_savedmodel(export_dir_base=your_export_dir, + serving_input_fn=your_export_input_fn, + default_output_alternative_key=ALL_SERVING_KEY) + if False, resort to default behavior, i.e. export scores and + probabilities but no variances. In this case + default_output_alternative_key should be None while calling + export_savedmodel(). + Note, that due to backward compatibility we cannot always set + include_all_in_serving to True because in this case calling + export_saved_model() without + default_output_alternative_key=ALL_SERVING_KEY (legacy behavior) the + saved_model_export_utils.get_output_alternatives() would raise + ValueError. Returns: A `TensorForestEstimator` instance. @@ -357,7 +382,9 @@ class TensorForestEstimator(estimator.Estimator): num_trainers=num_trainers, trainer_id=trainer_id, report_feature_importances=report_feature_importances, - local_eval=local_eval), + local_eval=local_eval, + include_all_in_serving=include_all_in_serving, + ), model_dir=model_dir, config=config, feature_engineering_fn=feature_engineering_fn) diff --git a/tensorflow/contrib/tensor_forest/hybrid/core/ops/hard_routing_function_op.cc b/tensorflow/contrib/tensor_forest/hybrid/core/ops/hard_routing_function_op.cc index 76cfb4c9ca02269f9fee61c767acc6cb4a0b4ca7..cf0db788a419f64ed891df8aa097fa8826f6de91 100644 --- a/tensorflow/contrib/tensor_forest/hybrid/core/ops/hard_routing_function_op.cc +++ b/tensorflow/contrib/tensor_forest/hybrid/core/ops/hard_routing_function_op.cc @@ -99,18 +99,17 @@ class HardRoutingFunction : public OpKernel { const Tensor& tree_biases_tensor = context->input(2); if (input_data.shape().dim_size(0) > 0) { - OP_REQUIRES(context, input_data.shape().dims() == 2, - errors::InvalidArgument( - "input_data should be two-dimensional")); + OP_REQUIRES( + context, input_data.shape().dims() == 2, + errors::InvalidArgument("input_data should be two-dimensional")); } // Check tensor bounds. if (!CheckTensorBounds(context, input_data)) return; - const int32 num_data = static_cast( - input_data.shape().dim_size(0)); - const int32 num_features = static_cast( - input_data.shape().dim_size(1)); + const int32 num_data = static_cast(input_data.shape().dim_size(0)); + const int32 num_features = + static_cast(input_data.shape().dim_size(1)); Tensor* output_probability = nullptr; TensorShape output_probability_shape; @@ -125,9 +124,8 @@ class HardRoutingFunction : public OpKernel { OP_REQUIRES_OK(context, context->allocate_output(0, output_probability_shape, &output_probability)); - OP_REQUIRES_OK(context, - context->allocate_output(1, output_path_shape, - &output_path)); + OP_REQUIRES_OK( + context, context->allocate_output(1, output_path_shape, &output_path)); auto out_probability = output_probability->tensor(); auto out_path = output_path->tensor(); @@ -144,12 +142,11 @@ class HardRoutingFunction : public OpKernel { out_probability(i, 0) = 1.0; out_path(i, 0) = 0; for (int j = 0; j < tree_depth_ - 1; j++) { - float left_prob = LeftProbability(point, - tree_parameters_tensor.Slice(j, j+1), - tree_biases(j), - num_features); + float left_prob = + LeftProbability(point, tree_parameters_tensor.Slice(j, j + 1), + tree_biases(j), num_features); - int32 left_child = 2*node + 1; + int32 left_child = 2 * node + 1; int32 right_child = left_child + 1; float dot_product = 0.0; diff --git a/tensorflow/contrib/tensor_forest/hybrid/core/ops/k_feature_gradient_op.cc b/tensorflow/contrib/tensor_forest/hybrid/core/ops/k_feature_gradient_op.cc index 28f50f1a32eb1827a242d527cd42c58487877959..f64155fa55af22d57c6619d8a39da0455dc0de65 100644 --- a/tensorflow/contrib/tensor_forest/hybrid/core/ops/k_feature_gradient_op.cc +++ b/tensorflow/contrib/tensor_forest/hybrid/core/ops/k_feature_gradient_op.cc @@ -85,12 +85,9 @@ REGISTER_OP("KFeatureGradient") class KFeatureGradient : public OpKernel { public: - explicit KFeatureGradient(OpKernelConstruction* context) - : OpKernel(context) { - OP_REQUIRES_OK(context, context->GetAttr("layer_num", - &layer_num_)); - OP_REQUIRES_OK(context, context->GetAttr("random_seed", - &random_seed_)); + explicit KFeatureGradient(OpKernelConstruction* context) : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("layer_num", &layer_num_)); + OP_REQUIRES_OK(context, context->GetAttr("random_seed", &random_seed_)); } void Compute(OpKernelContext* context) override { @@ -101,14 +98,14 @@ class KFeatureGradient : public OpKernel { const Tensor& routing_tensor = context->input(3); // Extract dimensions from input tensors. - const int32 num_data = static_cast( - input_data_tensor.shape().dim_size(0)); - const int32 num_features = static_cast( - input_data_tensor.shape().dim_size(1)); - const int32 num_nodes = static_cast( - tree_parameters_tensor.shape().dim_size(0)); - const int32 num_features_per_node = static_cast( - tree_parameters_tensor.shape().dim_size(1)); + const int32 num_data = + static_cast(input_data_tensor.shape().dim_size(0)); + const int32 num_features = + static_cast(input_data_tensor.shape().dim_size(1)); + const int32 num_nodes = + static_cast(tree_parameters_tensor.shape().dim_size(0)); + const int32 num_features_per_node = + static_cast(tree_parameters_tensor.shape().dim_size(1)); // Construct output tensors. Tensor* out_routes = nullptr; @@ -127,12 +124,12 @@ class KFeatureGradient : public OpKernel { out_weights_shape.AddDim(num_nodes); out_weights_shape.AddDim(num_features_per_node); - OP_REQUIRES_OK(context, context->allocate_output( - 0, out_routes_shape, &out_routes)); - OP_REQUIRES_OK(context, context->allocate_output( - 1, out_data_shape, &out_data)); - OP_REQUIRES_OK(context, context->allocate_output( - 2, out_weights_shape, &out_weights)); + OP_REQUIRES_OK(context, + context->allocate_output(0, out_routes_shape, &out_routes)); + OP_REQUIRES_OK(context, + context->allocate_output(1, out_data_shape, &out_data)); + OP_REQUIRES_OK( + context, context->allocate_output(2, out_weights_shape, &out_weights)); tensorforest::Initialize(*out_data, 0.0f); @@ -148,18 +145,13 @@ class KFeatureGradient : public OpKernel { std::vector feature_set; for (int i = 0; i < num_data; i++) { - const Tensor point = input_data_tensor.Slice(i, i+1); + const Tensor point = input_data_tensor.Slice(i, i + 1); feature_set.clear(); // Traverse the tree from the bottom up. for (int j = num_nodes - 1; j >= 0; j--) { - tensorforest::GetFeatureSet( - layer_num_, - j, - random_seed_, - num_features, - num_features_per_node, - &feature_set); + tensorforest::GetFeatureSet(layer_num_, j, random_seed_, num_features, + num_features_per_node, &feature_set); // Compute routing gradient. // j is a leaf node. @@ -170,12 +162,8 @@ class KFeatureGradient : public OpKernel { int32 right_child = left_child + 1; float left_prob = LeftProbabilityK( - point, - feature_set, - tree_parameters_tensor.Slice(j, j+1), - tree_biases(j), - num_features, - num_features_per_node); + point, feature_set, tree_parameters_tensor.Slice(j, j + 1), + tree_biases(j), num_features, num_features_per_node); float right_prob = 1.0f - left_prob; diff --git a/tensorflow/contrib/tensor_forest/hybrid/core/ops/k_feature_routing_function_op.cc b/tensorflow/contrib/tensor_forest/hybrid/core/ops/k_feature_routing_function_op.cc index 9bc42eb61fae013de3e4ea73aaf371cdaa4ccf9a..e7cafb144da84865ad2b4ea0c33866ddb89119a5 100644 --- a/tensorflow/contrib/tensor_forest/hybrid/core/ops/k_feature_routing_function_op.cc +++ b/tensorflow/contrib/tensor_forest/hybrid/core/ops/k_feature_routing_function_op.cc @@ -43,7 +43,6 @@ using shape_inference::ShapeHandle; using tensorforest::CheckTensorBounds; using tensorforest::LeftProbabilityK; - // The term 'routing function' is synonymous with 'the probability // that an instance is routed to each leaf node.' It is defined in // 'Deep Neural Decision Forests' by Kontschieder et al. @@ -96,10 +95,8 @@ class KFeatureRoutingFunction : public OpKernel { OP_REQUIRES_OK(context, context->GetAttr("max_nodes", &max_nodes_)); OP_REQUIRES_OK(context, context->GetAttr("num_features_per_node", &num_features_per_node_)); - OP_REQUIRES_OK(context, context->GetAttr("layer_num", - &layer_num_)); - OP_REQUIRES_OK(context, context->GetAttr("random_seed", - &random_seed_)); + OP_REQUIRES_OK(context, context->GetAttr("layer_num", &layer_num_)); + OP_REQUIRES_OK(context, context->GetAttr("random_seed", &random_seed_)); } void Compute(OpKernelContext* context) override { @@ -108,27 +105,25 @@ class KFeatureRoutingFunction : public OpKernel { const Tensor& tree_biases_tensor = context->input(2); if (input_data.shape().dim_size(0) > 0) { - OP_REQUIRES(context, input_data.shape().dims() == 2, - errors::InvalidArgument( - "input_data should be two-dimensional")); + OP_REQUIRES( + context, input_data.shape().dims() == 2, + errors::InvalidArgument("input_data should be two-dimensional")); } // Check tensor bounds. if (!CheckTensorBounds(context, input_data)) return; - const int32 num_data = static_cast( - input_data.shape().dim_size(0)); - const int32 num_features = static_cast( - input_data.shape().dim_size(1)); + const int32 num_data = static_cast(input_data.shape().dim_size(0)); + const int32 num_features = + static_cast(input_data.shape().dim_size(1)); Tensor* output_probabilities = nullptr; TensorShape output_shape; output_shape.AddDim(num_data); output_shape.AddDim(max_nodes_); - OP_REQUIRES_OK(context, - context->allocate_output(0, output_shape, - &output_probabilities)); + OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, + &output_probabilities)); auto out_probs = output_probabilities->tensor(); const auto tree_biases = tree_biases_tensor.tensor(); @@ -136,30 +131,22 @@ class KFeatureRoutingFunction : public OpKernel { // Iteratively compute the probability of reaching each leaf. std::vector feature_set; for (int i = 0; i < num_data; i++) { - const Tensor point = input_data.Slice(i, i+1); + const Tensor point = input_data.Slice(i, i + 1); out_probs(i, 0) = 1.0f; for (int j = 0; j < max_nodes_ / 2; j++) { feature_set.clear(); - tensorforest::GetFeatureSet( - layer_num_, - i, - random_seed_, - num_features, - num_features_per_node_, - &feature_set); - - int32 left_child = 2*j + 1; + tensorforest::GetFeatureSet(layer_num_, i, random_seed_, num_features, + num_features_per_node_, &feature_set); + + int32 left_child = 2 * j + 1; int32 right_child = left_child + 1; float prob = out_probs(i, j); - float left_prob = LeftProbabilityK(point, - feature_set, - tree_parameters_tensor.Slice(j, j+1), - tree_biases(j), - num_features, - num_features_per_node_); + float left_prob = LeftProbabilityK( + point, feature_set, tree_parameters_tensor.Slice(j, j + 1), + tree_biases(j), num_features, num_features_per_node_); out_probs(i, left_child) = prob * left_prob; out_probs(i, right_child) = prob * (1.0f - left_prob); diff --git a/tensorflow/contrib/tensor_forest/hybrid/core/ops/routing_function_op.cc b/tensorflow/contrib/tensor_forest/hybrid/core/ops/routing_function_op.cc index 4027e732b3f52585c2149c3cdc71535664f04ed4..0c2eaabe8f3e1e1377a8d5c5308aaec00030a20f 100644 --- a/tensorflow/contrib/tensor_forest/hybrid/core/ops/routing_function_op.cc +++ b/tensorflow/contrib/tensor_forest/hybrid/core/ops/routing_function_op.cc @@ -90,46 +90,43 @@ class RoutingFunction : public OpKernel { const Tensor& tree_biases_tensor = context->input(2); if (input_data.shape().dim_size(0) > 0) { - OP_REQUIRES(context, input_data.shape().dims() == 2, - errors::InvalidArgument( - "input_data should be two-dimensional")); + OP_REQUIRES( + context, input_data.shape().dims() == 2, + errors::InvalidArgument("input_data should be two-dimensional")); } // Check tensor bounds. if (!CheckTensorBounds(context, input_data)) return; - const int32 num_data = static_cast( - input_data.shape().dim_size(0)); - const int32 num_features = static_cast( - input_data.shape().dim_size(1)); + const int32 num_data = static_cast(input_data.shape().dim_size(0)); + const int32 num_features = + static_cast(input_data.shape().dim_size(1)); Tensor* output_probabilities = nullptr; TensorShape output_shape; output_shape.AddDim(num_data); output_shape.AddDim(max_nodes_); - OP_REQUIRES_OK(context, - context->allocate_output(0, output_shape, - &output_probabilities)); + OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, + &output_probabilities)); auto out_probs = output_probabilities->tensor(); const auto tree_biases = tree_biases_tensor.tensor(); // Iteratively compute the probability of reaching each leaf. for (int i = 0; i < num_data; i++) { - const Tensor point = input_data.Slice(i, i+1); + const Tensor point = input_data.Slice(i, i + 1); out_probs(i, 0) = 1.0; for (int j = 0; j < max_nodes_ / 2; j++) { - int32 left_child = 2*j + 1; + int32 left_child = 2 * j + 1; int32 right_child = left_child + 1; float prob = out_probs(i, j); - float left_prob = LeftProbability(point, - tree_parameters_tensor.Slice(j, j+1), - tree_biases(j), - num_features); + float left_prob = + LeftProbability(point, tree_parameters_tensor.Slice(j, j + 1), + tree_biases(j), num_features); out_probs(i, left_child) = prob * left_prob; out_probs(i, right_child) = prob * (1.0 - left_prob); diff --git a/tensorflow/contrib/tensor_forest/hybrid/core/ops/stochastic_hard_routing_function_op.cc b/tensorflow/contrib/tensor_forest/hybrid/core/ops/stochastic_hard_routing_function_op.cc index 66aa293dc1cb93b82f06d838ad7b0f9c09761585..c9df09bfda44e665ed013da383e1e9a2c665c454 100644 --- a/tensorflow/contrib/tensor_forest/hybrid/core/ops/stochastic_hard_routing_function_op.cc +++ b/tensorflow/contrib/tensor_forest/hybrid/core/ops/stochastic_hard_routing_function_op.cc @@ -96,10 +96,9 @@ class StochasticHardRoutingFunction : public OpKernel { explicit StochasticHardRoutingFunction(OpKernelConstruction* context) : OpKernel(context) { OP_REQUIRES_OK(context, context->GetAttr("tree_depth", &tree_depth_)); - OP_REQUIRES_OK(context, context->GetAttr("random_seed", - &random_seed_)); + OP_REQUIRES_OK(context, context->GetAttr("random_seed", &random_seed_)); single_rand_ = std::unique_ptr( - new random::PhiloxRandom(random_seed_)); + new random::PhiloxRandom(random_seed_)); rng_ = std::unique_ptr( new random::SimplePhilox(single_rand_.get())); } @@ -111,20 +110,19 @@ class StochasticHardRoutingFunction : public OpKernel { const Tensor& tree_biases_tensor = context->input(2); if (input_data.shape().dim_size(0) > 0) { - OP_REQUIRES(context, input_data.shape().dims() == 2, - errors::InvalidArgument( - "input_data should be two-dimensional")); + OP_REQUIRES( + context, input_data.shape().dims() == 2, + errors::InvalidArgument("input_data should be two-dimensional")); } // Check tensor bounds. if (!CheckTensorBounds(context, input_data)) return; - const int32 num_data = static_cast( - input_data.shape().dim_size(0)); - const int32 num_features = static_cast( - input_data.shape().dim_size(1)); - const int32 num_nodes = static_cast( - tree_parameters_tensor.shape().dim_size(0)); + const int32 num_data = static_cast(input_data.shape().dim_size(0)); + const int32 num_features = + static_cast(input_data.shape().dim_size(1)); + const int32 num_nodes = + static_cast(tree_parameters_tensor.shape().dim_size(0)); Tensor* output_probability = nullptr; TensorShape output_probability_shape; @@ -139,9 +137,8 @@ class StochasticHardRoutingFunction : public OpKernel { OP_REQUIRES_OK(context, context->allocate_output(0, output_probability_shape, &output_probability)); - OP_REQUIRES_OK(context, - context->allocate_output(1, output_path_shape, - &output_path)); + OP_REQUIRES_OK( + context, context->allocate_output(1, output_path_shape, &output_path)); auto out_probability = output_probability->tensor(); auto out_path = output_path->tensor(); @@ -150,19 +147,18 @@ class StochasticHardRoutingFunction : public OpKernel { // Stochastically traverse the tree to a leaf. for (int i = 0; i < num_data; i++) { - const Tensor point = input_data.Slice(i, i+1); + const Tensor point = input_data.Slice(i, i + 1); int32 node = 0; out_probability(i, 0) = 1.0; out_path(i, 0) = 0; for (int j = 0; j < tree_depth_ - 1; j++) { - int32 left_child = 2*node + 1; + int32 left_child = 2 * node + 1; int32 right_child = left_child + 1; - float left_prob = LeftProbability(point, - tree_parameters_tensor.Slice(j, j+1), - tree_biases(j), - num_features); + float left_prob = + LeftProbability(point, tree_parameters_tensor.Slice(j, j + 1), + tree_biases(j), num_features); if (left_prob < rng_->RandFloat()) { CHECK_LT(i, num_data); diff --git a/tensorflow/contrib/tensor_forest/hybrid/core/ops/stochastic_hard_routing_gradient_op.cc b/tensorflow/contrib/tensor_forest/hybrid/core/ops/stochastic_hard_routing_gradient_op.cc index 0b5afe464f4b9608af0feca584aaa799f5980f46..b0d8b832b5437db7a4b3026e80ae99d0391d7f7a 100644 --- a/tensorflow/contrib/tensor_forest/hybrid/core/ops/stochastic_hard_routing_gradient_op.cc +++ b/tensorflow/contrib/tensor_forest/hybrid/core/ops/stochastic_hard_routing_gradient_op.cc @@ -149,14 +149,14 @@ class StochasticHardRoutingGradient : public OpKernel { TensorShape output_bias_shape; output_bias_shape.AddDim(num_data); - OP_REQUIRES_OK(context, context->allocate_output( - 0, output_routing_shape, &output_routing)); - OP_REQUIRES_OK(context, context->allocate_output( - 1, output_data_shape, &output_data)); - OP_REQUIRES_OK(context, context->allocate_output( - 2, output_parameters_shape, &output_parameters)); - OP_REQUIRES_OK(context, context->allocate_output( - 3, output_bias_shape, &output_bias)); + OP_REQUIRES_OK(context, context->allocate_output(0, output_routing_shape, + &output_routing)); + OP_REQUIRES_OK( + context, context->allocate_output(1, output_data_shape, &output_data)); + OP_REQUIRES_OK(context, context->allocate_output(2, output_parameters_shape, + &output_parameters)); + OP_REQUIRES_OK( + context, context->allocate_output(3, output_bias_shape, &output_bias)); tensorforest::Initialize(*output_routing, 0.0); tensorforest::Initialize(*output_data, 0.0); @@ -178,7 +178,7 @@ class StochasticHardRoutingGradient : public OpKernel { const Tensor point = input_data.Slice(i, i + 1); // Traverses the tree from the bottom up. - for (int j = tree_depth_-1; j > -1; j--) { + for (int j = tree_depth_ - 1; j > -1; j--) { int32 node = path(i, j); CHECK_LT(node, num_nodes); diff --git a/tensorflow/contrib/tensor_forest/hybrid/core/ops/unpack_path_op.cc b/tensorflow/contrib/tensor_forest/hybrid/core/ops/unpack_path_op.cc index cacad03e274c3279eb3706e71e1bcdf8433ca1ef..25825a78a1498490009fe4ff6bbfc67493727037 100644 --- a/tensorflow/contrib/tensor_forest/hybrid/core/ops/unpack_path_op.cc +++ b/tensorflow/contrib/tensor_forest/hybrid/core/ops/unpack_path_op.cc @@ -64,8 +64,7 @@ REGISTER_OP("UnpackPath") class UnpackPath : public OpKernel { public: - explicit UnpackPath(OpKernelConstruction* context) - : OpKernel(context) {} + explicit UnpackPath(OpKernelConstruction* context) : OpKernel(context) {} void Compute(OpKernelContext* context) override { VLOG(1) << "unpack start"; @@ -73,8 +72,8 @@ class UnpackPath : public OpKernel { const Tensor& path_values_tensor = context->input(1); const int32 num_data = static_cast(path_tensor.shape().dim_size(0)); - const int32 tree_depth = static_cast( - path_tensor.shape().dim_size(1)); + const int32 tree_depth = + static_cast(path_tensor.shape().dim_size(1)); const int32 num_nodes = MathUtil::IPow(2, tree_depth) - 1; @@ -107,7 +106,6 @@ class UnpackPath : public OpKernel { } }; -REGISTER_KERNEL_BUILDER(Name("UnpackPath").Device(DEVICE_CPU), - UnpackPath); +REGISTER_KERNEL_BUILDER(Name("UnpackPath").Device(DEVICE_CPU), UnpackPath); } // namespace tensorflow diff --git a/tensorflow/contrib/tensor_forest/hybrid/core/ops/utils.cc b/tensorflow/contrib/tensor_forest/hybrid/core/ops/utils.cc index c091a73c4e48a47bdccea3ec99371faab9c586c2..34388fe1aab72895a805141ec66a71ecf0f42ba4 100644 --- a/tensorflow/contrib/tensor_forest/hybrid/core/ops/utils.cc +++ b/tensorflow/contrib/tensor_forest/hybrid/core/ops/utils.cc @@ -25,9 +25,7 @@ namespace tensorforest { using tensorflow::Tensor; -float LeftProbability(const Tensor& point, - const Tensor& weight, - float bias, +float LeftProbability(const Tensor& point, const Tensor& weight, float bias, int num_features) { const auto p = point.unaligned_flat(); const auto w = weight.unaligned_flat(); @@ -41,11 +39,8 @@ float LeftProbability(const Tensor& point, return 1.0 / (1.0 + exp(-dot_product + bias)); } -float LeftProbabilityK(const Tensor& point, - std::vector feature_set, - const Tensor& weight, - float bias, - int num_features, +float LeftProbabilityK(const Tensor& point, std::vector feature_set, + const Tensor& weight, float bias, int num_features, int k) { const auto p = point.unaligned_flat(); const auto w = weight.unaligned_flat(); diff --git a/tensorflow/contrib/tensor_forest/hybrid/core/ops/utils.h b/tensorflow/contrib/tensor_forest/hybrid/core/ops/utils.h index c5902184f95ea8f97be4a10d1101a38333359d44..69a0143a4e319157a4526ca80fbb3f6472902b31 100644 --- a/tensorflow/contrib/tensor_forest/hybrid/core/ops/utils.h +++ b/tensorflow/contrib/tensor_forest/hybrid/core/ops/utils.h @@ -24,16 +24,11 @@ namespace tensorflow { namespace tensorforest { // Returns the probability that the point falls to the left. -float LeftProbability(const Tensor& point, - const Tensor& weight, - float bias, +float LeftProbability(const Tensor& point, const Tensor& weight, float bias, int num_features); -float LeftProbabilityK(const Tensor& point, - std::vector feature_set, - const Tensor& weight, - float bias, - int num_features, +float LeftProbabilityK(const Tensor& point, std::vector feature_set, + const Tensor& weight, float bias, int num_features, int k); // Returns a random set of num_features_to_pick features in the @@ -49,5 +44,3 @@ void GetFeatureSet(int32 tree_num, int32 node_num, int32 random_seed, } // namespace tensorflow #endif // LEARNING_LIB_TENSOR_FOREST_HYBRID_CORE_OPS_UTILS_H_ - - diff --git a/tensorflow/contrib/tensor_forest/kernels/reinterpret_string_to_float_op.cc b/tensorflow/contrib/tensor_forest/kernels/reinterpret_string_to_float_op.cc index 47b49a379c4b7a17d35b52c1403f67c2f07aeeaf..b21a9179777c21f65435e136aa6082e27fb3b78c 100644 --- a/tensorflow/contrib/tensor_forest/kernels/reinterpret_string_to_float_op.cc +++ b/tensorflow/contrib/tensor_forest/kernels/reinterpret_string_to_float_op.cc @@ -30,15 +30,13 @@ namespace tensorflow { using tensorforest::CheckTensorBounds; - float Convert(const string& in) { const std::size_t intval = std::hash()(in); return static_cast(intval); } - -void Evaluate(const Tensor& input_data, Tensor output_data, - int32 start, int32 end) { +void Evaluate(const Tensor& input_data, Tensor output_data, int32 start, + int32 end) { auto out_data = output_data.unaligned_flat(); const auto in_data = input_data.unaligned_flat(); @@ -59,9 +57,8 @@ class ReinterpretStringToFloat : public OpKernel { if (!CheckTensorBounds(context, input_data)) return; Tensor* output_data = nullptr; - OP_REQUIRES_OK(context, - context->allocate_output(0, input_data.shape(), - &output_data)); + OP_REQUIRES_OK( + context, context->allocate_output(0, input_data.shape(), &output_data)); // Evaluate input data in parallel. const int32 num_data = static_cast(input_data.NumElements()); @@ -73,8 +70,8 @@ class ReinterpretStringToFloat : public OpKernel { auto work = [&input_data, output_data, num_data](int64 start, int64 end) { CHECK(start <= end); CHECK(end <= num_data); - Evaluate(input_data, *output_data, - static_cast(start), static_cast(end)); + Evaluate(input_data, *output_data, static_cast(start), + static_cast(end)); }; Shard(num_threads, worker_threads->workers, num_data, 100, work); } diff --git a/tensorflow/contrib/tensor_forest/kernels/scatter_add_ndim_op.cc b/tensorflow/contrib/tensor_forest/kernels/scatter_add_ndim_op.cc index dd2a98b08cdb486c98c161390a3a1f81d31e1f4b..60740c2be3703141805c7eae0ac384edf934ab3d 100644 --- a/tensorflow/contrib/tensor_forest/kernels/scatter_add_ndim_op.cc +++ b/tensorflow/contrib/tensor_forest/kernels/scatter_add_ndim_op.cc @@ -22,7 +22,6 @@ #include "tensorflow/core/framework/shape_inference.h" #include "tensorflow/core/platform/logging.h" - namespace tensorflow { using tensorforest::CheckTensorBounds; @@ -38,20 +37,19 @@ class ScatterAddNdim : public OpKernel { if (indices_tensor.shape().dim_size(0) > 0) { OP_REQUIRES(context, indices_tensor.shape().dims() == 2, - errors::InvalidArgument( - "indices should be two-dimensional")); + errors::InvalidArgument("indices should be two-dimensional")); const int32 delta_dims = deltas_tensor.shape().dims(); OP_REQUIRES( context, indices_tensor.shape().dim_size(1) + delta_dims == - input_tensor.shape().dims() + 1, + input_tensor.shape().dims() + 1, errors::InvalidArgument( "Number of indices dimensions should be the same as input " "rank.")); OP_REQUIRES( context, indices_tensor.shape().dim_size(0) == - deltas_tensor.shape().dim_size(0), + deltas_tensor.shape().dim_size(0), errors::InvalidArgument( "Number of updates should be same as number of indices.")); } else { @@ -68,8 +66,8 @@ class ScatterAddNdim : public OpKernel { const auto indices = indices_tensor.tensor(); const auto deltas = deltas_tensor.unaligned_flat(); - const int32 num_dims = static_cast( - indices_tensor.shape().dim_size(1)); + const int32 num_dims = + static_cast(indices_tensor.shape().dim_size(1)); // Figure out if indices don't specify a complete position in the // input tensor. @@ -80,10 +78,9 @@ class ScatterAddNdim : public OpKernel { // Calculate index multipliers. std::vector multipliers; - OP_REQUIRES( - context, input.size() < std::numeric_limits::max(), - errors::InvalidArgument( - "Input must contain less than 2^31 total elements")); + OP_REQUIRES(context, input.size() < std::numeric_limits::max(), + errors::InvalidArgument( + "Input must contain less than 2^31 total elements")); int32 last_size = static_cast(input.size()); for (int32 j = 0; j < num_dims; j++) { diff --git a/tensorflow/contrib/tensor_forest/kernels/tree_utils.cc b/tensorflow/contrib/tensor_forest/kernels/tree_utils.cc index 94e12cea5a072f0746e642196d55f3a3b13a16c3..44997ec5d6d5fdb9aab52ab7a50f46a731bfda66 100644 --- a/tensorflow/contrib/tensor_forest/kernels/tree_utils.cc +++ b/tensorflow/contrib/tensor_forest/kernels/tree_utils.cc @@ -65,8 +65,8 @@ void GetTwoBest(int max, const std::function& score_fn, float ClassificationSplitScore( const Eigen::Tensor& splits, - const Eigen::Tensor& rights, - int32 num_classes, int i) { + const Eigen::Tensor& rights, int32 num_classes, + int i) { Eigen::array offsets; // Class counts are stored with the total in [0], so the length of each // count vector is num_classes + 1. @@ -74,7 +74,7 @@ float ClassificationSplitScore( Eigen::array extents; extents[0] = num_classes; return WeightedGiniImpurity(splits.slice(offsets, extents)) + - WeightedGiniImpurity(rights.slice(offsets, extents)); + WeightedGiniImpurity(rights.slice(offsets, extents)); } void GetTwoBestClassification(const Tensor& total_counts, @@ -90,29 +90,28 @@ void GetTwoBestClassification(const Tensor& total_counts, // in seg faults, so we have to go with flat views of these tensors. However, // it is still pretty efficient because we put off evaluation until the // score is actually returned. - const auto tc = total_counts.Slice( - accumulator, accumulator + 1).unaligned_flat(); + const auto tc = + total_counts.Slice(accumulator, accumulator + 1).unaligned_flat(); // TODO(gilberth): See if we can delay evaluation here by templating the // arguments to ClassificationSplitScore. - const Eigen::Tensor splits = split_counts.Slice( - accumulator, accumulator + 1).unaligned_flat(); + const Eigen::Tensor splits = + split_counts.Slice(accumulator, accumulator + 1).unaligned_flat(); Eigen::array bcast; bcast[0] = num_splits; const Eigen::Tensor rights = tc.broadcast(bcast) - splits; - std::function score_fn = std::bind( - ClassificationSplitScore, splits, rights, num_classes, - std::placeholders::_1); + std::function score_fn = + std::bind(ClassificationSplitScore, splits, rights, num_classes, + std::placeholders::_1); GetTwoBest(num_splits, score_fn, best_score, best_index, second_best_score, second_best_index); } -int32 BestFeatureClassification( - const Tensor& total_counts, const Tensor& split_counts, - int32 accumulator) { +int32 BestFeatureClassification(const Tensor& total_counts, + const Tensor& split_counts, int32 accumulator) { float best_score; float second_best_score; int best_feature_index; @@ -130,8 +129,7 @@ float RegressionSplitScore( const Eigen::Tensor& splits_square, const Eigen::Tensor& right_sums, const Eigen::Tensor& right_squares, - int32 accumulator, - int32 num_regression_dims, int i) { + int32 accumulator, int32 num_regression_dims, int i) { Eigen::array offsets = {i * num_regression_dims + 1}; Eigen::array extents = {num_regression_dims - 1}; float left_count = splits_count_accessor(accumulator, i, 0); @@ -141,15 +139,15 @@ float RegressionSplitScore( // Guard against divide-by-zero. if (left_count > 0) { - score += WeightedVariance( - splits_sum.slice(offsets, extents), - splits_square.slice(offsets, extents), left_count); + score += + WeightedVariance(splits_sum.slice(offsets, extents), + splits_square.slice(offsets, extents), left_count); } if (right_count > 0) { - score += WeightedVariance(right_sums.slice(offsets, extents), - right_squares.slice(offsets, extents), - right_count); + score += + WeightedVariance(right_sums.slice(offsets, extents), + right_squares.slice(offsets, extents), right_count); } return score; } @@ -159,20 +157,20 @@ void GetTwoBestRegression(const Tensor& total_sums, const Tensor& total_squares, int32 accumulator, float* best_score, int* best_index, float* second_best_score, int* second_best_index) { const int32 num_splits = static_cast(split_sums.shape().dim_size(1)); - const int32 num_regression_dims = static_cast( - split_sums.shape().dim_size(2)); + const int32 num_regression_dims = + static_cast(split_sums.shape().dim_size(2)); // Ideally, Eigen::Tensor::chip would be best to use here but it results // in seg faults, so we have to go with flat views of these tensors. However, // it is still pretty efficient because we put off evaluation until the // score is actually returned. - const auto tc_sum = total_sums.Slice( - accumulator, accumulator + 1).unaligned_flat(); - const auto tc_square = total_squares.Slice( - accumulator, accumulator + 1).unaligned_flat(); - const auto splits_sum = split_sums.Slice( - accumulator, accumulator + 1).unaligned_flat(); - const auto splits_square = split_squares.Slice( - accumulator, accumulator + 1).unaligned_flat(); + const auto tc_sum = + total_sums.Slice(accumulator, accumulator + 1).unaligned_flat(); + const auto tc_square = + total_squares.Slice(accumulator, accumulator + 1).unaligned_flat(); + const auto splits_sum = + split_sums.Slice(accumulator, accumulator + 1).unaligned_flat(); + const auto splits_square = + split_squares.Slice(accumulator, accumulator + 1).unaligned_flat(); // Eigen is infuriating to work with, usually resulting in all kinds of // unhelpful compiler errors when trying something that seems sane. This // helps us do a simple thing like access the first element (the counts) @@ -193,10 +191,10 @@ void GetTwoBestRegression(const Tensor& total_sums, const Tensor& total_squares, best_score, best_index, second_best_score, second_best_index); } -int32 BestFeatureRegression( - const Tensor& total_sums, const Tensor& total_squares, - const Tensor& split_sums, const Tensor& split_squares, - int32 accumulator) { +int32 BestFeatureRegression(const Tensor& total_sums, + const Tensor& total_squares, + const Tensor& split_sums, + const Tensor& split_squares, int32 accumulator) { float best_score; float second_best_score; int best_feature_index; @@ -207,10 +205,11 @@ int32 BestFeatureRegression( return best_feature_index; } -bool BestSplitDominatesRegression( - const Tensor& total_sums, const Tensor& total_squares, - const Tensor& split_sums, const Tensor& split_squares, - int32 accumulator) { +bool BestSplitDominatesRegression(const Tensor& total_sums, + const Tensor& total_squares, + const Tensor& split_sums, + const Tensor& split_squares, + int32 accumulator) { // TODO(thomaswc): Implement this, probably as part of v3. return false; } @@ -599,7 +598,6 @@ bool Decide(float value, float bias, DataColumnTypes type) { } } - void GetParentWeightedMean(float leaf_sum, const float* leaf_data, float parent_sum, const float* parent_data, float valid_leaf_threshold, int num_outputs, diff --git a/tensorflow/contrib/tensor_forest/kernels/tree_utils.h b/tensorflow/contrib/tensor_forest/kernels/tree_utils.h index dad9df4898844eaa17bdfe5b4b298a95377fd12e..edbac6700677633cbd4d41f7040b4859ca599c4a 100644 --- a/tensorflow/contrib/tensor_forest/kernels/tree_utils.h +++ b/tensorflow/contrib/tensor_forest/kernels/tree_utils.h @@ -45,13 +45,10 @@ const int32 LEAF_NODE = -1; const int32 FREE_NODE = -2; // Used to indicate column types, e.g. categorical vs. float -enum DataColumnTypes { - kDataFloat = 0, - kDataCategorical = 1 -}; +enum DataColumnTypes { kDataFloat = 0, kDataCategorical = 1 }; // Calculates the sum of a tensor. -template +template T Sum(Tensor counts) { Eigen::Tensor count_sum = counts.unaligned_flat().sum(); @@ -97,7 +94,7 @@ float WeightedGiniImpurity(const T& counts) { return RawWeightedGiniImpurity(smoothed); } -template +template float WeightedVariance(const T1& sums, const T2& squares, float count) { const auto e_x = sums / count; const auto e_x2 = squares / count; @@ -120,10 +117,11 @@ int32 BestFeatureRegression(const Tensor& total_sums, // Returns true if the best split's variance is sufficiently smaller than // that of the next best split. -bool BestSplitDominatesRegression( - const Tensor& total_sums, const Tensor& total_squares, - const Tensor& split_sums, const Tensor& split_squares, - int32 accumulator); +bool BestSplitDominatesRegression(const Tensor& total_sums, + const Tensor& total_squares, + const Tensor& split_sums, + const Tensor& split_squares, + int32 accumulator); // Performs booststrap_samples bootstrap samples of the best split's class // counts and the second best splits's class counts, and returns true if at @@ -178,10 +176,8 @@ bool DecideNode(const GetFeatureFnType& get_dense, // isn't present in sparse_input_indices. sparse_input_indices is assumed // to be sorted. template -float FindSparseValue( - const T1& sparse_input_indices, - const T2& sparse_input_values, - int32 i, int32 j) { +float FindSparseValue(const T1& sparse_input_indices, + const T2& sparse_input_values, int32 i, int32 j) { int32 low = 0; int32 high = sparse_input_values.dimension(0); while (low < high) { @@ -273,7 +269,6 @@ int32 GetNumSparseFeatures(const T1& indices, int32 input_index, // categorical data, it is value != bias. bool Decide(float value, float bias, DataColumnTypes type = kDataFloat); - // Returns true if all the splits are initialized. Since they get initialized // in order, we can simply infer this from the last split. // This should only be called for a single allocator's candidate features diff --git a/tensorflow/contrib/tensor_forest/kernels/tree_utils_test.cc b/tensorflow/contrib/tensor_forest/kernels/tree_utils_test.cc index 7485a695dfba93fd3f57c19096b205b10e2fa8b5..08553545502c21eb8f2d68bfd342f8ba7c081adb 100644 --- a/tensorflow/contrib/tensor_forest/kernels/tree_utils_test.cc +++ b/tensorflow/contrib/tensor_forest/kernels/tree_utils_test.cc @@ -44,11 +44,13 @@ TEST(TestWeightedVariance, Basic) { Tensor squares = test::AsTensor({29, 12}, {2}); EXPECT_FLOAT_EQ(WeightedVariance(sums.unaligned_flat(), - squares.unaligned_flat(), 3), 2.0); + squares.unaligned_flat(), 3), + 2.0); Tensor zero = test::AsTensor({0}, {1}); EXPECT_FLOAT_EQ(WeightedVariance(zero.unaligned_flat(), - zero.unaligned_flat(), 1), 0); + zero.unaligned_flat(), 1), + 0); } TEST(TestInitialize, Basic) { @@ -94,17 +96,16 @@ TEST(BestFeatureClassification, Basic) { const int32 num_accumulators = 4; const int32 num_splits = 3; const int32 num_classes = 4; - Tensor totals = test::AsTensor({1, 5, 6, 7, - 0, 0, 0, 0, - 30, 10, 10, 10, // this one - -1, -1, -1, -1}, - {num_accumulators, num_classes}); - Tensor splits = test::AsTensor( - {1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 30, 10, 10, 10, 10, 0, 0, 10, 19, 5, 6, 8, // this one - -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1}, - {num_accumulators, num_splits, num_classes}); + Tensor totals = test::AsTensor( + {1, 5, 6, 7, 0, 0, 0, 0, 30, 10, 10, 10, // this one + -1, -1, -1, -1}, + {num_accumulators, num_classes}); + Tensor splits = + test::AsTensor({1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 30, 10, + 10, 10, 10, 0, 0, 10, 19, 5, 6, 8, // this one + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1}, + {num_accumulators, num_splits, num_classes}); EXPECT_EQ(BestFeatureClassification(totals, splits, 2), 1); } @@ -114,17 +115,16 @@ TEST(BestFeatureClassification, NoWinner) { const int32 num_splits = 3; const int32 num_classes = 4; // When counts are all the same, the most reasonable thing to do is pick 0. - Tensor totals = test::AsTensor({1, 5, 6, 7, - 0, 0, 0, 0, - 18, 6, 6, 6, // this one - -1, -1, -1, -1}, - {num_accumulators, num_classes}); - Tensor splits = test::AsTensor( - {1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 9, 3, 3, 3, 9, 3, 3, 3, 9, 3, 3, 3, // this one - -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1}, - {num_accumulators, num_splits, num_classes}); + Tensor totals = + test::AsTensor({1, 5, 6, 7, 0, 0, 0, 0, 18, 6, 6, 6, // this one + -1, -1, -1, -1}, + {num_accumulators, num_classes}); + Tensor splits = + test::AsTensor({1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 9, 3, + 3, 3, 9, 3, 3, 3, 9, 3, 3, 3, // this one + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1}, + {num_accumulators, num_splits, num_classes}); EXPECT_EQ(BestFeatureClassification(totals, splits, 2), 0); } @@ -133,36 +133,34 @@ TEST(BestFeatureRegression, Basic) { const int32 num_accumulators = 4; const int32 num_splits = 3; const int32 num_classes = 4; - Tensor total_sums = test::AsTensor( - {1, 5, 6, 7, - 0, 0, 0, 0, - 10, 8, 6, 9, // this one - -1, -1, -1, -1}, - {num_accumulators, num_classes}); + Tensor total_sums = + test::AsTensor({1, 5, 6, 7, 0, 0, 0, 0, 10, 8, 6, 9, // this one + -1, -1, -1, -1}, + {num_accumulators, num_classes}); Tensor total_squares = test::AsTensor( - {1, 5, 6, 7, - 0, 0, 0, 0, - 100, 50, 40, 45, // this one + {1, 5, 6, 7, 0, 0, 0, 0, 100, 50, 40, 45, // this one -1, -1, -1, -1}, {num_accumulators, num_classes}); - Tensor split_sums = test::AsTensor( - {1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 10, 8, 6, 9, 9, 8, 5, 9, 0, 0, 0, 0, // this one - -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1}, - {num_accumulators, num_splits, num_classes}); + Tensor split_sums = + test::AsTensor({1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 10, 8, + 6, 9, 9, 8, 5, 9, 0, 0, 0, 0, // this one + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1}, + {num_accumulators, num_splits, num_classes}); // lower the variance by lowering one of the squares just a little. - Tensor split_squares = test::AsTensor( - {1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 100, 50, 40, 45, 100, 50, 40, 43, 0, 0, 0, 0, // this one - -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1}, - {num_accumulators, num_splits, num_classes}); + Tensor split_squares = + test::AsTensor( + {1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 100, 50, 40, 45, 100, 50, 40, 43, 0, 0, 0, 0, // this one + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1}, + {num_accumulators, num_splits, num_classes}); EXPECT_EQ(BestFeatureRegression(total_sums, total_squares, split_sums, - split_squares, 2), 1); + split_squares, 2), + 1); } TEST(BestFeatureRegression, NoWinner) { @@ -170,37 +168,33 @@ TEST(BestFeatureRegression, NoWinner) { const int32 num_splits = 3; const int32 num_classes = 4; // when counts are all the same, the most reasonable thing to do is pick 0. - Tensor total_sums = test::AsTensor( - {1, 5, 6, 7, - 0, 0, 0, 0, - 10, 8, 6, 9, // this one - -1, -1, -1, -1}, - {num_accumulators, num_classes}); + Tensor total_sums = + test::AsTensor({1, 5, 6, 7, 0, 0, 0, 0, 10, 8, 6, 9, // this one + -1, -1, -1, -1}, + {num_accumulators, num_classes}); Tensor total_squares = test::AsTensor( - {1, 5, 6, 7, - 0, 0, 0, 0, - 100, 50, 40, 45, // this one + {1, 5, 6, 7, 0, 0, 0, 0, 100, 50, 40, 45, // this one -1, -1, -1, -1}, {num_accumulators, num_classes}); - Tensor split_sums = test::AsTensor( - {1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 10, 8, 6, 9, 10, 8, 6, 9, 10, 8, 6, 9, // this one - -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1}, - {num_accumulators, num_splits, num_classes}); + Tensor split_sums = + test::AsTensor({1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 10, 8, + 6, 9, 10, 8, 6, 9, 10, 8, 6, 9, // this one + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1}, + {num_accumulators, num_splits, num_classes}); Tensor split_squares = test::AsTensor( - {1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 100, 50, 40, 45, 100, 50, 40, 45, 100, 50, 40, 45, // this one - -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1}, + {1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 100, 50, 40, 45, 100, 50, 40, 45, 100, 50, 40, 45, // this one + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1}, {num_accumulators, num_splits, num_classes}); EXPECT_EQ(BestFeatureRegression(total_sums, total_squares, split_sums, - split_squares, 2), 0); + split_squares, 2), + 0); } } // namespace tensorforest } // namespace tensorflow - diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/candidate_graph_runner.cc b/tensorflow/contrib/tensor_forest/kernels/v4/candidate_graph_runner.cc index 81e2a1b2a1b720574210e376fa786923367794a6..f4a7058ddb8bfdd6393a9369006aabc29d058d3b 100644 --- a/tensorflow/contrib/tensor_forest/kernels/v4/candidate_graph_runner.cc +++ b/tensorflow/contrib/tensor_forest/kernels/v4/candidate_graph_runner.cc @@ -14,8 +14,8 @@ // ============================================================================= #include "tensorflow/contrib/tensor_forest/kernels/v4/candidate_graph_runner.h" -#include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/platform/env.h" namespace tensorflow { @@ -58,8 +58,7 @@ CandidateGraphRunner::CandidateGraphRunner( // Features don't change, store them in a tensor. const auto& oblique = split.inequality_left_child_test().oblique(); const int32 feat_size = oblique.features_size(); - features_.reset( - new Tensor(tensorflow::DT_INT32, TensorShape({feat_size}))); + features_.reset(new Tensor(tensorflow::DT_INT32, TensorShape({feat_size}))); auto feat = features_->flat(); int i = 0; for (const auto& id : oblique.features()) { @@ -67,10 +66,10 @@ CandidateGraphRunner::CandidateGraphRunner( } } -void CandidateGraphRunner::RunOp( - const string& name, const TensorNameValueList& inputs, - const std::vector& output_tensor_names, - std::vector* outputs) { +void CandidateGraphRunner::RunOp(const string& name, + const TensorNameValueList& inputs, + const std::vector& output_tensor_names, + std::vector* outputs) { std::vector op_name; if (name != kNoOp) { op_name.push_back(name); diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/decision-tree-resource.h b/tensorflow/contrib/tensor_forest/kernels/v4/decision-tree-resource.h index cced26b9036ba8ba6c5994b7483261a062f80588..328af28725af016e90b30ae2d303ffba15c81c1f 100644 --- a/tensorflow/contrib/tensor_forest/kernels/v4/decision-tree-resource.h +++ b/tensorflow/contrib/tensor_forest/kernels/v4/decision-tree-resource.h @@ -26,7 +26,6 @@ namespace tensorflow { namespace tensorforest { - // Keep a tree ensemble in memory for efficient evaluation and mutation. class DecisionTreeResource : public ResourceBase { public: @@ -35,15 +34,12 @@ class DecisionTreeResource : public ResourceBase { string DebugString() override { return strings::StrCat("DecisionTree[size=", - decision_tree_->decision_tree().nodes_size(), - "]"); + decision_tree_->decision_tree().nodes_size(), "]"); } void MaybeInitialize(); - const decision_trees::Model& decision_tree() const { - return *decision_tree_; - } + const decision_trees::Model& decision_tree() const { return *decision_tree_; } decision_trees::Model* mutable_decision_tree() { return decision_tree_.get(); @@ -59,9 +55,7 @@ class DecisionTreeResource : public ResourceBase { // Resets the resource and frees the proto. // Caller needs to hold the mutex lock while calling this. - void Reset() { - decision_tree_.reset(new decision_trees::Model()); - } + void Reset() { decision_tree_.reset(new decision_trees::Model()); } mutex* get_mutex() { return &mu_; } @@ -84,7 +78,6 @@ class DecisionTreeResource : public ResourceBase { std::vector> node_evaluators_; }; - } // namespace tensorforest } // namespace tensorflow diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator.h b/tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator.h index 85ce7b825b11983307370bb3ac30eeec9b6b2c99..bf2b2aaa3c8f433ab4fc145217857112f7a0a579 100644 --- a/tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator.h +++ b/tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator.h @@ -22,7 +22,6 @@ namespace tensorflow { namespace tensorforest { - // Base class for evaluators of decision nodes that effectively copy proto // contents into C++ structures for faster execution. class DecisionNodeEvaluator { diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator_test.cc b/tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator_test.cc index 5c49b87443e7b1f4ef532256ae2efdc9fa985d8a..af5cf72a3c0bea0eef45c3446acf52ff389c6751 100644 --- a/tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator_test.cc +++ b/tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator_test.cc @@ -20,11 +20,11 @@ namespace tensorflow { namespace { +using tensorflow::decision_trees::InequalityTest; +using tensorflow::decision_trees::MatchingValuesTest; using tensorflow::tensorforest::InequalityDecisionNodeEvaluator; using tensorflow::tensorforest::MatchingValuesDecisionNodeEvaluator; using tensorflow::tensorforest::ObliqueInequalityDecisionNodeEvaluator; -using tensorflow::decision_trees::InequalityTest; -using tensorflow::decision_trees::MatchingValuesTest; TEST(InequalityDecisionNodeEvaluatorTest, TestLessOrEqual) { InequalityTest test; @@ -124,4 +124,3 @@ TEST(ObliqueDecisionNodeEvaluatorTest, Basic) { } // namespace } // namespace tensorflow - diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/fertile-stats-resource.h b/tensorflow/contrib/tensor_forest/kernels/v4/fertile-stats-resource.h index 0d6712e9e552d7045eb198f7e65d04eb42eff920..eea0be27caf0a022ba7acaacd359c75a2df4eedb 100644 --- a/tensorflow/contrib/tensor_forest/kernels/v4/fertile-stats-resource.h +++ b/tensorflow/contrib/tensor_forest/kernels/v4/fertile-stats-resource.h @@ -40,9 +40,7 @@ class FertileStatsResource : public ResourceBase { model_op_ = LeafModelOperatorFactory::CreateLeafModelOperator(params_); } - string DebugString() override { - return "FertileStats"; - } + string DebugString() override { return "FertileStats"; } void ExtractFromProto(const FertileStats& stats); @@ -50,8 +48,7 @@ class FertileStatsResource : public ResourceBase { // Resets the resource and frees the proto. // Caller needs to hold the mutex lock while calling this. - void Reset() { - } + void Reset() {} // Reset the stats for a node, but leave the leaf_stats intact. void ResetSplitStats(int32 node_id, int32 depth) { @@ -84,7 +81,6 @@ class FertileStatsResource : public ResourceBase { // was found. bool BestSplit(int32 node_id, SplitCandidate* best, int32* depth); - private: mutex mu_; std::shared_ptr model_op_; @@ -94,7 +90,6 @@ class FertileStatsResource : public ResourceBase { void AllocateNode(int32 node_id, int32 depth); }; - } // namespace tensorforest } // namespace tensorflow diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/grow_stats.cc b/tensorflow/contrib/tensor_forest/kernels/v4/grow_stats.cc index 3ce630e3a9691b87ad291a9f29616f741953dd84..da600d34eacdf27514709240723e5bb730cfe7f0 100644 --- a/tensorflow/contrib/tensor_forest/kernels/v4/grow_stats.cc +++ b/tensorflow/contrib/tensor_forest/kernels/v4/grow_stats.cc @@ -20,7 +20,6 @@ #include "tensorflow/contrib/tensor_forest/kernels/v4/stat_utils.h" #include "tensorflow/core/lib/random/distribution_sampler.h" - namespace tensorflow { namespace tensorforest { @@ -454,14 +453,14 @@ void DenseClassificationGrowStats::PackToProto(FertileSlot* slot) const { class_stats->add_value()->set_float_value(total_counts_[i]); } - for (int split_num = 0; split_num < num_splits(); ++split_num) { + for (int split_num = 0; split_num < num_splits(); ++split_num) { auto* cand = slot->add_candidates(); *cand->mutable_split() = splits_[split_num]; auto* left_stats = cand->mutable_left_stats() ->mutable_classification() ->mutable_dense_counts(); for (int i = 0; i < num_outputs_; ++i) { - left_stats->add_value()->set_float_value(left_count(split_num, i)); + left_stats->add_value()->set_float_value(left_count(split_num, i)); } } } @@ -546,7 +545,7 @@ void SparseClassificationGrowStats::PackToProto(FertileSlot* slot) const { (*class_stats)[entry.first] = val; } - for (int split_num = 0; split_num < num_splits(); ++split_num) { + for (int split_num = 0; split_num < num_splits(); ++split_num) { auto* cand = slot->add_candidates(); *cand->mutable_split() = splits_[split_num]; auto* left_stats = cand->mutable_left_stats() @@ -561,8 +560,8 @@ void SparseClassificationGrowStats::PackToProto(FertileSlot* slot) const { } } -float SparseClassificationGrowStats::GiniScore( - int split, float* left_sum, float* right_sum) const { +float SparseClassificationGrowStats::GiniScore(int split, float* left_sum, + float* right_sum) const { float left_square = 0, right_square = 0; *left_sum = 0; *right_sum = 0; @@ -844,12 +843,11 @@ void LeastSquaresRegressionGrowStats::PackToProto(FertileSlot* slot) const { total_squares->add_value()->set_float_value(total_sum_squares_[i]); } - for (int split_num = 0; split_num < num_splits(); ++split_num) { + for (int split_num = 0; split_num < num_splits(); ++split_num) { auto* cand = slot->add_candidates(); *cand->mutable_split() = splits_[split_num]; - auto* sums = cand->mutable_left_stats() - ->mutable_regression() - ->mutable_mean_output(); + auto* sums = + cand->mutable_left_stats()->mutable_regression()->mutable_mean_output(); auto* squares = cand->mutable_left_stats() ->mutable_regression() ->mutable_mean_output_squares(); @@ -891,20 +889,17 @@ float LeastSquaresRegressionGrowStats::SplitVariance(int split) const { float total_variance = 0; for (int i = 0; i < params_.num_outputs(); ++i) { // Left side - const float le_x = - left_sum(split, i) / left_counts_[split]; + const float le_x = left_sum(split, i) / left_counts_[split]; - const float le_x2 = - left_square(split, i) / left_counts_[split]; + const float le_x2 = left_square(split, i) / left_counts_[split]; total_variance += le_x2 - le_x * le_x; // Right side const float re_x = (total_sum_[i] - left_sum(split, i)) / (weight_sum_ - left_counts_[split]); - const float re_x2 = - (total_sum_squares_[i] - left_square(split, i)) / - (weight_sum_ - left_counts_[split]); + const float re_x2 = (total_sum_squares_[i] - left_square(split, i)) / + (weight_sum_ - left_counts_[split]); total_variance += re_x2 - re_x * re_x; } return total_variance; @@ -937,8 +932,7 @@ bool LeastSquaresRegressionGrowStats::BestSplit(SplitCandidate* best) const { left->set_weight_sum(left_counts_[best_index]); auto* left_output_sum = left_reg_stats->mutable_mean_output(); for (int i = 0; i < num_outputs; ++i) { - left_output_sum->add_value()->set_float_value( - left_sum(best_index, i)); + left_output_sum->add_value()->set_float_value(left_sum(best_index, i)); } // Right @@ -947,8 +941,8 @@ bool LeastSquaresRegressionGrowStats::BestSplit(SplitCandidate* best) const { right->set_weight_sum(weight_sum_ - left_counts_[best_index]); auto* right_output_sum = right_reg_stats->mutable_mean_output(); for (int i = 0; i < num_outputs; ++i) { - right_output_sum->add_value()->set_float_value( - total_sum_[i] - left_sum(best_index, i)); + right_output_sum->add_value()->set_float_value(total_sum_[i] - + left_sum(best_index, i)); } return true; } diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/grow_stats.h b/tensorflow/contrib/tensor_forest/kernels/v4/grow_stats.h index 02c0fc687fffc022f9f41ffce7acfcddba5d4b45..04e6b0a735320dd024e326a94ef910593a326245 100644 --- a/tensorflow/contrib/tensor_forest/kernels/v4/grow_stats.h +++ b/tensorflow/contrib/tensor_forest/kernels/v4/grow_stats.h @@ -73,21 +73,15 @@ class GrowStats { const InputTarget* target, int example) {} void RemoveSplit(int split_num); - int num_splits() const { - return splits_.size(); - } + int num_splits() const { return splits_.size(); } - float weight_sum() const { - return weight_sum_; - } + float weight_sum() const { return weight_sum_; } virtual bool IsInitialized() const { return weight_sum_ > 0 || splits_.size() == num_splits_to_consider_; } - int32 depth() const { - return depth_; - } + int32 depth() const { return depth_; } protected: GrowStats(const TensorForestParams& params, int32 depth); @@ -206,8 +200,8 @@ class ClassificationStats : public GrowStats { virtual float left_count(int split, int class_num) const = 0; virtual float right_count(int split, int class_num) const = 0; - virtual void ClassificationAddLeftExample( - int split, int64 int_label, float weight) = 0; + virtual void ClassificationAddLeftExample(int split, int64 int_label, + float weight) = 0; virtual void ClassificationAddRightExample(int split, int64 int_label, float weight) { // Does nothing by default, but sub-classes can override. @@ -375,9 +369,7 @@ class SparseClassificationGrowStats : public ClassificationStats { SparseClassificationGrowStats(const TensorForestParams& params, int32 depth) : ClassificationStats(params, depth) {} - void Initialize() override { - Clear(); - } + void Initialize() override { Clear(); } void ExtractFromProto(const FertileSlot& slot) override; void PackToProto(FertileSlot* slot) const override; @@ -562,9 +554,9 @@ class LeastSquaresRegressionGrowStats : public GrowStats { } void RemoveSplitStats(int split_num) override { left_sums_.erase(left_sums_.begin() + num_outputs_ * split_num, - left_sums_.begin() + num_outputs_ * (split_num + 1)); + left_sums_.begin() + num_outputs_ * (split_num + 1)); left_squares_.erase(left_squares_.begin() + num_outputs_ * split_num, - left_squares_.begin() + num_outputs_ * (split_num + 1)); + left_squares_.begin() + num_outputs_ * (split_num + 1)); left_counts_.erase(left_counts_.begin() + split_num, left_counts_.begin() + (split_num + 1)); } @@ -605,7 +597,6 @@ class LeastSquaresRegressionGrowStats : public GrowStats { std::vector left_counts_; }; - } // namespace tensorforest } // namespace tensorflow diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/grow_stats_test.cc b/tensorflow/contrib/tensor_forest/kernels/v4/grow_stats_test.cc index ceb58d2ead5c2f148c96d9cb9532a73688593d33..26e989928e00de1b2ae1646abf216adfbec2be4f 100644 --- a/tensorflow/contrib/tensor_forest/kernels/v4/grow_stats_test.cc +++ b/tensorflow/contrib/tensor_forest/kernels/v4/grow_stats_test.cc @@ -24,21 +24,21 @@ namespace tensorflow { namespace { -using tensorflow::tensorforest::GrowStats; -using tensorflow::tensorforest::TestableInputTarget; -using tensorflow::tensorforest::FertileSlot; +using tensorflow::decision_trees::BinaryNode; +using tensorflow::decision_trees::FeatureId; +using tensorflow::decision_trees::InequalityTest; using tensorflow::tensorforest::DenseClassificationGrowStats; -using tensorflow::tensorforest::SparseClassificationGrowStats; +using tensorflow::tensorforest::FertileSlot; using tensorflow::tensorforest::FixedSizeClassStats; using tensorflow::tensorforest::FixedSizeSparseClassificationGrowStats; +using tensorflow::tensorforest::GrowStats; using tensorflow::tensorforest::LeastSquaresRegressionGrowStats; -using tensorflow::tensorforest::TensorForestParams; +using tensorflow::tensorforest::SparseClassificationGrowStats; using tensorflow::tensorforest::SPLIT_FINISH_BASIC; using tensorflow::tensorforest::SPLIT_FINISH_DOMINATE_HOEFFDING; using tensorflow::tensorforest::SPLIT_PRUNE_HOEFFDING; -using tensorflow::decision_trees::BinaryNode; -using tensorflow::decision_trees::InequalityTest; -using tensorflow::decision_trees::FeatureId; +using tensorflow::tensorforest::TensorForestParams; +using tensorflow::tensorforest::TestableInputTarget; BinaryNode MakeSplit(const string& feat, float val) { BinaryNode split; @@ -52,8 +52,7 @@ BinaryNode MakeSplit(const string& feat, float val) { return split; } -void RunBatch(GrowStats* stats, - const TestableInputTarget* target) { +void RunBatch(GrowStats* stats, const TestableInputTarget* target) { std::unique_ptr dataset( new tensorflow::tensorforest::TestableDataSet( {1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, 2)); @@ -102,18 +101,10 @@ class TestableRunningStats : public DenseClassificationGrowStats { TestableRunningStats(const TensorForestParams& params, int32 depth) : DenseClassificationGrowStats(params, depth) {} - float test_left_sum(int split) { - return get_left_gini()->sum(split); - } - float test_left_square(int split) { - return get_left_gini()->square(split); - } - float test_right_sum(int split) { - return get_right_gini()->sum(split); - } - float test_right_square(int split) { - return get_right_gini()->square(split); - } + float test_left_sum(int split) { return get_left_gini()->sum(split); } + float test_left_square(int split) { return get_left_gini()->square(split); } + float test_right_sum(int split) { return get_right_gini()->sum(split); } + float test_right_square(int split) { return get_right_gini()->square(split); } }; TEST(GrowStatsDenseClassificationTest, BasicRunningStats) { @@ -166,9 +157,7 @@ class TestableFinishEarly : public DenseClassificationGrowStats { int num_times_called_; protected: - void CheckFinishEarlyHoeffding() override { - ++num_times_called_; - } + void CheckFinishEarlyHoeffding() override { ++num_times_called_; } }; TEST(GrowStatsDenseClassificationTest, TestFinishEarly) { @@ -212,7 +201,6 @@ TEST(GrowStatsDenseClassificationTest, TestFinishEarly) { ASSERT_EQ(stat->num_times_called_, 9); } - TEST(GrowStatsDenseClassificationTest, TestCheckPruneHoeffding) { TensorForestParams params; params.set_num_outputs(2); @@ -224,7 +212,8 @@ TEST(GrowStatsDenseClassificationTest, TestCheckPruneHoeffding) { finish->set_type(SPLIT_FINISH_BASIC); finish->mutable_check_every_steps()->set_constant_value(100); params.mutable_pruning_type()->set_type(SPLIT_PRUNE_HOEFFDING); - params.mutable_pruning_type()->mutable_prune_every_samples() + params.mutable_pruning_type() + ->mutable_prune_every_samples() ->set_constant_value(1); // On each iteration, we add two examples, one of class 0 and one @@ -234,8 +223,8 @@ TEST(GrowStatsDenseClassificationTest, TestCheckPruneHoeffding) { std::vector weights = {1, 1}; TestableInputTarget target(labels, weights, 1); std::unique_ptr dataset( - new tensorflow::tensorforest::TestableDataSet( - {-1.0, -1.0, 1.0, -1.0}, 2)); + new tensorflow::tensorforest::TestableDataSet({-1.0, -1.0, 1.0, -1.0}, + 2)); DenseClassificationGrowStats stats(params, 1); stats.Initialize(); diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/input_data.cc b/tensorflow/contrib/tensor_forest/kernels/v4/input_data.cc index bf0fb9245043c3bbf22e8aafc97b6d0186c3a29f..d43884481afbbbc988d6eb80e01e49663df6914b 100644 --- a/tensorflow/contrib/tensor_forest/kernels/v4/input_data.cc +++ b/tensorflow/contrib/tensor_forest/kernels/v4/input_data.cc @@ -109,10 +109,10 @@ void TensorDataSet::set_input_tensors(const Tensor& dense, dense_data_.reset(new DenseStorageType(dense.tensor())); } if (sparse_indices.shape().dims() == 2) { - sparse_indices_.reset(new SparseIndicesStorageType( - sparse_indices.tensor())); - sparse_values_.reset(new SparseValuesStorageType( - sparse_values.tensor())); + sparse_indices_.reset( + new SparseIndicesStorageType(sparse_indices.tensor())); + sparse_values_.reset( + new SparseValuesStorageType(sparse_values.tensor())); sparse_batch_size_ = sparse_shape.tensor()(0); } original_dense_tensor_ = dense; diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/input_data.h b/tensorflow/contrib/tensor_forest/kernels/v4/input_data.h index eafad6b591672f67ae816405ff603f9aaba30a1b..c544a8c75e9bfe8fe6bbea8913e7be17d868bfef 100644 --- a/tensorflow/contrib/tensor_forest/kernels/v4/input_data.h +++ b/tensorflow/contrib/tensor_forest/kernels/v4/input_data.h @@ -93,9 +93,7 @@ class TensorDataSet { // an int32 you can avoid the atoi32. virtual float GetExampleValue(int example, int32 feature_id) const; - int num_features() { - return available_features_.size(); - } + int num_features() { return available_features_.size(); } const Tensor& original_tensor() const { return original_dense_tensor_; } diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/input_target.h b/tensorflow/contrib/tensor_forest/kernels/v4/input_target.h index 44ec09c50ef3d092bd1bf7f051f492e1fffdd05b..d4402b6055a36d38042a0e6cfa07b532ec11c093 100644 --- a/tensorflow/contrib/tensor_forest/kernels/v4/input_target.h +++ b/tensorflow/contrib/tensor_forest/kernels/v4/input_target.h @@ -79,9 +79,7 @@ class TensorInputTarget : public StoredInputTarget { return (*target_)(example_index * num_targets_ + target_index); } - const Tensor& original_tensor() const { - return original_tensor_; - } + const Tensor& original_tensor() const { return original_tensor_; } protected: Tensor original_tensor_; diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/leaf_model_operators.cc b/tensorflow/contrib/tensor_forest/kernels/v4/leaf_model_operators.cc index d43c068e462ff78b114fb29bd8cf0ee0c6080fcd..83614a25314117ef9ba29b4dcf6ebee8f7f3e226 100644 --- a/tensorflow/contrib/tensor_forest/kernels/v4/leaf_model_operators.cc +++ b/tensorflow/contrib/tensor_forest/kernels/v4/leaf_model_operators.cc @@ -160,6 +160,5 @@ void RegressionLeafModelOperator::ExportModel( } } - } // namespace tensorforest } // namespace tensorflow diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/leaf_model_operators_test.cc b/tensorflow/contrib/tensor_forest/kernels/v4/leaf_model_operators_test.cc index ffd92c01f9a59719e6bb2458c2f28253c364a2e8..ab4191809b6a7400114acf85991c74acfac55505 100644 --- a/tensorflow/contrib/tensor_forest/kernels/v4/leaf_model_operators_test.cc +++ b/tensorflow/contrib/tensor_forest/kernels/v4/leaf_model_operators_test.cc @@ -26,19 +26,19 @@ namespace { using tensorflow::decision_trees::Leaf; using tensorflow::tensorforest::DenseClassificationLeafModelOperator; using tensorflow::tensorforest::LeafModelOperator; -using tensorflow::tensorforest::SparseClassificationLeafModelOperator; -using tensorflow::tensorforest::SparseOrDenseClassificationLeafModelOperator; using tensorflow::tensorforest::LeafStat; using tensorflow::tensorforest::RegressionLeafModelOperator; -using tensorflow::tensorforest::TestableInputTarget; +using tensorflow::tensorforest::SparseClassificationLeafModelOperator; +using tensorflow::tensorforest::SparseOrDenseClassificationLeafModelOperator; using tensorflow::tensorforest::TensorForestParams; +using tensorflow::tensorforest::TestableInputTarget; const int32 kNumClasses = 3; constexpr char kRegressionStatProto[] = - "weight_sum: 3 " - "regression { " - "mean_output { " + "weight_sum: 3 " + "regression { " + "mean_output { " "value { " " float_value: 27 " "} " @@ -48,8 +48,8 @@ constexpr char kRegressionStatProto[] = "value { " " float_value: 10 " "} " - "} " - "mean_output_squares { " + "} " + "mean_output_squares { " "value {" " float_value: 245" "}" @@ -59,8 +59,8 @@ constexpr char kRegressionStatProto[] = "value {" " float_value: 46" "}" - "}" -"}"; + "}" + "}"; void TestClassificationNormalUse(const std::unique_ptr& op) { Leaf l; @@ -83,7 +83,6 @@ void TestClassificationNormalUse(const std::unique_ptr& op) { EXPECT_FLOAT_EQ(op->GetOutputValue(l, 1), 3.4); } - TEST(DenseLeafModelOperatorsTest, NormalUse) { TensorForestParams params; params.set_num_outputs(kNumClasses); @@ -182,7 +181,7 @@ TEST(SparseLeafModelOperatorsTest, InitWithExisting) { std::unique_ptr leaf(new Leaf); - op->ExportModel( *stat, leaf.get()); + op->ExportModel(*stat, leaf.get()); // Make sure it was initialized correctly. EXPECT_FLOAT_EQ(op->GetOutputValue(*leaf, 0), 1.1); @@ -194,7 +193,6 @@ TEST(SparseLeafModelOperatorsTest, InitWithExisting) { EXPECT_EQ(leaf->sparse_vector().sparse_value().size(), kNumClasses); } - TEST(RegressionLeafModelOperatorsTest, NormalUse) { TensorForestParams params; params.set_num_outputs(kNumClasses); diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/params.h b/tensorflow/contrib/tensor_forest/kernels/v4/params.h index b0ed949424756cc498d4b7ad1fb1867fff11b265..7583e3d0402a3a1d07f3696727b285747dc887de 100644 --- a/tensorflow/contrib/tensor_forest/kernels/v4/params.h +++ b/tensorflow/contrib/tensor_forest/kernels/v4/params.h @@ -24,7 +24,6 @@ namespace tensorforest { // Return the value of the given depth-dependent parameter given a leaf's depth. float ResolveParam(const DepthDependentParam& param, int32 depth); - } // namespace tensorforest } // namespace tensorflow diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/params_test.cc b/tensorflow/contrib/tensor_forest/kernels/v4/params_test.cc index 801881af1368dc33f00b356d12bea07ae3161ef6..4010a71006d58df0bec6d3686a9c47433b46fdd4 100644 --- a/tensorflow/contrib/tensor_forest/kernels/v4/params_test.cc +++ b/tensorflow/contrib/tensor_forest/kernels/v4/params_test.cc @@ -71,5 +71,3 @@ TEST(ParamsTest, TestThreshold) { } } // namespace - - diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/split_collection_operators.cc b/tensorflow/contrib/tensor_forest/kernels/v4/split_collection_operators.cc index cdb1d80a4bbd47d1481ecde2348bef500bd125f1..b7b60d0ab8c2670cec8b029d1f42c5edd3690afe 100644 --- a/tensorflow/contrib/tensor_forest/kernels/v4/split_collection_operators.cc +++ b/tensorflow/contrib/tensor_forest/kernels/v4/split_collection_operators.cc @@ -52,8 +52,8 @@ std::unique_ptr SplitCollectionOperator::CreateGrowStats( new SparseClassificationGrowStats(params_, depth)); case STATS_LEAST_SQUARES_REGRESSION: - return std::unique_ptr(new LeastSquaresRegressionGrowStats( - params_, depth)); + return std::unique_ptr( + new LeastSquaresRegressionGrowStats(params_, depth)); case STATS_FIXED_SIZE_SPARSE_GINI: return std::unique_ptr( @@ -136,8 +136,7 @@ void SplitCollectionOperator::CreateAndInitializeCandidateWithExample( stats_.at(node_id)->AddSplit(split, input_data, target, example); } -bool SplitCollectionOperator::BestSplit(int32 node_id, - SplitCandidate* best, +bool SplitCollectionOperator::BestSplit(int32 node_id, SplitCandidate* best, int32* depth) const { auto* slot = stats_.at(node_id).get(); *depth = slot->depth(); diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/split_collection_operators.h b/tensorflow/contrib/tensor_forest/kernels/v4/split_collection_operators.h index ad52f89faddb15be77644b5dc374aca73c46b149..c606ff98c67f411a5817f0282238fdaf3be03642 100644 --- a/tensorflow/contrib/tensor_forest/kernels/v4/split_collection_operators.h +++ b/tensorflow/contrib/tensor_forest/kernels/v4/split_collection_operators.h @@ -71,9 +71,7 @@ class SplitCollectionOperator { } // Perform any necessary cleanup for any tracked state for the slot. - virtual void ClearSlot(int32 node_id) { - stats_.erase(node_id); - } + virtual void ClearSlot(int32 node_id) { stats_.erase(node_id); } // Return true if slot is fully initialized. virtual bool IsInitialized(int32 node_id) const; diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/stat_utils.cc b/tensorflow/contrib/tensor_forest/kernels/v4/stat_utils.cc index 0bec198e97e8215d2cfdb9ada5355dd5b0d2d97b..c749fbe69e17769c2f2b69bcf541eb0eb8b9e7e8 100644 --- a/tensorflow/contrib/tensor_forest/kernels/v4/stat_utils.cc +++ b/tensorflow/contrib/tensor_forest/kernels/v4/stat_utils.cc @@ -32,9 +32,9 @@ namespace tensorforest { // smoothed_sum = stats.sum() + #_classes float GiniImpurity(const LeafStat& stats, int32 num_classes) { const float smoothed_sum = num_classes + stats.weight_sum(); - return 1.0 - ( - (stats.classification().gini().square() - + 2 * stats.weight_sum() + num_classes) / (smoothed_sum * smoothed_sum)); + return 1.0 - ((stats.classification().gini().square() + + 2 * stats.weight_sum() + num_classes) / + (smoothed_sum * smoothed_sum)); } float WeightedGiniImpurity(const LeafStat& stats, int32 num_classes) { @@ -46,21 +46,20 @@ void UpdateGini(LeafStat* stats, float old_val, float weight) { // Equivalent to stats->square() - old_val * old_val + new_val * new_val, // (for new_val = old_val + weight), but more numerically stable. stats->mutable_classification()->mutable_gini()->set_square( - stats->classification().gini().square() - + weight * weight + 2 * old_val * weight); + stats->classification().gini().square() + weight * weight + + 2 * old_val * weight); } - float Variance(const LeafStat& stats, int output) { if (stats.weight_sum() == 0) { return 0; } const float e_x = - stats.regression().mean_output().value(output).float_value() - / stats.weight_sum(); + stats.regression().mean_output().value(output).float_value() / + stats.weight_sum(); const auto e_x2 = - stats.regression().mean_output_squares().value(output).float_value() - / stats.weight_sum(); + stats.regression().mean_output_squares().value(output).float_value() / + stats.weight_sum(); return e_x2 - e_x * e_x; } @@ -75,8 +74,7 @@ float TotalVariance(const LeafStat& stats) { float SmoothedGini(float sum, float square, int num_classes) { // See comments for GiniImpurity above. const float smoothed_sum = num_classes + sum; - return 1.0 - - (square + 2 * sum + num_classes) / (smoothed_sum * smoothed_sum); + return 1.0 - (square + 2 * sum + num_classes) / (smoothed_sum * smoothed_sum); } float WeightedSmoothedGini(float sum, float square, int num_classes) { diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/test_utils.h b/tensorflow/contrib/tensor_forest/kernels/v4/test_utils.h index 289c81e9d51dbc5d2023f7eabce8c2089748645d..38deb3e3cd816aae5fe66f26cd4b934316d03ce4 100644 --- a/tensorflow/contrib/tensor_forest/kernels/v4/test_utils.h +++ b/tensorflow/contrib/tensor_forest/kernels/v4/test_utils.h @@ -27,9 +27,7 @@ class TestableInputTarget : public StoredInputTarget> { : StoredInputTarget(new std::vector(t), new std::vector(w), num_t) {} - int NumItems() const { - return target_->size(); - } + int NumItems() const { return target_->size(); } int32 GetTargetAsClassIndex(int example_index, int target_index) const override { @@ -51,7 +49,6 @@ class TestableInputTarget : public StoredInputTarget> { } }; - class TestableDataSet : public TensorDataSet { public: TestableDataSet(const std::vector& data, int num_features) diff --git a/tensorflow/contrib/tpu/BUILD b/tensorflow/contrib/tpu/BUILD index 0199313bc8d0214a547498b97e9a1d83ee37b708..a7d54d8a0cc4598c26d1c7bd62f5b0aa1070701b 100644 --- a/tensorflow/contrib/tpu/BUILD +++ b/tensorflow/contrib/tpu/BUILD @@ -43,6 +43,7 @@ py_library( deps = [ ":tpu_lib", ":tpu_py", + "//tensorflow/contrib/summary:summary_ops", "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", "//tensorflow/python:control_flow_ops", diff --git a/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc b/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc index 6a05a2abf6fa57938af2b9b02ce394a9f6b7fe6e..b1ef9fde37fe0647965f0818895be37d2d56d207 100644 --- a/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc +++ b/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/contrib/tpu/profiler/version.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/util/command_line_flags.h" @@ -47,6 +48,19 @@ string GetCurrentTimeStampAsString() { return s; } +Status ValidateHostPortPair(const string& host_port) { + uint32 port; + std::vector parts = str_util::Split(host_port, ':'); + // Must be host:port, port must be a number, host must not contain a '/', + // host also must not be empty. + if (parts.size() != 2 || !strings::safe_strtou32(parts[1], &port) || + parts[0].find("/") != string::npos || parts[0].empty()) { + return errors::InvalidArgument("Could not interpret \"", host_port, + "\" as a host-port pair."); + } + return Status::OK(); +} + ProfileResponse Profile(const string& service_addr, int duration_ms, const ProfileOptions& opts) { ProfileRequest request; @@ -60,11 +74,14 @@ ProfileResponse Profile(const string& service_addr, int duration_ms, ::grpc::ClientContext context; ::grpc::ChannelArguments channel_args; // TODO(ioeric): use `SetMaxReceiveMessageSize` instead once it's available. + // TODO(qiuminxu): use `NewHostPortGrpcChannel` instead once their + // `ValidateHostPortPair` checks for empty host string case. channel_args.SetInt(GRPC_ARG_MAX_MESSAGE_LENGTH, std::numeric_limits::max()); std::unique_ptr stub = TPUProfiler::NewStub(::grpc::CreateCustomChannel( - service_addr, ::grpc::InsecureChannelCredentials(), channel_args)); + "dns:///" + service_addr, ::grpc::InsecureChannelCredentials(), + channel_args)); ProfileResponse response; TF_QCHECK_OK(FromGrpcStatus(stub->Profile(&context, request, &response))); return response; @@ -101,7 +118,14 @@ int main(int argc, char** argv) { tensorflow::string usage = tensorflow::Flags::Usage(argv[0], flag_list); bool parse_ok = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_ok || FLAGS_service_addr.empty() || FLAGS_logdir.empty()) { - std::printf("%s", usage.c_str()); + std::cout << usage.c_str() << std::endl; + return 2; + } + tensorflow::Status status = + tensorflow::tpu::ValidateHostPortPair(FLAGS_service_addr); + if (!status.ok()) { + std::cout << status.error_message() << std::endl; + std::cout << usage.c_str() << std::endl; return 2; } tensorflow::port::InitMain(argv[0], &argc, &argv); @@ -130,6 +154,8 @@ int main(int argc, char** argv) { << std::endl << "Tip: increase number of attempts with --num_tracing_attempts." << std::endl; + // Don't dump profile data if no trace is collected. + return 0; } // Use the current timestamp as the run name. diff --git a/tensorflow/contrib/tpu/profiler/dump_tpu_profile.cc b/tensorflow/contrib/tpu/profiler/dump_tpu_profile.cc index 64e4e6275dd0619796b5088f3b5f25b134a89b22..ebd6185faad28ae7a22eb33f6b358eb2344c9c22 100644 --- a/tensorflow/contrib/tpu/profiler/dump_tpu_profile.cc +++ b/tensorflow/contrib/tpu/profiler/dump_tpu_profile.cc @@ -151,8 +151,7 @@ Status WriteTensorboardTPUProfile(const string& logdir, const string& run, TF_RETURN_IF_ERROR(Env::Default()->RecursivelyCreateDir(profile_run_dir)); // Ignore computation_graph for now. - const bool empty_trace = response.encoded_trace().empty(); - if (!empty_trace) { + if (!response.encoded_trace().empty()) { LOG(INFO) << "Converting trace events to TraceViewer JSON."; TF_RETURN_IF_ERROR( DumpTraceToLogDirectory(profile_run_dir, response.encoded_trace(), os)); @@ -163,11 +162,9 @@ Status WriteTensorboardTPUProfile(const string& logdir, const string& run, TF_RETURN_IF_ERROR(DumpOpProfileToLogDirectory(profile_run_dir, response.op_profile(), os)); } - if (!empty_trace && !response.tool_data().empty()) { - for (const auto& tool_data : response.tool_data()) { - TF_RETURN_IF_ERROR( - DumpToolDataToLogDirectory(profile_run_dir, tool_data, os)); - } + for (const auto& tool_data : response.tool_data()) { + TF_RETURN_IF_ERROR( + DumpToolDataToLogDirectory(profile_run_dir, tool_data, os)); } return Status::OK(); diff --git a/tensorflow/contrib/tpu/profiler/dump_tpu_profile.h b/tensorflow/contrib/tpu/profiler/dump_tpu_profile.h index 2f8656a37be031fbe72fdda355040f88605fde78..29ef977bacfd61e163be49558c5b94277ed479c1 100644 --- a/tensorflow/contrib/tpu/profiler/dump_tpu_profile.h +++ b/tensorflow/contrib/tpu/profiler/dump_tpu_profile.h @@ -29,6 +29,8 @@ namespace tpu { // - Op profile // - Input pipeline analyzer // - Overview page +// Note: this function creates a directory even when all fields in +// ProfileResponse are unset/empty. Status WriteTensorboardTPUProfile(const string& logdir, const string& run, const ProfileResponse& response, std::ostream* os); diff --git a/tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py b/tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py index 885466e5d17b726f4334674fe15f61ed2a529788..78d237e6a201541b6095b101311db48b447cc477 100644 --- a/tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py +++ b/tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py @@ -25,17 +25,19 @@ import sys import tensorflow as tf -flags.DEFINE_string('service_addr', None, - 'Address of TPU profiler service e.g. localhost:8466') -flags.DEFINE_string('logdir', None, - "Path of TensorBoard log directory e.g. /tmp/tb_log, " - "gs://tb_bucket") +flags.DEFINE_string( + 'service_addr', None, 'Address of TPU profiler service e.g. ' + 'localhost:8466') +flags.DEFINE_string( + 'logdir', None, 'Path of TensorBoard log directory e.g. /tmp/tb_log, ' + 'gs://tb_bucket') flags.DEFINE_integer('duration_ms', 2000, 'Duration of tracing in ms.') -flags.DEFINE_integer('num_tracing_attempts', 3, - "Automatically retry N times when no trace event is " - "collected.") -flags.DEFINE_boolean('include_dataset_ops', True, - "Set to false to profile longer TPU device traces.") +flags.DEFINE_integer( + 'num_tracing_attempts', 3, 'Automatically retry N times when no trace ' + 'event is collected.') +flags.DEFINE_boolean( + 'include_dataset_ops', True, 'Set to false to profile longer TPU ' + 'device traces.') FLAGS = flags.FLAGS EXECUTABLE = 'data/capture_tpu_profile' diff --git a/tensorflow/contrib/tpu/profiler/pip_package/setup.py b/tensorflow/contrib/tpu/profiler/pip_package/setup.py index 3dffebe66801bf5f6956ac4297e8dbededf5a4f7..cb6198479908943a546710b94f059d27d9e41a84 100644 --- a/tensorflow/contrib/tpu/profiler/pip_package/setup.py +++ b/tensorflow/contrib/tpu/profiler/pip_package/setup.py @@ -20,7 +20,7 @@ from __future__ import print_function from setuptools import setup -_VERSION = '1.5.0-rc1' +_VERSION = '1.6.0-rc0' CONSOLE_SCRIPTS = [ 'capture_tpu_profile=cloud_tpu_profiler.main:run_main', diff --git a/tensorflow/contrib/tpu/python/tpu/tpu.py b/tensorflow/contrib/tpu/python/tpu/tpu.py index 8fec379aad8a90d06cd05f4858d25656384a12b2..d5f54ff4fd278f0c84f79e0079bfb7a409dfba8d 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu.py @@ -153,10 +153,11 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext): raise NotImplementedError( "Non-resource Variables are not supported inside TPU computations " "(operator name: %s)" % op.name) - # pylint: enable=protected-access if _TPU_REPLICATE_ATTR in op.node_def.attr: raise ValueError("TPU computations cannot be nested") - op.node_def.attr[_TPU_REPLICATE_ATTR].s = compat.as_bytes(self._name) + op._set_attr(_TPU_REPLICATE_ATTR, + attr_value_pb2.AttrValue(s=compat.as_bytes(self._name))) + # pylint: enable=protected-access op.graph.prevent_feeding(op) op.graph.prevent_fetching(op) diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_config.py b/tensorflow/contrib/tpu/python/tpu/tpu_config.py index 0c2580211ab7674d841ca1953c9327df9488bb8e..188db6e2f0d12ed441c043674df9e2e6bec7cc14 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_config.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_config.py @@ -53,7 +53,8 @@ class TPUConfig( num_shards: The number of TPU shards in the system. per_host_input_for_training: If `True`, `input_fn` is invoked Per-Host rather than Per-Core. With Per-Host input pipeline deployment, `input_fn` - is invoked once on each host. To be precise, with a global batch size + is invoked once on each host. With Per-Core input pipeline deployment, it + is invoked once for each core. To be precise, with a global batch size `train_batch_size` in `TPUEstimator` constructor, the batch size for each shard is `train_batch_size` // #hosts. With Per-Core input pipeline deployment, the shard batch size is `train_batch_size` // #cores. diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py index 2ae3a26a853bf4941ac3855ec525293b5a508a2a..56793f11d9c54d00c9cb2a535669de9aec5315e3 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py @@ -21,6 +21,7 @@ from __future__ import print_function import collections from contextlib import contextmanager import copy +import signal import threading import time import traceback @@ -29,6 +30,7 @@ import six from six.moves import queue as Queue # pylint: disable=redefined-builtin from six.moves import xrange # pylint: disable=redefined-builtin +from tensorflow.contrib.summary import summary_ops as contrib_summary from tensorflow.contrib.tpu.python.ops import tpu_ops from tensorflow.contrib.tpu.python.tpu import tpu from tensorflow.contrib.tpu.python.tpu import tpu_config @@ -39,11 +41,13 @@ from tensorflow.contrib.tpu.python.tpu import util as util_lib from tensorflow.core.framework.summary_pb2 import Summary from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.data.ops import dataset_ops from tensorflow.python.estimator import estimator as estimator_lib from tensorflow.python.estimator import model_fn as model_fn_lib from tensorflow.python.estimator import util from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops @@ -59,6 +63,7 @@ from tensorflow.python.training import evaluation from tensorflow.python.training import session_run_hook from tensorflow.python.training import training from tensorflow.python.training import training_util +from tensorflow.python.util import tf_inspect _INITIAL_LOSS = 1e7 _ZERO_LOSS = 0. @@ -68,7 +73,12 @@ _BATCH_SIZE_KEY = 'batch_size' _CROSS_REPLICA_SUM_OP = 'CrossReplicaSum' _RESERVED_PARAMS_KEYS = [_BATCH_SIZE_KEY] -# TODO(b/65703635): Flip the value and remove all dead code. + +# TODO(b/65703635): Flip the value and remove all dead code. Currently, this is +# only used for per-core based deployments. For per-host based pipelines, if a +# user returns a Dataset instance it will be automatically wrapped in a +# tf.while_loop (This can be disabled by returning features and labels +# explicitly). _WRAP_INPUT_FN_INTO_WHILE_LOOP = False @@ -162,10 +172,12 @@ class _TPUContext(object): ``` """ - def __init__(self, config, train_batch_size, eval_batch_size, use_tpu): + def __init__(self, config, train_batch_size, eval_batch_size, + predict_batch_size, use_tpu): self._config = config self._train_batch_size = train_batch_size self._eval_batch_size = eval_batch_size + self._predict_batch_size = predict_batch_size self._use_tpu = use_tpu self._num_shards_or_none = self._config.tpu_config.num_shards self._mode = None @@ -210,39 +222,66 @@ class _TPUContext(object): return (self._mode == model_fn_lib.ModeKeys.TRAIN and not self._config.tpu_config.per_host_input_for_training) - def is_running_on_cpu(self): - """Determines whether the input_fn and model_fn should be invoked on CPU.""" + def is_running_on_cpu(self, is_export_mode=False): + """Determines whether the input_fn and model_fn should be invoked on CPU. + + Args: + is_export_mode: Indicates whether the current mode is for exporting the + model, when mode == PREDICT. Only with this bool, we could + tell whether user is calling the Estimator.predict or + Estimator.export_savedmodel, which are running on TPU and CPU + respectively. Parent class Estimator does not distingush these two. + + Returns: + bool, whether current input_fn or model_fn should be running on CPU. + + Raises: + ValueError: any configuration is invalid. + """ mode = self._assert_mode() - return ((not self._use_tpu) or mode == model_fn_lib.ModeKeys.PREDICT or - (mode == model_fn_lib.ModeKeys.EVAL and - self._eval_batch_size is None)) + + if not self._use_tpu: + return True + + if mode != model_fn_lib.ModeKeys.PREDICT: + return False + + # There are actually 2 use cases when running with mode.PREDICT: prediction + # and saving the model. We run actual predictions on the TPU, but + # model export is run on the CPU. + if is_export_mode: + return True + + if self._predict_batch_size is None: + raise ValueError( + 'predict_batch_size in TPUEstimator constructor should not be ' + '`None` if .predict is running on TPU.') + if self.num_hosts > 1: + raise ValueError( + 'TPUEstimator.predict should be running on single host.') + + return False @property def global_batch_size(self): mode = self._assert_mode() - if mode == model_fn_lib.ModeKeys.EVAL and self._eval_batch_size is None: - raise RuntimeError('Internal error, EVAL on TPU is not enabled, but ' - '`global_batch_size` is called.') - return (self._train_batch_size - if mode == model_fn_lib.ModeKeys.TRAIN else self._eval_batch_size) + if mode == model_fn_lib.ModeKeys.TRAIN: + return self._train_batch_size + elif mode == model_fn_lib.ModeKeys.EVAL: + return self._eval_batch_size + elif mode == model_fn_lib.ModeKeys.PREDICT: + return self._predict_batch_size + else: + return None @property def batch_size_for_input_fn(self): """Returns the shard batch size for `input_fn`.""" - mode = self._assert_mode() - # Special case for eval. - if mode == model_fn_lib.ModeKeys.EVAL and self._eval_batch_size is None: - return None + global_batch_size = self.global_batch_size + if self.is_running_on_cpu(): - if mode == model_fn_lib.ModeKeys.TRAIN: - return self._train_batch_size - if mode == model_fn_lib.ModeKeys.EVAL: - return self._eval_batch_size - return None + return global_batch_size - global_batch_size = ( - self._train_batch_size - if mode == model_fn_lib.ModeKeys.TRAIN else self._eval_batch_size) # On TPU if self.is_input_sharded_per_core(): return global_batch_size // self.num_cores @@ -252,22 +291,13 @@ class _TPUContext(object): @property def batch_size_for_model_fn(self): """Returns the shard batch size for `model_fn`.""" - mode = self._assert_mode() - # Special case for eval. - if mode == model_fn_lib.ModeKeys.EVAL and self._eval_batch_size is None: - return None + global_batch_size = self.global_batch_size + if self.is_running_on_cpu(): - if mode == model_fn_lib.ModeKeys.TRAIN: - return self._train_batch_size - if mode == model_fn_lib.ModeKeys.EVAL: - return self._eval_batch_size - return None + return global_batch_size # On TPU. always sharded per core. - if mode == model_fn_lib.ModeKeys.TRAIN: - return self._train_batch_size // self.num_cores - else: - return self._eval_batch_size // self.num_cores + return global_batch_size // self.num_cores @property def master_job(self): @@ -384,7 +414,8 @@ class TPUEstimatorSpec( 'train_op', 'eval_metrics', 'export_outputs', - 'scaffold_fn' + 'scaffold_fn', + 'host_call' ])): """Ops and objects returned from a `model_fn` and passed to `TPUEstimator`. @@ -410,6 +441,15 @@ class TPUEstimatorSpec( `scaffold_fn` is a function running on CPU to generate the `Scaffold`. This function should not capture any Tensors in `model_fn`. + + `host_call` is a tuple of a `function` and a list or dictionary of `tensors` + to pass to that function and returns a list of Tensors. `host_call` currently + works for train() and evaluate(). The Tensors returned by the function is + executed on the CPU on every step, so there is communication overhead when + sending tensors from TPU to CPU. To reduce the overhead, try reducing the + size of the tensors. The `tensors` are concatenated along their major (batch) + dimension, and so must be >= rank 1. The `host_call` is useful for writing + summaries with @{tf.contrib.summary.create_file_writer}. """ def __new__(cls, @@ -419,10 +459,15 @@ class TPUEstimatorSpec( train_op=None, eval_metrics=None, export_outputs=None, - scaffold_fn=None): + scaffold_fn=None, + host_call=None): """Creates a validated `TPUEstimatorSpec` instance.""" + host_calls = {} if eval_metrics is not None: - _EvalMetrics.validate(eval_metrics) + host_calls['eval_metrics'] = eval_metrics + if host_call is not None: + host_calls['host_call'] = host_call + _OutfeedHostCall.validate(host_calls) return super(TPUEstimatorSpec, cls).__new__( cls, mode=mode, @@ -431,12 +476,23 @@ class TPUEstimatorSpec( train_op=train_op, eval_metrics=eval_metrics, export_outputs=export_outputs, - scaffold_fn=scaffold_fn) + scaffold_fn=scaffold_fn, + host_call=host_call) def as_estimator_spec(self): """Creates an equivalent `EstimatorSpec` used by CPU train/eval.""" - eval_metric_ops = _EvalMetrics.to_metric_metric_ops_for_cpu( - self.eval_metrics) + host_calls = {} + if self.eval_metrics is not None: + host_calls['eval_metrics'] = self.eval_metrics + if self.host_call is not None: + host_calls['host_call'] = wrap_hostcall_with_global_step(self.host_call) + host_call_ret = _OutfeedHostCall.create_cpu_hostcall(host_calls) + eval_metric_ops = None + if self.eval_metrics is not None: + eval_metric_ops = host_call_ret['eval_metrics'] + hooks = None + if self.host_call is not None: + hooks = [_OutfeedHostCallHook(host_call_ret['host_call'])] scaffold = self.scaffold_fn() if self.scaffold_fn else None return model_fn_lib.EstimatorSpec( mode=self.mode, @@ -445,7 +501,10 @@ class TPUEstimatorSpec( train_op=self.train_op, eval_metric_ops=eval_metric_ops, export_outputs=self.export_outputs, - scaffold=scaffold) + scaffold=scaffold, + training_hooks=hooks, + evaluation_hooks=hooks, + prediction_hooks=hooks) class _OpQueueContext(object): @@ -467,12 +526,12 @@ class _OpQueueContext(object): def read_iteration_counts(self): while True: - signal = self._queue.get(block=True) - logging.debug('%s read signal %s', self._name, signal) - if signal == _SIGNAL.STOP: - logging.info('%s received signal, stopping.', self._name) + iterations = self._queue.get(block=True) + logging.debug('%s read iterations %s', self._name, iterations) + if iterations == _SIGNAL.STOP: + logging.info('%s received shutdown signal, stopping.', self._name) return - yield signal + yield iterations def join(self): logging.info('Shutting down %s thread.' % self._name) @@ -480,6 +539,22 @@ class _OpQueueContext(object): self._thread.join() +class _OpSignalOnceQueueContext(_OpQueueContext): + """Manages work queue and thread for a infeed/outfeed thread. + + This subclass only signals once. + """ + + def __init__(self, name, target, args): + super(_OpSignalOnceQueueContext, self).__init__(name, target, args) + self._has_signaled = False + + def send_next_batch_signal(self, iterations): + if not self._has_signaled: + self._queue.put(iterations) + self._has_signaled = True + + class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook): """A Session hook setting up the TPU initialization, infeed, and outfeed. @@ -489,12 +564,19 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook): dequeue. """ - def __init__(self, ctx, enqueue_ops, dequeue_ops=None): + def __init__(self, + ctx, + enqueue_ops, + dequeue_ops, + run_infeed_loop_on_coordinator=True): self._master_job = ctx.master_job self._enqueue_ops = enqueue_ops self._dequeue_ops = dequeue_ops + + self._run_infeed_loop_on_coordinator = run_infeed_loop_on_coordinator self._initial_infeed_sleep_secs = ( ctx.config.tpu_config.initial_infeed_sleep_secs) + self._session_cancel_timer = None self._feed_error = None @@ -503,8 +585,15 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook): def begin(self): logging.info('TPU job name %s', self._master_job) self._iterations_per_loop_var = _create_or_get_iterations_per_loop() - self._init_op = [tpu.initialize_system(job=self._master_job)] - self._finalize_op = [tpu.shutdown_system(job=self._master_job)] + self._init_ops = [tpu.initialize_system(job=self._master_job)] + self._finalize_ops = [tpu.shutdown_system(job=self._master_job)] + + summary_writer_init_ops = contrib_summary.summary_writer_initializer_op() + self._init_ops.extend(summary_writer_init_ops) + # Get all the writer resources from the initializer, so we know what to + # flush. + for op in summary_writer_init_ops: + self._finalize_ops.append(contrib_summary.flush(writer=op.inputs[0])) def _log_error(self, session, error): """Log an infeed or outfeed error. @@ -516,8 +605,9 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook): emitting a stack trace for the infeed. Args: - session: `tf.Session`, session to be terminated - error: exception that triggered logging. + session: `tf.Session`, session to be terminated error: exception that + triggered logging. + error: the Exception to log. """ logging.warning( '\n\n' @@ -569,15 +659,15 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook): logging.info('%s thread starting after sleep', self._name) try: - if _WRAP_INPUT_FN_INTO_WHILE_LOOP: - for _ in queue_ctx.read_iteration_counts(): - session.run(self._enqueue_ops) - else: + if self._run_infeed_loop_on_coordinator: for count, steps in enumerate(queue_ctx.read_iteration_counts()): for i in xrange(steps): logging.debug('Infeed enqueue for iteration (%d, %d)', count, i) session.run(self._enqueue_ops) - logging.debug('Infeed thread finished, shutting down.') + else: + for _ in queue_ctx.read_iteration_counts(): + session.run(self._enqueue_ops) + logging.info('Infeed thread finished, shutting down.') except Exception as e: # pylint: disable=broad-except self._log_error(session, e) @@ -588,23 +678,25 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook): for i in xrange(steps): logging.debug('Outfeed dequeue for iteration (%d, %d)', count, i) session.run(self._dequeue_ops) + logging.info('Outfeed thread finished, shutting down.') except Exception as e: # pylint: disable=broad-except self._log_error(session, e) + def _create_infeed_controller(self, name, target, args): + return _OpQueueContext(name=name, target=target, args=args) + def after_create_session(self, session, coord): logging.info('Init TPU system') - session.run( - self._init_op, - options=config_pb2.RunOptions(timeout_in_ms=5 * 60 * 1000)) + session.run(self._init_ops, + options=config_pb2.RunOptions(timeout_in_ms=5 * 60 * 1000)) logging.info('Start infeed thread controller') - self._infeed_controller = _OpQueueContext( + self._infeed_controller = self._create_infeed_controller( name='InfeedController', target=self._run_infeed, args=(session,)) - if self._dequeue_ops is not None: - logging.info('Start outfeed thread controller') - self._outfeed_controller = _OpQueueContext( - name='OutfeedController', target=self._run_outfeed, args=(session,)) + logging.info('Start outfeed thread controller') + self._outfeed_controller = _OpQueueContext( + name='OutfeedController', target=self._run_outfeed, args=(session,)) def before_run(self, run_context): if self._feed_error: @@ -617,11 +709,9 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook): logging.info('Enqueue next (%d) batch(es) of data to infeed.', iterations) self._infeed_controller.send_next_batch_signal(iterations) - if self._dequeue_ops is not None: - # TODO(xiejw): Refactor the outfeed dequeue into tf.while_loop. - logging.info('Dequeue next (%d) batch(es) of data from outfeed.', - iterations) - self._outfeed_controller.send_next_batch_signal(iterations) + logging.info('Dequeue next (%d) batch(es) of data from outfeed.', + iterations) + self._outfeed_controller.send_next_batch_signal(iterations) def end(self, session): if self._session_cancel_timer: @@ -632,12 +722,21 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook): logging.info('Stop infeed thread controller') self._infeed_controller.join() - if self._dequeue_ops is not None: - logging.info('Stop output thread controller') - self._outfeed_controller.join() + logging.info('Stop output thread controller') + self._outfeed_controller.join() logging.info('Shutdown TPU system.') - session.run(self._finalize_op) + session.run(self._finalize_ops) + + +class TPUInfeedOutfeedSessionHookForPrediction(TPUInfeedOutfeedSessionHook): + + def __init__(self, ctx, enqueue_ops, dequeue_ops): + super(TPUInfeedOutfeedSessionHookForPrediction, self).__init__( + ctx, enqueue_ops, dequeue_ops, run_infeed_loop_on_coordinator=False) + + def _create_infeed_controller(self, name, target, args): + return _OpSignalOnceQueueContext(name=name, target=target, args=args) class _TPUStopAtStepHook(session_run_hook.SessionRunHook): @@ -727,6 +826,47 @@ class _SetEvalIterationsHook(session_run_hook.SessionRunHook): self._iterations_per_loop_var.load(self._num_steps, session=session) +class _StoppingPredictHook(session_run_hook.SessionRunHook): + """Hook that requests stop according to the stopping signal in prediction.""" + + def __init__(self, scalar_stopping_signal): + self._scalar_stopping_signal = scalar_stopping_signal + + def begin(self): + self._iterations_per_loop_var = _create_or_get_iterations_per_loop() + + def after_create_session(self, session, coord): + # This is not necessary as we do not run infeed enqueue and outfeed dequeue + # in side threads for prediction model. But it makes the + # TPUInfeedOutfeedSessionHook prints nice message. + self._iterations_per_loop_var.load(1, session=session) + + def before_run(self, run_context): + return session_run_hook.SessionRunArgs(self._scalar_stopping_signal) + + def after_run(self, run_context, run_values): + _ = run_context + scalar_stopping_signal = run_values.results + if _StopSignals.should_stop(scalar_stopping_signal): + # NOTE(xiejw): In prediction, stopping signals are inserted for each + # batch. And we append one more batch to signal the system it should stop. + # The data flow might look like + # + # batch 0: images, labels, stop = 0 (user provideded) + # batch 1: images, labels, stop = 0 (user provideded) + # ... + # batch 99: images, labels, stop = 0 (user provideded) + # batch 100: images, labels, stop = 1 (TPUEstimator appended) + # + # where the final batch (id = 100) is appended by TPUEstimator, so we + # should drop it before returning the predictions to user. + # To achieve that, we throw the OutOfRangeError in after_run. Once + # Monitored Session sees this error in SessionRunHook.after_run, the + # "current" prediciton, i.e., batch with id=100, will be discarded + # immediately + raise errors.OutOfRangeError(None, None, 'Stopped by stopping signal.') + + def generate_per_core_enqueue_ops_fn_for_host(ctx, input_fn, inputs_structure_recorder): """Generates infeed enqueue ops for per-core input_fn on a single host.""" @@ -738,11 +878,14 @@ def generate_per_core_enqueue_ops_fn_for_host(ctx, input_fn, per_host_sharded_inputs = [] for core_ordinal in range(num_cores_per_host): with ops.name_scope('ordinal_%d' % (core_ordinal)): - inputs = input_fn() - if isinstance(inputs, tuple): - features, labels = inputs - else: - features, labels = inputs, None + inputs = _Inputs.from_input_fn(input_fn()) + if inputs.is_dataset: + raise TypeError( + '`input_fn` returning `Dataset` is not yet supported in ' + 'per-Core input pipeline deployment yet. Please set ' + 'TPUConfig.per_host_input_for_training to True or return ' + '`features` and `labels` from `input_fn`') + features, labels = inputs.features_and_labels() inputs_structure_recorder.validate_and_record_structure( features, labels) @@ -769,18 +912,37 @@ def generate_per_host_enqueue_ops_fn_for_host( """Generates infeed enqueue ops for per-host input_fn on a single host.""" captured_infeed_queue = _CapturedObject() + hooks = [] + + with ops.device(device): + inputs = _Inputs.from_input_fn(input_fn()) + + is_dataset = inputs.is_dataset + if ctx.mode == model_fn_lib.ModeKeys.PREDICT: + if not is_dataset: + raise TypeError( + 'For mode PREDICT, `input_fn` must return `Dataset` instead of ' + '`features` and `labels`.') + inputs = _InputsWithStoppingSignals( + dataset=inputs.dataset, batch_size=ctx.batch_size_for_input_fn) + + if is_dataset: + hooks.append(inputs.dataset_initializer_hook()) + def enqueue_ops_fn(): with ops.device(device): num_cores_per_host = ctx.num_of_cores_per_host - inputs = input_fn() - if isinstance(inputs, tuple): - features, labels = inputs - else: - features, labels = inputs, None - inputs_structure_recorder.validate_and_record_structure(features, labels) + # Convert user input to features and labels. If the user returns a + # dataset, it is initialized and the features and labels extracted via + # `dataset.iterator.get_next()` + features, labels = inputs.features_and_labels() + signals = inputs.signals() + + inputs_structure_recorder.validate_and_record_structure( + features, labels, signals) unsharded_tensor_list = ( inputs_structure_recorder.flatten_features_and_labels( - features, labels)) + features, labels, signals)) infeed_queue = tpu_feed.InfeedQueue( tuple_types=[t.dtype for t in unsharded_tensor_list], @@ -792,9 +954,15 @@ def generate_per_host_enqueue_ops_fn_for_host( per_host_enqueue_ops = ( infeed_queue.split_inputs_and_generate_enqueue_ops( unsharded_tensor_list, placement_function=lambda x: device)) - return per_host_enqueue_ops + if signals is None: + return per_host_enqueue_ops + else: + return { + 'ops': per_host_enqueue_ops, + 'signals': signals, + } - return enqueue_ops_fn, captured_infeed_queue + return enqueue_ops_fn, captured_infeed_queue, hooks, is_dataset class _InputPipeline(object): @@ -834,6 +1002,7 @@ class _InputPipeline(object): self._feature_names = [] self._label_names = [] self._has_labels = False + self._signals_helper = None # Internal state. self._initialized = False @@ -841,7 +1010,7 @@ class _InputPipeline(object): def has_labels(self): return self._has_labels - def validate_and_record_structure(self, features, labels): + def validate_and_record_structure(self, features, labels, signals=None): """Validates and records the structure of features` and `labels`.""" def _extract_key_names(tensor_or_dict): @@ -854,6 +1023,10 @@ class _InputPipeline(object): feature_names = _extract_key_names(features) label_names = _extract_key_names(labels) + if signals is not None and self._signals_helper is None: + # Record signals helper. + self._signals_helper = _SignalsHelper(signals) + if self._initialized: # Verify the structure is same. The following should never happen. assert feature_names == self._feature_names, 'feature keys mismatched' @@ -866,7 +1039,7 @@ class _InputPipeline(object): self._label_names = label_names self._has_labels = has_labels - def flatten_features_and_labels(self, features, labels): + def flatten_features_and_labels(self, features, labels, signals=None): """Flattens the `features` and `labels` to a single tensor list.""" flattened_inputs = [] if self._feature_names: @@ -882,6 +1055,9 @@ class _InputPipeline(object): flattened_inputs.extend([labels[name] for name in self._label_names]) else: flattened_inputs.append(labels) + + if signals is not None: + flattened_inputs.extend(_SignalsHelper.as_tensor_list(signals)) return flattened_inputs def unflatten_features_and_labels(self, flattened_inputs): @@ -907,7 +1083,11 @@ class _InputPipeline(object): else: expected_num_labels = 0 - expected_num_tensors = expected_num_features + expected_num_labels + expected_num_signals = ( + self._signals_helper.num_signals if self._signals_helper else 0) + + expected_num_tensors = ( + expected_num_features + expected_num_labels + expected_num_signals) if expected_num_tensors != len(flattened_inputs): raise ValueError( @@ -924,13 +1104,20 @@ class _InputPipeline(object): if expected_num_labels == 0: unflattened_label = None elif self._label_names: - unflattened_label = dict( - zip(self._label_names, flattened_inputs[expected_num_features:])) + label_list = flattened_inputs[ + expected_num_features:expected_num_features + expected_num_labels] + unflattened_label = dict(zip(self._label_names, label_list)) else: # Single tensor case. unflattened_label = flattened_inputs[expected_num_features] - return unflattened_features, unflattened_label + signals = None + if expected_num_signals != 0: + tensor_list_for_signals = flattened_inputs[ + expected_num_features + expected_num_labels:] + signals = self._signals_helper.unflatten(tensor_list_for_signals) + + return _Inputs(unflattened_features, unflattened_label, signals=signals) def __init__(self, input_fn, batch_axis, ctx): """Constructor. @@ -958,7 +1145,8 @@ class _InputPipeline(object): # While tf.while_loop is called, the body function, which invokes # `enqueue_fn` passed in, is called to construct the graph. So, input_fn # structure is recorded. - enqueue_ops = self._invoke_input_fn_and_record_structure() + enqueue_ops, all_hooks, run_infeed_loop_on_coordinator = ( + self._invoke_input_fn_and_record_structure()) self._validate_input_pipeline() @@ -969,14 +1157,18 @@ class _InputPipeline(object): return self._inputs_structure_recorder.unflatten_features_and_labels( values) - return (enqueue_ops, dequeue_fn) + return (enqueue_ops, dequeue_fn, all_hooks, run_infeed_loop_on_coordinator) def _invoke_input_fn_and_record_structure(self): """Deploys the input pipeline and record input structure.""" enqueue_ops = [] infeed_queues = [] + all_hooks = [] num_hosts = self._ctx.num_hosts tpu_host_placement_fn = self._ctx.tpu_host_placement_function + + run_infeed_loop_on_coordinator = True + if self._sharded_per_core: # Per-Core input pipeline deployment. # Invoke input pipeline for each core and placed on the corresponding @@ -990,6 +1182,7 @@ class _InputPipeline(object): self._ctx, self._input_fn, self._inputs_structure_recorder)) if _WRAP_INPUT_FN_INTO_WHILE_LOOP: + run_infeed_loop_on_coordinator = False enqueue_ops.append( _wrap_computation_in_while_loop( device=host_device, op_fn=enqueue_ops_fn)) @@ -1003,15 +1196,32 @@ class _InputPipeline(object): host_device = tpu_host_placement_fn(host_id=host_id) with ops.device(host_device): with ops.name_scope('input_pipeline_task%d' % (host_id)): - enqueue_ops_fn, captured_infeed_queue = ( + enqueue_ops_fn, captured_infeed_queue, hooks, is_dataset = ( generate_per_host_enqueue_ops_fn_for_host( self._ctx, self._input_fn, self._inputs_structure_recorder, self._batch_axis, host_device)) - - if _WRAP_INPUT_FN_INTO_WHILE_LOOP: + all_hooks.extend(hooks) + + # NOTE(xiejw): We dispatch here based on the return type of the + # users `input_fn`. + # + # 1. If input_fn returns a Dataset instance, we initialize the + # iterator outside of tf.while_loop, and call the iterator.get_next + # inside tf.while_loop. This should be always safe. + # + # 2. If input_fn returns (features, labels), it is too late to wrap + # them inside tf.while_loop, as resource initialization cannot be + # handled in TF control flow properly. In this case, we will use + # python loop to enqueue the data into TPU system. This may be + # slow compared to the previous case. + if is_dataset: + run_infeed_loop_on_coordinator = False + wrap_fn = ( + _wrap_computation_in_while_loop + if self._ctx.mode != model_fn_lib.ModeKeys.PREDICT else + _wrap_computation_in_while_loop_with_stopping_signals) enqueue_ops.append( - _wrap_computation_in_while_loop( - device=host_device, op_fn=enqueue_ops_fn)) + wrap_fn(device=host_device, op_fn=enqueue_ops_fn)) else: enqueue_ops.append(enqueue_ops_fn()) infeed_queues.append(captured_infeed_queue.get()) @@ -1019,7 +1229,7 @@ class _InputPipeline(object): # dequeue is dtypes and types. So, any one can be used. Here, grab the # first one. self._infeed_queue = infeed_queues[0] - return enqueue_ops + return enqueue_ops, all_hooks, run_infeed_loop_on_coordinator def _validate_input_pipeline(self): # Perform some sanity checks to log user friendly information. We should @@ -1076,29 +1286,38 @@ class _ModelFnWrapper(object): infeed dequeue channel. Returns: - A Fn representing the train step for TPU. + A tuple of train_fn, host_calls, and captured scaffold_fn. The train_fn + representing the train step for TPU. """ + host_call = _OutfeedHostCall(self._ctx) captured_scaffold_fn = _CapturedObject() def train_step(loss): """Training step function for use inside a while loop.""" del loss # unused; required in function signature. - features, labels = dequeue_fn() + inputs = dequeue_fn() + features, labels = inputs.features_and_labels() estimator_spec = self._verify_estimator_spec( self._call_model_fn(features, labels)) loss, train_op = estimator_spec.loss, estimator_spec.train_op + host_call_outfeed_ops = [] if isinstance(estimator_spec, TPUEstimatorSpec): captured_scaffold_fn.capture(estimator_spec.scaffold_fn) + if estimator_spec.host_call is not None: + host_call.record({ + 'host_call': wrap_hostcall_with_global_step( + estimator_spec.host_call)}) + host_call_outfeed_ops = host_call.create_enqueue_op() else: captured_scaffold_fn.capture(None) - with ops.control_dependencies([train_op]): + with ops.control_dependencies([train_op] + host_call_outfeed_ops): return array_ops.identity(loss) - return train_step, captured_scaffold_fn + return train_step, host_call, captured_scaffold_fn def convert_to_single_tpu_eval_step(self, dequeue_fn): """Converts user provided model_fn` as a single eval step on TPU. @@ -1123,15 +1342,16 @@ class _ModelFnWrapper(object): infeed dequeue channel. Returns: - A tuple of eval_fn and eval_metrics. The eval_fn representing the eval - step for TPU. and eval_metrics is an `_EvalMetrics` instance. + A tuple of eval_fn, host_calls, and captured scaffold_fn. The eval_fn + representing the eval step for TPU. """ - eval_metrics = _EvalMetrics(self._ctx) + host_calls = _OutfeedHostCall(self._ctx) captured_scaffold_fn = _CapturedObject() def eval_step(total_loss): """Evaluation step function for use inside a while loop.""" - features, labels = dequeue_fn() + inputs = dequeue_fn() + features, labels = inputs.features_and_labels() tpu_estimator_spec = self._call_model_fn(features, labels) if not isinstance(tpu_estimator_spec, TPUEstimatorSpec): @@ -1141,15 +1361,68 @@ class _ModelFnWrapper(object): loss = tpu_estimator_spec.loss captured_scaffold_fn.capture(tpu_estimator_spec.scaffold_fn) - eval_metrics.record(tpu_estimator_spec) - outfeed_ops = tpu_ops.outfeed_enqueue_tuple(eval_metrics.outfeed_tensors) - - with ops.control_dependencies([outfeed_ops]): + to_record = {} + to_record['eval_metrics'] = tpu_estimator_spec.eval_metrics + if tpu_estimator_spec.host_call is not None: + # We assume that evaluate won't update global step, so we don't wrap + # this host_call. + to_record['host_call'] = tpu_estimator_spec.host_call + host_calls.record(to_record) + + with ops.control_dependencies(host_calls.create_enqueue_op()): return math_ops.add(total_loss, loss) - return eval_step, eval_metrics, captured_scaffold_fn + return eval_step, host_calls, captured_scaffold_fn + + def convert_to_single_tpu_predict_step(self, dequeue_fn): + """Converts user provided model_fn` as a single predict step on TPU. + + Args: + dequeue_fn: The function to retrieve inputs, features and labels, from TPU + infeed dequeue channel. + + Returns: + A tuple of predict_fn, host_calls, and captured scaffold_fn. The + predict_fn representing the predict step for TPU. + """ + host_calls = _OutfeedHostCall(self._ctx) + captured_scaffold_fn = _CapturedObject() + + def predict_step(unused_scalar_stopping_signal): + """Evaluation step function for use inside a while loop.""" + inputs = dequeue_fn() + features, labels = inputs.features_and_labels() + stopping_signals = inputs.signals() + + assert stopping_signals is not None, ( + 'Internal Error: `signals` is missing.') - def _call_model_fn(self, features, labels): + tpu_estimator_spec = self._call_model_fn( + features, labels, is_export_mode=False) + if not isinstance(tpu_estimator_spec, TPUEstimatorSpec): + raise RuntimeError( + 'estimator_spec used by TPU prediction must have type' + '`TPUEstimatorSpec`. Got {}'.format(type(tpu_estimator_spec))) + + captured_scaffold_fn.capture(tpu_estimator_spec.scaffold_fn) + to_record = {} + identity_fn = lambda **kwargs: kwargs + # TODO(xiejw): Adds validation for prediction dictionrary. + # TODO(xiejw): Adds support for single tensor as predictions. + if not isinstance(tpu_estimator_spec.predictions, dict): + raise TypeError('TPUEstimatorSpec.predictions must be dict of Tensors.') + to_record['predictions'] = [identity_fn, tpu_estimator_spec.predictions] + to_record['signals'] = [identity_fn, stopping_signals] + if tpu_estimator_spec.host_call is not None: + to_record['host_call'] = tpu_estimator_spec.host_call + host_calls.record(to_record) + + with ops.control_dependencies(host_calls.create_enqueue_op()): + return _StopSignals.as_scalar_stopping_signal(stopping_signals) + + return predict_step, host_calls, captured_scaffold_fn + + def _call_model_fn(self, features, labels, is_export_mode=True): """Calls the model_fn with required parameters.""" model_fn_args = util.fn_args(self._model_fn) kwargs = {} @@ -1180,7 +1453,7 @@ class _ModelFnWrapper(object): params[_BATCH_SIZE_KEY] = batch_size_for_model_fn estimator_spec = self._model_fn(features=features, **kwargs) - if (self._ctx.is_running_on_cpu() and + if (self._ctx.is_running_on_cpu(is_export_mode) and isinstance(estimator_spec, TPUEstimatorSpec)): # The estimator_spec will be passed to `Estimator` directly, which expects # type `EstimatorSpec`. @@ -1207,158 +1480,241 @@ class _ModelFnWrapper(object): return estimator_spec -class _EvalMetrics(object): - """Class wraps TPUEstimator.eval_metrics.""" +class _OutfeedHostCall(object): + """Support for `eval_metrics` and `host_call` in TPUEstimatorSpec.""" def __init__(self, ctx): self._ctx = ctx - self._metric_fn = None - self._is_dict = False - self._tensor_keys = [] - self._tensors = [] - self._tensor_dtypes = [] - self._tensor_shapes = [] - self._recorded = False + self._names = [] + # All of these are dictionaries of lists keyed on the name. + self._host_fns = {} + self._tensor_keys = collections.defaultdict(list) + self._tensors = collections.defaultdict(list) + self._tensor_dtypes = collections.defaultdict(list) + self._tensor_shapes = collections.defaultdict(list) @staticmethod - def validate(eval_metrics): - """Validates the `eval_metrics` in `TPUEstimatorSpec`.""" - - if not isinstance(eval_metrics, (tuple, list)): - raise ValueError('eval_metrics should be tuple or list') - if len(eval_metrics) != 2: - raise ValueError('eval_metrics should have two elements.') - if not callable(eval_metrics[0]): - raise TypeError('eval_metrics[0] should be callable.') - if not isinstance(eval_metrics[1], (tuple, list, dict)): - raise ValueError('eval_metrics[1] should be tuple or list, or dict.') - - if isinstance(eval_metrics[1], (tuple, list)): - fn_args = util.fn_args(eval_metrics[0]) - if len(eval_metrics[1]) != len(fn_args): - raise RuntimeError( - 'In TPUEstimatorSpec.eval_metrics, length of tensors does not ' - 'match method args of metric_fn.') + def validate(host_calls): + """Validates the `eval_metrics` and `host_call` in `TPUEstimatorSpec`.""" + + for name, host_call in host_calls.items(): + if not isinstance(host_call, (tuple, list)): + raise ValueError('{} should be tuple or list'.format(name)) + if len(host_call) != 2: + raise ValueError('{} should have two elements.'.format(name)) + if not callable(host_call[0]): + raise TypeError('{}[0] should be callable.'.format(name)) + if not isinstance(host_call[1], (tuple, list, dict)): + raise ValueError('{}[1] should be tuple or list, or dict.'.format(name)) + + if isinstance(host_call[1], (tuple, list)): + fullargspec = tf_inspect.getfullargspec(host_call[0]) + fn_args = util.fn_args(host_call[0]) + # wrapped_hostcall_with_global_step uses varargs, so we allow that. + if fullargspec.varargs is None and len(host_call[1]) != len(fn_args): + raise RuntimeError( + 'In TPUEstimatorSpec.{}, length of tensors {} does not match ' + 'method args of the function, which takes {}.'.format( + name, len(host_call[1]), len(fn_args))) @staticmethod - def to_metric_metric_ops_for_cpu(eval_metrics): - """Converts `TPUEstimatorSpec.eval_metrics` to `eval_metric_ops` for CPU.""" - if not eval_metrics: - return None - - _EvalMetrics.validate(eval_metrics) + def create_cpu_hostcall(host_calls): + """Runs on the host_call on CPU instead of TPU when use_tpu=False.""" + + _OutfeedHostCall.validate(host_calls) + ret = {} + for name, host_call in host_calls.items(): + host_fn, tensors = host_call + if isinstance(tensors, (tuple, list)): + ret[name] = host_fn(*tensors) + else: + # Must be dict. + try: + ret[name] = host_fn(**tensors) + except TypeError as e: + logging.warning( + 'Exception while calling %s: %s. It is likely the tensors ' + '(%s[1]) do not match the ' + 'function\'s arguments', name, e, name) + raise e + return ret + + def record(self, host_calls): + """Records the host_call structure.""" + + for name, host_call in host_calls.items(): + host_fn, tensor_list_or_dict = host_call + self._names.append(name) + self._host_fns[name] = host_fn + + if isinstance(tensor_list_or_dict, dict): + for (key, tensor) in six.iteritems(tensor_list_or_dict): + self._tensor_keys[name].append(key) + self._tensors[name].append(tensor) + self._tensor_dtypes[name].append(tensor.dtype) + self._tensor_shapes[name].append(tensor.shape) + else: + # List or tuple. + self._tensor_keys[name] = None + for tensor in tensor_list_or_dict: + self._tensors[name].append(tensor) + self._tensor_dtypes[name].append(tensor.dtype) + self._tensor_shapes[name].append(tensor.shape) - metric_fn, tensors = eval_metrics + def create_enqueue_op(self): + """Create the op to enqueue the recorded host_calls. - if isinstance(tensors, (tuple, list)): - return metric_fn(*tensors) - else: - # Must be dict. - try: - return metric_fn(**tensors) - except TypeError as e: - logging.warning( - 'Exception while calling metric_fn for evalution: %s. ' - 'It is likely the tensors (eval_metrics[1]) do not match the ' - 'metric_fn arguments', e) - raise e - - def record(self, spec): - """Records the eval_metrics structure in `spec`.""" - if self._recorded: - raise RuntimeError('Eval metrics have been recorded already.') - - self._metric_fn, tensor_list_or_dict = spec.eval_metrics - - if isinstance(tensor_list_or_dict, dict): - self._is_dict = True - for (key, tensor) in six.iteritems(tensor_list_or_dict): - self._tensor_keys.append(key) - self._tensors.append(tensor) - self._tensor_dtypes.append(tensor.dtype) - self._tensor_shapes.append(tensor.shape) - else: - # List or tuple. - self._is_dict = False - self._tensors = tensor_list_or_dict - for tensor in tensor_list_or_dict: - self._tensor_dtypes.append(tensor.dtype) - self._tensor_shapes.append(tensor.shape) - self._recorded = True - - @property - def outfeed_tensors(self): - if not self._recorded: - raise RuntimeError('Eval metrics have not been recorded yet') - return self._tensors + Returns: + A list of enqueue ops, which is empty if there are no host calls. + """ + if not self._names: + return [] - def to_metric_metric_ops_for_tpu(self, dummy_update_op): - """Creates the eval_metric_ops now based on the TPU outfeed. + tensors = [] + # TODO(jhseu): Consider deduping tensors. + for name in self._names: + tensors.extend(self._tensors[name]) + return [tpu_ops.outfeed_enqueue_tuple(tensors)] - `eval_metric_ops` is defined in `EstimatorSpec`. From all shards, tensors - are dequeued from outfeed and then concatenated (along batch size dimension) - to form global-like tensors. All global-like tensors are passed to the - metric fn. + def create_tpu_hostcall(self): + """Sends the tensors through outfeed and runs the host_fn on CPU. - Args: - dummy_update_op: A dummy update op. + The tensors are concatenated along dimension 0 to form a global tensor + across all shards. The concatenated function is passed to the host_fn and + executed on the first host. Returns: - A tuple of (`eval_metric_ops` and `update_ops`), where `update_ops` should - be invoked in Outfeed dequeue thread, which drive the outfeed dequeue and - update the state of metrics. + A dictionary mapping name to the return type of the host_call by that + name. Raises: RuntimeError: If outfeed tensor is scalar. """ + if not self._names: + return [] - num_cores = self._ctx.num_cores - + ret = {} # For each i, dequeue_ops[i] is a list containing the tensors from all # shards. This list is concatenated later. dequeue_ops = [] - for i in xrange(len(self._tensors)): - dequeue_ops.append([]) - - # Outfeed ops execute on each JF node. + tensor_dtypes = [] + tensor_shapes = [] + for name in self._names: + for _ in self._tensors[name]: + dequeue_ops.append([]) + for dtype in self._tensor_dtypes[name]: + tensor_dtypes.append(dtype) + for shape in self._tensor_shapes[name]: + tensor_shapes.append(shape) + + # Outfeed ops execute on each JF node. Note: we must constraint it such that + # we have at most one outfeed dequeue and enqueue. tpu_device_placement_fn = self._ctx.tpu_device_placement_function - for i in xrange(num_cores): + for i in xrange(self._ctx.num_cores): with ops.device(tpu_device_placement_fn(i)): outfeed_tensors = tpu_ops.outfeed_dequeue_tuple( - dtypes=self._tensor_dtypes, shapes=self._tensor_shapes) + dtypes=tensor_dtypes, shapes=tensor_shapes) for j, item in enumerate(outfeed_tensors): dequeue_ops[j].append(item) - # It is assumed evaluation always happends on single host TPU system. So, + # Deconstruct dequeue ops. + dequeue_ops_by_name = {} + pos = 0 + for name in self._names: + dequeue_ops_by_name[name] = dequeue_ops[pos:pos+len(self._tensors[name])] + pos += len(self._tensors[name]) + + # It is assumed evaluation always happens on single host TPU system. So, # place all ops on tpu host if possible. + # + # TODO(jhseu): Evaluate whether this is right for summaries. with ops.device(self._ctx.tpu_host_placement_function(core_id=0)): - for i, item in enumerate(dequeue_ops): - if dequeue_ops[i][0].shape.ndims == 0: - raise RuntimeError( - 'All tensors outfed from TPU should preseve batch size ' - 'dimension, but got scalar {}'.format(dequeue_ops[i][0])) - # TODO(xiejw): Allow users to specify the axis for batch size dimension. - dequeue_ops[i] = array_ops.concat(dequeue_ops[i], axis=0) + for name in self._names: + dequeue_ops = dequeue_ops_by_name[name] + for i, item in enumerate(dequeue_ops): + if dequeue_ops[i][0].shape.ndims == 0: + raise RuntimeError( + 'All tensors outfed from TPU should preserve batch size ' + 'dimension, but got scalar {}'.format(dequeue_ops[i][0])) + # TODO(xiejw): Allow users to specify the axis for batch size + # dimension. + dequeue_ops[i] = array_ops.concat(dequeue_ops[i], axis=0) + + if self._tensor_keys[name] is not None: + # The user-provided eval_metrics[1] is a dict. + dequeue_ops = dict(zip(self._tensor_keys[name], dequeue_ops)) + try: + ret[name] = self._host_fns[name](**dequeue_ops) + except TypeError as e: + logging.warning( + 'Exception while calling %s: %s. It is likely the tensors ' + '(%s[1]) do not match the ' + 'function\'s arguments', name, e, name) + raise e + else: + ret[name] = self._host_fns[name](*dequeue_ops) + + return ret + + +def wrap_hostcall_with_global_step(hostcall): + """Wrap the hostcall so that we update the global step upon every call.""" + if hostcall is None: + return None + host_fn, tensors = hostcall + + def global_step_host_fn(_global_step, *args, **kwargs): # pylint: disable=invalid-name + # Note that we don't have any ordering here, so the graph may see a + # global_step that's off by 1. + state_ops.assign( + training.get_global_step(), + math_ops.cast(_global_step[0], dtypes.int64)) + return host_fn(*args, **kwargs) + # Give the global step tensor a batch dimension. Reshape is not supported for + # int64, so we cast it to int32. + # TODO(jhseu): Remove the cast once int64 is supported. + global_step_tensor = array_ops.reshape( + math_ops.cast(training.get_global_step(), dtypes.int32), [1]) + if isinstance(tensors, dict): + outfeed_tensors = {'_global_step': global_step_tensor} + outfeed_tensors.update(tensors) + return global_step_host_fn, outfeed_tensors + else: + fn_args = util.fn_args(host_fn) + if len(tensors) != len(fn_args): + raise RuntimeError( + 'In TPUEstimatorSpec.host_call, length of tensors {} does not match ' + 'method args of the function, which takes {}.'.format( + len(tensors), len(fn_args))) + return global_step_host_fn, [global_step_tensor] + list(tensors) - if self._is_dict: - dequeue_ops = dict(zip(self._tensor_keys, dequeue_ops)) - try: - eval_metric_ops = self._metric_fn(**dequeue_ops) - except TypeError as e: - logging.warning( - 'Exception while calling metric_fn for evalution: %s. ' - 'It is likely the tensors (eval_metrics[1]) do not match the ' - 'metric_fn arguments', e) - raise e - else: - eval_metric_ops = self._metric_fn(*dequeue_ops) - eval_update_ops = [] - for k, v in eval_metric_ops.items(): - eval_metric_ops[k] = (v[0], dummy_update_op) - eval_update_ops.append(v[1]) +class _OutfeedHostCallHook(session_run_hook.SessionRunHook): + """Hook to run host calls when use_tpu=False.""" - return eval_metric_ops, eval_update_ops + def __init__(self, tensors): + self._tensors = tensors + + def begin(self): + # We duplicate this code from the TPUInfeedOutfeedSessionHook rather than + # create a separate hook to guarantee execution order, because summaries + # need to be initialized before the outfeed thread starts. + # TODO(jhseu): Make a wrapper hook instead? + self._init_ops = contrib_summary.summary_writer_initializer_op() + # Get all the writer resources from the initializer, so we know what to + # flush. + self._finalize_ops = [] + for op in self._init_ops: + self._finalize_ops.append(contrib_summary.flush(writer=op.inputs[0])) + + def after_create_session(self, session, coord): + session.run(self._init_ops) + + def before_run(self, run_context): + return basic_session_run_hooks.SessionRunArgs(self._tensors) + + def end(self, session): + session.run(self._finalize_ops) class ExamplesPerSecondHook(basic_session_run_hooks.StepCounterHook): @@ -1387,6 +1743,23 @@ class ExamplesPerSecondHook(basic_session_run_hooks.StepCounterHook): logging.info('examples/sec: %g', examples_per_sec) +class InstallSignalHandlerHook(session_run_hook.SessionRunHook): + """Change SIGINT (CTRL^C) handler to force quit the process. + + The default behavior often results in hanging processes. + The original handler is restored after training/evaluation. + """ + + def __init__(self): + self._signal_fn = signal.getsignal(signal.SIGINT) + + def before_run(self, run_context): + signal.signal(signal.SIGINT, signal.SIG_DFL) + + def end(self, session): + signal.signal(signal.SIGINT, self._signal_fn) + + class TPUEstimator(estimator_lib.Estimator): """Estimator with TPU support. @@ -1394,30 +1767,28 @@ class TPUEstimator(estimator_lib.Estimator): replicating inputs and models for each core, and returning to host periodically to run hooks. - If `use_tpu` is false, all training, evaluation, and predict are executed on - CPU. - - For training, TPUEstimator transforms a global batch size in params to a - per-shard batch size when calling the `input_fn` and `model_fn`. Users should - specify `train_batch_size` in constructor, and then get the batch size for - each shard in `input_fn` and `model_fn` by `params['batch_size']`. If - `TPUConfig.per_host_input_for_training` is `True`, `input_fn` is invoked per - host rather than per core. In this case, a global batch size is transformed a - per-host batch size in params for `input_fn`, but `model_fn` still gets - per-core batch size. - - For evaluation, if `eval_batch_size` is None, it is executed on CPU, even if - `use_tpu` is `True`. If `eval_batch_size` is not `None`, it is executed on - TPU, which is an experimental feature. In this case, `model_fn` should return - `TPUEstimatorSpec` instead of `EstimatorSpec`, which expects the - `eval_metrics` for TPU evaluation. - + TPUEstimator transforms a global batch size in params to a per-shard batch + size when calling the `input_fn` and `model_fn`. Users should specify + global batch size in constructor, and then get the batch size for each shard + in `input_fn` and `model_fn` by `params['batch_size']`. + For training, `model_fn` gets per-core batch size; `input_fn` may get + per-core or per-host batch size depending on + `per_host_input_for_training` in `TPUConfig`. + For evaluation, `model_fn` gets per-core batch size and `input_fn` get + per-host batch size. + + `model_fn` should return `TPUEstimatorSpec`, which expects the `eval_metrics` + for TPU evaluation. `TPUEstimatorSpec.eval_metrics` is a tuple of `metric_fn` and `tensors`, where `tensors` could be a list of `Tensor`s or dict of names to `Tensor`s. (See `TPUEstimatorSpec` for details). `metric_fn` takes the `tensors` and returns a dict from metric string name to the result of calling a metric function, namely a `(metric_tensor, update_op)` tuple. + One can set `use_tpu` to `False` for testing. All training, evaluation, and + predict will be executed on CPU. `input_fn` and `model_fn` will receive + `train_batch_size` or `eval_batch_size` unmodified as `params['batch_size']`. + Current limitations: 1. TPU evaluation only works on single host. @@ -1472,6 +1843,7 @@ class TPUEstimator(estimator_lib.Estimator): use_tpu=True, train_batch_size=None, eval_batch_size=None, + predict_batch_size=None, batch_axis=None): """Constructs an `TPUEstimator` instance. @@ -1490,18 +1862,17 @@ class TPUEstimator(estimator_lib.Estimator): basic python types. There are reserved keys for `TPUEstimator`, including 'batch_size'. use_tpu: A bool indicating whether TPU support is enabled. Currently, - - TPU training respects this bit. - - If true, see `eval_batch_size` for evaluate support. + - TPU training and evaluation respect this bit. - Predict still happens on CPU. train_batch_size: An int representing the global training batch size. TPUEstimator transforms this global batch size to a per-shard batch size, as params['batch_size'], when calling `input_fn` and `model_fn`. Cannot be `None` if `use_tpu` is `True`. Must be divisible by `config.tpu_config.num_shards`. - eval_batch_size: An int representing the global training batch size. - Currently, if `None`, evaluation is still executed on CPU (even when - `use_tpu` is True). In near future, `use_tpu` will be the only option to - switch between TPU/CPU evaluation. + eval_batch_size: An int representing evaluation batch size. + Must be divisible by `config.tpu_config.num_shards`. + predict_batch_size: An int representing the prediction batch size. + Must be divisible by `config.tpu_config.num_shards`. batch_axis: A python tuple of int values describing how each tensor produced by the Estimator `input_fn` should be split across the TPU compute shards. For example, if your input_fn produced (images, labels) @@ -1541,15 +1912,25 @@ class TPUEstimator(estimator_lib.Estimator): .format(train_batch_size, config.tpu_config.num_shards)) if eval_batch_size is not None: - if config.tpu_config.num_shards > 8: - raise NotImplementedError( - 'TPU evaluation is only supported with one host.') - + if not isinstance(eval_batch_size, int): + raise ValueError('`eval_batch_size` must be an int') + if eval_batch_size < 1: + raise ValueError('`eval_batch_size` must be positive') if eval_batch_size % config.tpu_config.num_shards != 0: raise ValueError( 'eval batch size {} must be divisible by number of shards {}' .format(eval_batch_size, config.tpu_config.num_shards)) + if predict_batch_size is not None: + if not isinstance(predict_batch_size, int): + raise ValueError('`predict_batch_size` must be an int') + if predict_batch_size < 1: + raise ValueError('`predict_batch_size` must be positive') + if predict_batch_size % config.tpu_config.num_shards != 0: + raise ValueError( + 'predict batch size {} must be divisible by number of shards {}' + .format(predict_batch_size, config.tpu_config.num_shards)) + # Verifies the model_fn signature according to Estimator framework. estimator_lib._verify_model_fn_args(model_fn, params) # pylint: disable=protected-access # We cannot store config and params in this constructor as parent @@ -1569,7 +1950,7 @@ class TPUEstimator(estimator_lib.Estimator): # All properties passed to _TPUContext are immutable. self._ctx = _TPUContext(self._config, train_batch_size, eval_batch_size, - use_tpu) + predict_batch_size, use_tpu) def _create_global_step(self, graph): """Creates a global step suitable for TPUs. @@ -1617,6 +1998,14 @@ class TPUEstimator(estimator_lib.Estimator): util_lib.check_positive_integer(steps, 'Eval steps') + if self._config.tpu_config.num_shards > 8: + raise NotImplementedError( + 'TPU evaluation is only supported with one host.') + + if self._ctx._eval_batch_size is None: # pylint: disable=protected-access + raise ValueError('`eval_batch_size` cannot be `None`' + 'if evaluate() is called on TPU.') + return [ evaluation._StopAfterNEvalsHook( # pylint: disable=protected-access num_evals=steps), @@ -1657,7 +2046,9 @@ class TPUEstimator(estimator_lib.Estimator): if batch_size_for_input_fn is not None: kwargs['params'][_BATCH_SIZE_KEY] = batch_size_for_input_fn - if ctx.is_running_on_cpu(): + # For export_savedmodel, input_fn is never passed to Estimator. So, + # `is_export_mode` must be False. + if ctx.is_running_on_cpu(is_export_mode=False): with ops.device('/device:CPU:0'): return input_fn(**kwargs) @@ -1684,8 +2075,13 @@ class TPUEstimator(estimator_lib.Estimator): with self._ctx.with_mode(mode) as ctx: model_fn_wrapper = _ModelFnWrapper(model_fn, config, params, ctx) - # TODO(jhseu): Move to PREDICT to TPU. - if ctx.is_running_on_cpu(): + # For export_savedmodel, input_fn is never passed to Estimator. So, + # if features is callable, it means it is the input_fn passed by + # TPUEstimator._call_input_fn. Then we can know if the mode == PREDICT, + # it implies, it is the .predict API, not export_savedmodel API. + is_export_mode = not callable(features) + + if ctx.is_running_on_cpu(is_export_mode=is_export_mode): logging.info('Running %s on CPU', mode) return model_fn_wrapper.call_without_tpu(features, labels) @@ -1695,22 +2091,31 @@ class TPUEstimator(estimator_lib.Estimator): input_fn = features input_holders = _InputPipeline(input_fn, batch_axis, ctx) - enqueue_ops, dequeue_fn = ( + enqueue_ops, dequeue_fn, input_hooks, run_infeed_loop_on_coordinator = ( input_holders.generate_infeed_enqueue_ops_and_dequeue_fn()) if mode == model_fn_lib.ModeKeys.TRAIN: - loss, scaffold = ( + loss, host_call, scaffold = ( _train_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn)) + host_ops = host_call.create_tpu_hostcall() + if host_ops is None: + host_ops = [] hooks = [ - TPUInfeedOutfeedSessionHook(ctx, enqueue_ops), + TPUInfeedOutfeedSessionHook( + ctx, + enqueue_ops, + host_ops, + run_infeed_loop_on_coordinator=( + run_infeed_loop_on_coordinator)), ExamplesPerSecondHook(ctx.global_batch_size), + InstallSignalHandlerHook(), training.LoggingTensorHook( { 'loss': array_ops.identity(loss), 'step': training.get_global_step() }, every_n_secs=30) - ] + ] + input_hooks summary.scalar(model_fn_lib.LOSS_METRIC_KEY, loss) with ops.control_dependencies([loss]): update_ops = _sync_variables_ops() @@ -1725,40 +2130,114 @@ class TPUEstimator(estimator_lib.Estimator): train_op=control_flow_ops.group(*update_ops), scaffold=scaffold) - # Now eval. - total_loss, eval_metric_ops, scaffold = _eval_on_tpu_system( + if mode == model_fn_lib.ModeKeys.EVAL: + total_loss, host_calls, scaffold = _eval_on_tpu_system( + ctx, model_fn_wrapper, dequeue_fn) + iterations_per_loop_var = _create_or_get_iterations_per_loop() + mean_loss = math_ops.div(total_loss, + math_ops.cast( + iterations_per_loop_var, + dtype=total_loss.dtype)) + + # Creates a dummy metric update_op for all metrics. Estimator expects + # all metrics in eval_metric_ops have update_op and calls them one by + # one. The real metric update_ops are invoked in a separated thread. + # So, here give Estimator the dummy op for all metrics. + with ops.control_dependencies([mean_loss]): + # After TPU evaluation computation is done (the mean_loss tensor), + # reads all variables back from TPU and updates the eval step + # counter properly + internal_ops_to_run = _sync_variables_ops() + internal_ops_to_run.append( + _increase_eval_step_op(iterations_per_loop_var)) + with ops.control_dependencies(internal_ops_to_run): + dummy_update_op = control_flow_ops.no_op() + + host_call_ret = host_calls.create_tpu_hostcall() + eval_metric_ops = {} + eval_update_ops = [] + for k, v in host_call_ret['eval_metrics'].items(): + eval_metric_ops[k] = (v[0], dummy_update_op) + eval_update_ops.append(v[1]) + + if 'host_call' not in host_call_ret: + host_ops = [] + else: + host_ops = host_call_ret['host_call'] + hooks = [ + TPUInfeedOutfeedSessionHook( + ctx, + enqueue_ops, + eval_update_ops + host_ops, + run_infeed_loop_on_coordinator=( + run_infeed_loop_on_coordinator)), + ] + input_hooks + + return model_fn_lib.EstimatorSpec( + mode, + loss=mean_loss, + evaluation_hooks=hooks, + eval_metric_ops=eval_metric_ops, + scaffold=scaffold) + + # Predict + assert mode == model_fn_lib.ModeKeys.PREDICT + + dummy_predict_op, host_calls, scaffold = _predict_on_tpu_system( ctx, model_fn_wrapper, dequeue_fn) - iterations_per_loop_var = _create_or_get_iterations_per_loop() - mean_loss = math_ops.div(total_loss, - math_ops.cast( - iterations_per_loop_var, - dtype=total_loss.dtype)) - - # Creates a dummy metric update_op for all metrics. Estimator expects - # all metrics in eval_metric_ops have update_op and calls them one by - # one. The real metric update_ops are invoked in a separated thread. So, - # here give Estimator the dummy op for all metrics. - with ops.control_dependencies([mean_loss]): - # After TPU evaluation computation is done (the mean_loss tensor), - # reads all variables back from TPU and updates the eval step counter - # properly + with ops.control_dependencies([dummy_predict_op]): internal_ops_to_run = _sync_variables_ops() - internal_ops_to_run.append( - _increase_eval_step_op(iterations_per_loop_var)) with ops.control_dependencies(internal_ops_to_run): - dummy_update_op = control_flow_ops.no_op() + dummy_predict_op = control_flow_ops.no_op() + + # In train and evaluation, the main TPU program is passed to monitored + # training session to run. Infeed enqueue and outfeed dequeue are + # executed in side threads. This is not the configuration for + # prediction mode. + # + # For prediction, the Estimator executes the EstimatorSpec.predictions + # directly and yield the element (via generator) to call site. So, the + # outfeed based prediction must be passed to MonitoredSession directly. + # Other parts of the TPU execution are organized as follows. + # + # 1. All outfeed based Tensors must be grouped with predictions Tensors + # to form a single invocation. This avoid the issue we might trigger + # multiple outfeeds incorrectly. To achieve this, `host_call` is + # placed in control_dependencies of `stopping_signals`, and + # `stopping_signals` is passed into _StoppingPredictHook, which sets + # the `stopping_signals` as SessionRunArgs. MonitoredSession merges + # all SessionRunArgs with the fetch in session.run together. + # + # 2. The TPU program (dummy_predict_op) and enqueue_ops (infeed Enqueue) + # are grouped together. They will be launched once and only once in + # side threads and they quit naturally according to the SAME stopping + # condition. + enqueue_ops.append(dummy_predict_op) + + host_call_ret = host_calls.create_tpu_hostcall() + if 'host_call' not in host_call_ret: + host_ops = [] + else: + host_ops = host_call_ret['host_call'] + + predictions = host_call_ret['predictions'] + stopping_signals = host_call_ret['signals'] + + with ops.control_dependencies(host_ops): + host_ops = [] # Empty, we do do not need it anymore. + scalar_stopping_signal = _StopSignals.as_scalar_stopping_signal( + stopping_signals) - eval_metric_ops, eval_update_ops = ( - eval_metric_ops.to_metric_metric_ops_for_tpu(dummy_update_op)) hooks = [ - TPUInfeedOutfeedSessionHook(ctx, enqueue_ops, eval_update_ops), - ] + _StoppingPredictHook(scalar_stopping_signal), + TPUInfeedOutfeedSessionHookForPrediction(ctx, enqueue_ops, + host_ops), + ] + input_hooks return model_fn_lib.EstimatorSpec( mode, - loss=mean_loss, - evaluation_hooks=hooks, - eval_metric_ops=eval_metric_ops, + prediction_hooks=hooks, + predictions=predictions, scaffold=scaffold) return _model_fn @@ -1769,7 +2248,7 @@ def _eval_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn): num_cores = ctx.num_cores iterations_per_loop_var = _create_or_get_iterations_per_loop() - single_tpu_eval_step, eval_metric_ops, captured_scaffold_fn = ( + single_tpu_eval_step, host_calls, captured_scaffold_fn = ( model_fn_wrapper.convert_to_single_tpu_eval_step(dequeue_fn)) def multi_tpu_eval_steps_on_single_shard(): @@ -1785,7 +2264,7 @@ def _eval_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn): outputs_from_all_shards=False) scaffold = _get_scaffold(captured_scaffold_fn) - return loss, eval_metric_ops, scaffold + return loss, host_calls, scaffold def _train_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn): @@ -1793,7 +2272,7 @@ def _train_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn): num_cores = ctx.num_cores iterations_per_loop_var = _create_or_get_iterations_per_loop() - single_tpu_train_step, captured_scaffold_fn = ( + single_tpu_train_step, host_call, captured_scaffold_fn = ( model_fn_wrapper.convert_to_single_tpu_train_step(dequeue_fn)) def multi_tpu_train_steps_on_single_shard(): @@ -1809,7 +2288,35 @@ def _train_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn): outputs_from_all_shards=False) scaffold = _get_scaffold(captured_scaffold_fn) - return loss, scaffold + return loss, host_call, scaffold + + +def _predict_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn): + """Executes `model_fn_wrapper` multiple times on all TPU shards.""" + num_cores = ctx.num_cores + + single_tpu_predict_step, host_calls, captured_scaffold_fn = ( + model_fn_wrapper.convert_to_single_tpu_predict_step(dequeue_fn)) + + def multi_tpu_predict_steps_on_single_shard(): + + def cond(scalar_stopping_signal): + return math_ops.logical_not( + _StopSignals.should_stop(scalar_stopping_signal)) + + inputs = [_StopSignals.NON_STOPPING_SIGNAL] + outputs = training_loop.while_loop( + cond, single_tpu_predict_step, inputs=inputs, name=b'loop') + return outputs + + (dummy_predict_op,) = tpu.shard( + multi_tpu_predict_steps_on_single_shard, + inputs=[], + num_shards=num_cores, + outputs_from_all_shards=False) + + scaffold = _get_scaffold(captured_scaffold_fn) + return dummy_predict_op, host_calls, scaffold def _wrap_computation_in_while_loop(device, op_fn): @@ -1830,6 +2337,29 @@ def _wrap_computation_in_while_loop(device, op_fn): parallel_iterations=1) +def _wrap_computation_in_while_loop_with_stopping_signals(device, op_fn): + """Wraps the ops generated by `op_fn` in tf.while_loop.""" + + def cond(scalar_stopping_signal): + return math_ops.logical_not( + _StopSignals.should_stop(scalar_stopping_signal)) + + def computation(unused_scalar_stopping_signal): + return_value = op_fn() + execute_ops = return_value['ops'] + signals = return_value['signals'] + with ops.control_dependencies(execute_ops): + return _StopSignals.as_scalar_stopping_signal(signals) + + # By setting parallel_iterations=1, the parallel execution in while_loop is + # basically turned off. + with ops.device(device): + return control_flow_ops.while_loop( + cond, + computation, [_StopSignals.NON_STOPPING_SIGNAL], + parallel_iterations=1) + + def _validate_tpu_training_graph(): """Validate graph before running distributed training. @@ -1920,3 +2450,194 @@ class _CapturingContext(control_flow_ops.ControlFlowContext): def __exit__(self, _, __, ___): # pylint: disable=invalid-name self._g._set_control_flow_context(self._old) # pylint: disable=protected-access + + +class _Inputs(object): + """A data structure representing the input_fn returned values. + + This also supports the returned value from input_fn as `Dataset`. + """ + + def __init__(self, features=None, labels=None, dataset=None, signals=None): + if dataset is not None and (features is not None or labels is not None or + signals is not None): + raise RuntimeError('Internal Error: Either (features and labels) or ' + 'dataset should be provided, not both. Please file ' + 'bug') + + self._features = features + self._labels = labels + self._signals = signals + + self._dataset = dataset + self._iterator = None + + @staticmethod + def from_input_fn(return_values): + """Returns an `_Inputs` instance according to `input_fn` return value.""" + if isinstance(return_values, dataset_ops.Dataset): + dataset = return_values + return _Inputs(dataset=dataset) + + features, labels = _Inputs._parse_inputs(return_values) + return _Inputs(features, labels) + + @staticmethod + def _parse_inputs(return_values): + if isinstance(return_values, tuple): + features, labels = return_values + else: + features, labels = return_values, None + return features, labels + + @property + def is_dataset(self): + """Returns True if the return value from input_fn is Dataset.""" + return self._dataset is not None + + def dataset_initializer_hook(self): + """Returns a `SessionRunHook` to initialize this dataset. + + This must be called before `features_and_labels`. + """ + iterator = self._dataset.make_initializable_iterator() + # pylint: disable=protected-access + hook = estimator_lib._DatasetInitializerHook(iterator) + self._iterator = iterator + return hook + + def features_and_labels(self): + """Gets `features` and `labels`.""" + if self.is_dataset: + return _Inputs._parse_inputs(self._iterator.get_next()) + + return (self._features, self._labels) + + def signals(self): + return self._signals + + @property + def dataset(self): + return self._dataset + + +# TODO(xiejw): Extend this to support final partial batch. +class _InputsWithStoppingSignals(_Inputs): + """Inputs with `_StopSignals` inserted into the dataset.""" + + def __init__(self, dataset, batch_size): + + assert dataset is not None + + user_provided_dataset = dataset.map( + _InputsWithStoppingSignals.insert_stopping_signal( + stop=False, batch_size=batch_size)) + final_batch_dataset = dataset.take(1).map( + _InputsWithStoppingSignals.insert_stopping_signal( + stop=True, batch_size=batch_size)) + dataset = user_provided_dataset.concatenate(final_batch_dataset).prefetch(2) + + super(_InputsWithStoppingSignals, self).__init__(dataset=dataset) + self._current_inputs = None + + def features_and_labels(self): + if self._current_inputs is not None: + raise RuntimeError( + 'Internal Error: The previous inputs have not been properly ' + 'consumed. First call features_and_labels, then call signals.') + + inputs_with_signals = self._iterator.get_next() + features = inputs_with_signals['features'] + labels = inputs_with_signals.get('labels') + + self._current_inputs = inputs_with_signals + return features, labels + + def signals(self): + """Returns the `Signals` from `_Inputs`.""" + if self._current_inputs is None: + raise RuntimeError( + 'Internal Error: The current inputs have not been properly ' + 'generated. First call features_and_labels, then call signals.') + signals = self._current_inputs['signals'] + self._current_inputs = None + return signals + + @staticmethod + def insert_stopping_signal(stop, batch_size): + """Inserts stopping_signal into dataset via _map_fn. + + Here we change the data structure in the dataset, such that the return value + is a dictionary now and `features`, `labels`, and `signals` are three + distinguished keys in that dict. This provides a better structure, which + eases the process to decompose the inputs (see `features_and_labels`). + + Args: + stop: bool, state of current stopping signals. + batch_size: int, batch size. + + Returns: + A map_fn passed to dataset.map API. + """ + + def _map_fn(*args): + features, labels = _Inputs._parse_inputs(args) + new_input_dict = {} + new_input_dict['features'] = features + if labels is not None: + new_input_dict['labels'] = labels + new_input_dict['signals'] = _StopSignals( + stop=stop, batch_size=batch_size).as_dict() + return new_input_dict + + return _map_fn + + +class _StopSignals(object): + """Signals class holding all logic to handle TPU stopping condition.""" + + NON_STOPPING_SIGNAL = 0.0 + STOPPING_SIGNAL = 1.0 + + def __init__(self, stop, batch_size): + self._stop = stop + self._batch_size = batch_size + + def as_dict(self): + shape = [self._batch_size, 1] + dtype = dtypes.float32 + + if self._stop: + stopping = array_ops.ones(shape=shape, dtype=dtype) + else: + stopping = array_ops.zeros(shape=shape, dtype=dtype) + + return {'stopping': stopping} + + @staticmethod + def as_scalar_stopping_signal(signals): + return array_ops.identity(signals['stopping'][0][0]) + + @staticmethod + def should_stop(scalar_stopping_signal): + return scalar_stopping_signal >= _StopSignals.STOPPING_SIGNAL + + +class _SignalsHelper(object): + """A general helper class to handle common signals manipulation.""" + + def __init__(self, signals): + self._signal_keys = [] + for key in sorted(signals.iterkeys()): + self._signal_keys.append(key) + + @property + def num_signals(self): + return len(self._signal_keys) + + def unflatten(self, tensor_list): + return dict(zip(self._signal_keys, tensor_list)) + + @staticmethod + def as_tensor_list(signals): + return [signals[key] for key in sorted(signals.iterkeys())] diff --git a/tensorflow/contrib/training/BUILD b/tensorflow/contrib/training/BUILD index cccaa2b833ee764921508a5b6d6affe0b8822ede..6db373d2d5e20ea7da449530b2730403c3bb64cc 100644 --- a/tensorflow/contrib/training/BUILD +++ b/tensorflow/contrib/training/BUILD @@ -26,6 +26,7 @@ py_library( "python/training/resample.py", "python/training/sampling_ops.py", "python/training/sequence_queueing_state_saver.py", + "python/training/tensor_queue_dataset.py", "python/training/training.py", "python/training/tuner.py", ], @@ -285,6 +286,28 @@ py_test( ], ) +py_test( + name = "tensor_queue_dataset_test", + size = "large", + srcs = ["python/training/tensor_queue_dataset_test.py"], + srcs_version = "PY2AND3", + tags = ["notsan"], + deps = [ + ":training_py", + "//tensorflow/contrib/data/python/kernel_tests:dataset_serialization_test", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:gradients", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform", + "//tensorflow/python:random_seed", + "//tensorflow/python:training", + "//tensorflow/python:variables", + "//tensorflow/python/data", + "//third_party/py/numpy", + ], +) + filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/contrib/training/python/training/tensor_queue_dataset.py b/tensorflow/contrib/training/python/training/tensor_queue_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..409aba817c1ec37003eb98f000f6cf8918234c5d --- /dev/null +++ b/tensorflow/contrib/training/python/training/tensor_queue_dataset.py @@ -0,0 +1,200 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Python wrappers for Datasets and Iterators.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.util import nest +from tensorflow.python.data.util import sparse +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import gen_dataset_ops +from tensorflow.python.util import nest as tf_nest + + +class _PrependFromQueueAndPaddedBatchDataset(dataset_ops.Dataset): + """A `Dataset` that prepends a queue to another `Dataset`. + + A vector of handles to the queue is returned as the first component of + the associated iterator. This vector can be passed to + `enqueue_in_queue_dataset` to add new elements to the queue. + """ + + def __init__(self, input_dataset, batch_size, padded_shapes, padding_values): + """Initialize `PrependFromQueueAndPaddedBatchDataset`.""" + super(_PrependFromQueueAndPaddedBatchDataset, self).__init__() + if sparse.any_sparse(input_dataset.output_classes): + raise TypeError( + "Batching of padded sparse tensors is not currently supported") + self._input_dataset = input_dataset + self._batch_size = ops.convert_to_tensor( + batch_size, dtype=dtypes.int64, name="batch_size") + # pylint: disable=protected-access + if padded_shapes is None: + self._padded_shapes = nest.map_structure( + dataset_ops._partial_shape_to_tensor, input_dataset.output_shapes) + else: + self._padded_shapes = nest.map_structure_up_to( + input_dataset.output_shapes, dataset_ops._partial_shape_to_tensor, + padded_shapes) + padding_values = ( + padding_values if padding_values is not None else + dataset_ops._default_padding(input_dataset)) + self._padding_values = nest.map_structure_up_to( + input_dataset.output_shapes, dataset_ops._padding_value_to_tensor, + padding_values, input_dataset.output_types) + # pylint: enable=protected-access + + def _as_variant_tensor(self): + # pylint: disable=protected-access + return gen_dataset_ops.prepend_from_queue_and_padded_batch_dataset( + self._input_dataset._as_variant_tensor(), + batch_size=self._batch_size, + padded_shapes=[ + ops.convert_to_tensor(s, dtype=dtypes.int64) + for s in nest.flatten(self._padded_shapes) + ], + padding_values=nest.flatten(self._padding_values), + output_shapes=nest.flatten( + sparse.as_dense_shapes(self.output_shapes, self.output_classes))) + # pylint: enable=protected-access + + @property + def output_classes(self): + return (ops.Tensor, self._input_dataset.output_classes) + + def _as_batch_shape(self, shape_like): + return tensor_shape.vector(None).concatenate( + tensor_util.constant_value_as_shape(shape_like)) + + @property + def output_shapes(self): + # First output is a variant representing the Queue + return (tensor_shape.vector(None), + nest.map_structure(self._as_batch_shape, self._padded_shapes)) + + @property + def output_types(self): + # First output is a variant representing the Queue + return (dtypes.variant, self._input_dataset.output_types) + + +def prepend_from_queue_and_padded_batch_dataset(batch_size, + padding_values=None, + padded_shapes=None): + """A transformation that prepends a queue to a `Dataset` and batches results. + + A vector of handles to the queue is returned as the first component of the + associated iterator. This vector can be passed to `enqueue_in_queue_dataset` + to add new elements to the queue. + + Below is an example of how this dataset might be used to split incoming + variable-length sequences into "head" and "rest" parts, where "rest" parts + are re-enqueued back into the dataset. A more realistic example would + perform some calculation on the "head" and modify some components of "rest" + with the result (before re-enqueueing). + + ```python + dataset = tf.data.Dataset.from_tensor_slices([2*x for x in range(10)]) + # Make a dataset of variable-length vectors and their lengths. + dataset = dataset.map(lambda count: (count, tf.ones((count,)))) + # Emit a queue we can prepend to, and counts/values as padded batch. + dataset = dataset.apply( + tf.contrib.training.prepend_from_queue_and_padded_batch_dataset( + batch_size=10)) + dataset = dataset.prefetch(1) + + iterator = dataset.make_one_shot_iterator() + queue, (count, padded_value) = iterator.get_next() + + # Split the padded_value into two pieces: head and rest + rest_indices = tf.squeeze(tf.where(count > 3), axis=1) + bound = tf.minimum(3, tf.reduce_max(count)) + value_head = padded_value[:, :bound] + count_rest = tf.gather(count - 3, rest_indices) + value_rest = tf.gather(padded_value[:, bound:], rest_indices) + queue_rest = tf.gather(queue, rest_indices) + enqueue_rest_op = tf.contrib.training.enqueue_in_queue_dataset( + queue_rest, (count_rest, value_rest)) + with tf.control_dependencies([enqueue_rest_op]): + calculation = fn(value_head) + + while True: # Will raise OutOfRange when finished with all pieces. + session.run(calculation) + ``` + + Args: + batch_size: `int64` scalar tensor. The batch size to use when performing + padded batching. + padding_values: (optional) Nested tuple of scalar tensors. If provided, + the structure and dtypes of padding_values should match that of + incoming dataset's `output_types`. + padded_shapes: (optional) Nested tuple of `int64` vector tensors. + If provided, the structure must match that of the incoming dataset's + `output_types`. If not provided, the incoming dataset's `output_shapes` + is used. Any unknown (`None` or `-1`) dimensions in the shapes are + treated as being unique per-batch: for each batch time, an unknown + dimension is replaced with the maximum given value of this dimension + across all tensors for the given component in the batch. + + Returns: + A `Dataset` transformation function, which can be passed to + @{tf.data.Dataset.apply}. + """ + + def _apply_fn(dataset): + return _PrependFromQueueAndPaddedBatchDataset( + dataset, + batch_size=batch_size, + padding_values=padding_values, + padded_shapes=padded_shapes) + + return _apply_fn + + +def enqueue_in_queue_dataset(queue, components): + """Enqueue components into queue from `PrependFromQueueAndPaddedBatchDataset`. + + The components' dtypes and shapes must be compatible with the `output_shapes` + attribute of the `dataset` created by + `prepend_from_queue_and_padded_batch_dataset`. This operation supports both + non-batched and batched modes. + + For more details, see the example in the docstring for + `prepend_from_queue_and_padded_batch_dataset`. + + Args: + queue: `variant` scalar or vector tensor. + The tensor emitted by the first component of the iterator associated with + `prepend_from_queue_and_padded_batch_dataset`. If this is a scalar, + then the `components` input tensors should not have a prepended batch + dimension. + components: Nested tuple of tensors, each with a leading batch dimension + if `queue` is a vector. The structure, dtypes, and shapes + (excluding batch dimension) must match the nested tuples + `dataset.output_types[1]` and `dataset.output_shapes[1]` (the non-queue + output types and shapes) of the `dataset` emitted by + the original `prepend_from_queue_and_padded_batch_dataset` call. + + Returns: + An `Operation` that enqueues `components` into the dataset(s) associated + with entries of `queue`. + """ + return gen_dataset_ops.enqueue_in_queue_dataset( + queue=queue, components=tf_nest.flatten(components)) diff --git a/tensorflow/contrib/training/python/training/tensor_queue_dataset_test.py b/tensorflow/contrib/training/python/training/tensor_queue_dataset_test.py new file mode 100644 index 0000000000000000000000000000000000000000..0338f409a203c232e63e99534a8f6d6a43fa661e --- /dev/null +++ b/tensorflow/contrib/training/python/training/tensor_queue_dataset_test.py @@ -0,0 +1,355 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for TensorQueueDataset.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base +from tensorflow.contrib.training.python.training import tensor_queue_dataset as tqd +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import string_ops +from tensorflow.python.platform import test + + +class PrependFromQueueAndPaddedBatchDatasetTest(test.TestCase): + + def testNoEnqueue(self): + dataset = dataset_ops.Dataset.from_tensor_slices([0, 1, 2]) + dataset = dataset.apply( + tqd.prepend_from_queue_and_padded_batch_dataset(batch_size=1)) + self.assertEqual((dtypes.variant, dtypes.int32), dataset.output_types) + self.assertAllEqual(([None],) * 2, + [x.as_list() for x in dataset.output_shapes]) + iterator = dataset.make_one_shot_iterator() + _, value = iterator.get_next() + self.assertEqual([0], self.evaluate(value)) + self.assertEqual([1], self.evaluate(value)) + self.assertEqual([2], self.evaluate(value)) + with self.assertRaisesOpError("End of sequence"): + self.evaluate(value) + + def testBatchedNoEnqueue(self): + dataset = dataset_ops.Dataset.from_tensor_slices([0, 1, 2]) + dataset = dataset.apply( + tqd.prepend_from_queue_and_padded_batch_dataset(batch_size=2)) + iterator = dataset.make_one_shot_iterator() + _, value = iterator.get_next() + self.assertAllEqual([0, 1], self.evaluate(value)) + self.assertAllEqual([2], self.evaluate(value)) + with self.assertRaisesOpError("End of sequence"): + self.evaluate(value) + + def testBatchedWithBiggerPaddingNoEnqueue(self): + dataset = dataset_ops.Dataset.from_tensor_slices([[0], [1], [2]]) + dataset = dataset.apply( + tqd.prepend_from_queue_and_padded_batch_dataset( + batch_size=2, padded_shapes=[3])) + iterator = dataset.make_one_shot_iterator() + _, value = iterator.get_next() + self.assertAllEqual([[0, 0, 0], [1, 0, 0]], self.evaluate(value)) + self.assertAllEqual([[2, 0, 0]], self.evaluate(value)) + with self.assertRaisesOpError("End of sequence"): + self.evaluate(value) + + def testBatchedWithBiggerPaddingOneEnqueue(self): + dataset = dataset_ops.Dataset.from_tensor_slices([[0], [1], [2]]) + dataset = dataset.apply( + tqd.prepend_from_queue_and_padded_batch_dataset( + batch_size=1, padded_shapes=[3])) + iterator = dataset.make_one_shot_iterator() + queue_handle, value = iterator.get_next() + enqueue_negative = tqd.enqueue_in_queue_dataset(queue_handle, -value) + with self.test_session() as sess: + self.assertAllEqual([[0, 0, 0]], sess.run(value)) + value_1, _ = sess.run([value, enqueue_negative]) + self.assertAllEqual([[1, 0, 0]], value_1) + value_2, _ = sess.run([value, enqueue_negative]) + self.assertAllEqual([[-1, 0, 0]], value_2) + value_3 = sess.run(value) + self.assertAllEqual([[1, 0, 0]], value_3) + value_4, _ = sess.run([value, enqueue_negative]) + self.assertAllEqual([[2, 0, 0]], value_4) + value_5 = sess.run(value) + self.assertAllEqual([[-2, 0, 0]], value_5) + with self.assertRaisesOpError("End of sequence"): + sess.run(value) + + def testOneEnqueue(self): + dataset = dataset_ops.Dataset.from_tensor_slices([0, 1, 2]) + dataset = dataset.apply( + tqd.prepend_from_queue_and_padded_batch_dataset(batch_size=1)) + iterator = dataset.make_one_shot_iterator() + queue_handle, value = iterator.get_next() + enqueue_negative = tqd.enqueue_in_queue_dataset(queue_handle, -value) + with self.test_session() as sess: + self.assertEqual([0], sess.run(value)) + value_1, _ = sess.run([value, enqueue_negative]) + self.assertEqual([1], value_1) + value_2, _ = sess.run([value, enqueue_negative]) + self.assertEqual([-1], value_2) + value_3 = sess.run(value) + self.assertEqual([1], value_3) + value_4, _ = sess.run([value, enqueue_negative]) + self.assertEqual([2], value_4) + value_5 = sess.run(value) + self.assertEqual([-2], value_5) + with self.assertRaisesOpError("End of sequence"): + sess.run(value) + + def testBatchedOneEnqueue(self): + dataset = dataset_ops.Dataset.from_tensor_slices([0, 1, 2]) + dataset = dataset.apply( + tqd.prepend_from_queue_and_padded_batch_dataset(batch_size=2)) + iterator = dataset.make_one_shot_iterator() + queue_handle, value = iterator.get_next() + enqueue_negative = tqd.enqueue_in_queue_dataset(queue_handle, -value) + enqueue_zeroth = tqd.enqueue_in_queue_dataset([queue_handle[0]], + array_ops.expand_dims( + value[0], axis=0)) + with self.test_session() as sess: + value_0, _ = sess.run([value, enqueue_negative]) + self.assertAllEqual([0, 1], value_0) + value_1, _ = sess.run([value, enqueue_zeroth]) + self.assertAllEqual([0, -1], value_1) + value_2, _ = sess.run([value, enqueue_negative]) + self.assertAllEqual([0, 2], value_2) + self.assertAllEqual([0, -2], sess.run(value)) + with self.assertRaisesOpError("End of sequence"): + sess.run(value) + + def testManyEnqueue(self): + dataset = dataset_ops.Dataset.from_tensor_slices([0, 1]) + dataset = dataset.apply( + tqd.prepend_from_queue_and_padded_batch_dataset(batch_size=1)) + iterator = dataset.make_one_shot_iterator() + queue_handle, value = iterator.get_next() + enqueue_many_more = [ + tqd.enqueue_in_queue_dataset(queue_handle, value + 100 + i) + for i in range(1000) + ] + with self.test_session() as sess: + value_0, _ = sess.run((value, enqueue_many_more)) + self.assertEqual([0], value_0) + rest = [] + for _ in range(1000): + rest.append(sess.run(value)) + self.assertEquals([[100 + i] for i in range(1000)], sorted(rest)) + # Going back to the original input. + value_1, _ = sess.run((value, enqueue_many_more)) + self.assertEqual(1, value_1) + rest = [] + for _ in range(1000): + rest.append(sess.run(value)) + self.assertEquals([[100 + i + 1] for i in range(1000)], sorted(rest)) + with self.assertRaisesOpError("End of sequence"): + sess.run(value) + + def testEnqueueWithPrefetch(self): + dataset = dataset_ops.Dataset.from_tensor_slices([0]) + dataset = dataset.apply( + tqd.prepend_from_queue_and_padded_batch_dataset(batch_size=1)) + # Prefetching will request additional values before they are + # available to the queue. + dataset = dataset.prefetch(buffer_size=3) + iterator = dataset.make_one_shot_iterator() + queue_handle, value = iterator.get_next() + enqueue = tqd.enqueue_in_queue_dataset(queue_handle, value + 1) + with self.test_session() as sess: + i = 0 + while i < 4: + received, _ = sess.run((value, enqueue)) + if received.size > 0: + self.assertAllEqual([i], received) + i += 1 + received_last = False + while True: + try: + received = sess.run(value) + if received.size > 0: + self.assertAllEqual([4], received) + received_last = True + except errors.OutOfRangeError: + break + self.assertTrue(received_last) + + def testDatasetWithPaddedShapeSmallerThanInputFails(self): + dataset = dataset_ops.Dataset.from_tensor_slices([[0, 0, 0]]).repeat(None) + dataset = dataset.apply( + tqd.prepend_from_queue_and_padded_batch_dataset( + batch_size=1, padded_shapes=[2])) + iterator = dataset.make_one_shot_iterator() + _, value = iterator.get_next() + with self.test_session() as sess: + with self.assertRaisesOpError( + r"Incompatible input shapes at component 0 between " + r"input dataset this dataset: \[3\] vs. \[2\]"): + sess.run(value) + + def testEnqueueWithIncompatibleInputsFailsWithInformativeError(self): + dataset = dataset_ops.Dataset.from_tensor_slices([0]).repeat(None) + dataset = dataset.apply( + tqd.prepend_from_queue_and_padded_batch_dataset(batch_size=1)) + iterator = dataset.make_one_shot_iterator() + queue_handle, value = iterator.get_next() + + enqueue_bad_structure = tqd.enqueue_in_queue_dataset( + queue_handle, (value, value)) + enqueue_bad_dtype = tqd.enqueue_in_queue_dataset(queue_handle, + np.array( + [1.0], + dtype=np.float32)) + enqueue_bad_shape_no_batch_dim = tqd.enqueue_in_queue_dataset( + queue_handle, ([1],)) + enqueue_bad_shape = tqd.enqueue_in_queue_dataset(queue_handle, + np.array( + [[1]], dtype=np.int32)) + + with self.test_session() as sess: + with self.assertRaisesOpError( + "mismatched number of tensors. Queue expects 1 tensors but " + "tried to insert 2"): + sess.run(enqueue_bad_structure) + with self.assertRaisesOpError(r"Expected component 0 to have batched " + r"shape \[1,...\], but saw shape: \[\]"): + sess.run(enqueue_bad_shape_no_batch_dim) + with self.assertRaisesOpError( + r"mismatched shapes at component 0. Attempted to insert tensor " + r"with shape \[1\] but queue expected shape: \[\]"): + sess.run(enqueue_bad_shape) + with self.assertRaisesOpError( + r"mismatched dtypes at component 0. Attempted to insert tensor " + r"of type float but queue expected type: int32"): + sess.run(enqueue_bad_dtype) + + def testEnqueueWithPaddedBatchFailsWithInformativeError(self): + dataset = dataset_ops.Dataset.from_tensor_slices([0, 1, 2]) + dataset = dataset.apply( + tqd.prepend_from_queue_and_padded_batch_dataset(batch_size=1)) + with self.assertRaisesRegexp( + TypeError, r"Unable to create padding for field of type 'variant'"): + dataset.padded_batch(batch_size=10, padded_shapes=[1]) + + def testOneEnqueueWithPadding(self): + dataset = dataset_ops.Dataset.from_tensor_slices([0, 2, 4, 6]) + # Make a dataset of variable-length vectors and their lengths. + dataset = dataset.map( + lambda c: (c, c * array_ops.ones((c,), dtype=c.dtype))) + # Emit a queue we can prepend to, and counts/values as padded + # batch. + dataset = dataset.apply( + tqd.prepend_from_queue_and_padded_batch_dataset(batch_size=3)) + + iterator = dataset.make_one_shot_iterator() + queue, (count, padded_value) = iterator.get_next() + + # Split the padded_value into two pieces: head and rest + rest_indices = array_ops.squeeze(array_ops.where(count > 2), axis=1) + bound = math_ops.minimum(2, math_ops.reduce_max(count)) + value_head = padded_value[:, :bound] + count_rest = array_ops.gather(count - 2, rest_indices) + value_rest = array_ops.gather(padded_value, rest_indices)[:, bound:] + queue_rest = array_ops.gather(queue, rest_indices) + enqueue_rest_op = tqd.enqueue_in_queue_dataset(queue_rest, + (count_rest, value_rest)) + with ops.control_dependencies([enqueue_rest_op]): + calc = array_ops.identity(value_head) + + with self.test_session() as sess: + self.assertAllEqual([[0, 0], [2, 2], [4, 4]], sess.run(calc)) + self.assertAllEqual([[4, 4], [6, 6]], sess.run(calc)) + self.assertAllEqual([[6, 6]], sess.run(calc)) + self.assertAllEqual([[6, 6]], sess.run(calc)) + # Get some final batches due to prefetching. + for _ in range(3): + try: + self.assertAllEqual( + np.empty(shape=(0, 0), dtype=np.int32), sess.run(calc)) + except errors.OutOfRangeError as e: + self.assertTrue(str(e).startswith("End of sequence")) + + def testNonstandardPadding(self): + dataset = dataset_ops.Dataset.from_tensor_slices([0, 2, 4, 6]) + # Make a dataset of variable-length vectors and their lengths. + dataset = dataset.map( + lambda c: (c, c * array_ops.ones((c,), dtype=c.dtype))) + # Emit a queue we can prepend to, and counts/values as padded + # batch. + dataset = dataset.apply( + tqd.prepend_from_queue_and_padded_batch_dataset( + batch_size=3, padding_values=( + 0, + -1, + ))) + + iterator = dataset.make_one_shot_iterator() + _, (unused_count, padded_value) = iterator.get_next() + + with self.test_session() as sess: + self.assertAllEqual([[-1, -1, -1, -1], [2, 2, -1, -1], [4, 4, 4, 4]], + sess.run(padded_value)) + self.assertAllEqual([[6] * 6], sess.run(padded_value)) + with self.assertRaisesOpError("End of sequence"): + sess.run(padded_value) + + +# TODO(ebrevdo): Figure out how to use run_core_tests to test state +# saving of an iterator that's had some tensors enqueued into its queue. +class PrependFromQueueAndPaddedBatchDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def testPrependFromQueueAndPaddedBatch(self): + + def build_dataset(seq_lens): + return dataset_ops.Dataset.from_tensor_slices(seq_lens).map( + lambda x: array_ops.fill([x], x)).apply( + tqd.prepend_from_queue_and_padded_batch_dataset(batch_size=4)) + + seq_lens1 = np.random.randint(1, 20, size=(32,)).astype(np.int32) + seq_lens2 = np.random.randint(21, 40, size=(32,)).astype(np.int32) + self.run_core_tests(lambda: build_dataset(seq_lens1), + lambda: build_dataset(seq_lens2), 8) + + def testPrependFromQueueAndPaddedBatchNonDefaultPadding(self): + + def build_dataset(seq_lens): + + def fill_tuple(x): + filled = array_ops.fill([x], x) + return (filled, string_ops.as_string(filled)) + + padded_shape = [-1] + return dataset_ops.Dataset.from_tensor_slices(seq_lens).map( + fill_tuple).apply( + tqd.prepend_from_queue_and_padded_batch_dataset( + batch_size=4, + padded_shapes=(padded_shape, padded_shape), + padding_values=(-1, ""))) + + seq_lens1 = np.random.randint(1, 20, size=(32,)).astype(np.int32) + seq_lens2 = np.random.randint(21, 40, size=(32,)).astype(np.int32) + self.run_core_tests(lambda: build_dataset(seq_lens1), + lambda: build_dataset(seq_lens2), 8) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/util/convert_graphdef_memmapped_format_lib.cc b/tensorflow/contrib/util/convert_graphdef_memmapped_format_lib.cc index 2992a61ea8186caada394208e9c27ddffe896dd1..9675428e56e93c9669753371dbca47d56325b0c4 100644 --- a/tensorflow/contrib/util/convert_graphdef_memmapped_format_lib.cc +++ b/tensorflow/contrib/util/convert_graphdef_memmapped_format_lib.cc @@ -142,9 +142,9 @@ Status ConvertConstantsToImmutable(const string& in_graph_filename, const auto load_graph_status = ReadBinaryProto(default_env, in_graph_filename, &graph_def); if (!load_graph_status.ok()) { - return tensorflow::errors::NotFound("Failed to load graph at '", - in_graph_filename, "' : ", - load_graph_status.error_message()); + return tensorflow::errors::NotFound( + "Failed to load graph at '", in_graph_filename, + "' : ", load_graph_status.error_message()); } NodeConverter node_converter; diff --git a/tensorflow/contrib/util/inspect_checkpoint.cc b/tensorflow/contrib/util/inspect_checkpoint.cc index 39088aeaad68e26344b2e89ce10ae6da8026e481..9b578ceb07548b8d198f64bc859d31c92774a286 100644 --- a/tensorflow/contrib/util/inspect_checkpoint.cc +++ b/tensorflow/contrib/util/inspect_checkpoint.cc @@ -13,10 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/util/tensor_slice_reader.h" namespace tensorflow { diff --git a/tensorflow/contrib/verbs/verbs_server_lib.cc b/tensorflow/contrib/verbs/verbs_server_lib.cc index 47ed83f521c5e6165c906ea557e74faf27df2112..1a0b5028febb7b11f979abd179a3227a2615252d 100644 --- a/tensorflow/contrib/verbs/verbs_server_lib.cc +++ b/tensorflow/contrib/verbs/verbs_server_lib.cc @@ -49,8 +49,8 @@ VerbsServer::~VerbsServer() { Status VerbsServer::ChannelCacheFactory(const ServerDef& server_def, GrpcChannelCache** channel_cache) { string name_prefix = - strings::StrCat("/job:", server_def.job_name(), "/replica:0", "/task:", - server_def.task_index()); + strings::StrCat("/job:", server_def.job_name(), "/replica:0", + "/task:", server_def.task_index()); GrpcChannelSpec channel_spec; TF_RETURN_IF_ERROR(ParseChannelSpec(server_def, &channel_spec)); diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 90c2823ea4434b53567d8efa5628ac9db13ed275..4ba84b420f23849fd1e0bf8c13c0003175993666 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -433,6 +433,7 @@ tf_cuda_library( "framework/common_shape_fns.h", "framework/control_flow.h", # TODO(josh11b): Make internal? "framework/dataset.h", + "framework/dataset_stateful_op_whitelist.h", "framework/device_base.h", "framework/function.h", "framework/graph_def_util.h", @@ -1355,6 +1356,13 @@ tf_pyclif_proto_library( visibility = ["//visibility:public"], ) +tf_pyclif_proto_library( + name = "protobuf/device_properties_pyclif", + proto_lib = ":protos_all_cc", + proto_srcfile = "protobuf/device_properties.proto", + visibility = ["//visibility:public"], +) + # ----------------------------------------------------------------------------- # Internal targets @@ -1902,6 +1910,7 @@ cc_library( tf_cuda_library( name = "cuda_device_functions", hdrs = ["util/cuda_device_functions.h"], + cuda_deps = ["//third_party_gpus/cuda:cuda_headers"], visibility = ["//visibility:public"], deps = [":framework_lite"], ) diff --git a/tensorflow/core/api_def/BUILD b/tensorflow/core/api_def/BUILD index 81187ff6b772633105e0962d9da8f87d6cfd9558..58dbac4e8edac7079d315fbfcdafbd136793df0b 100644 --- a/tensorflow/core/api_def/BUILD +++ b/tensorflow/core/api_def/BUILD @@ -96,6 +96,7 @@ tf_cc_test( srcs = ["api_test.cc"], data = [ ":base_api_def", + ":python_api_def", ], deps = [ ":excluded_ops_lib", diff --git a/tensorflow/core/api_def/api_test.cc b/tensorflow/core/api_def/api_test.cc index 112c55ccc3ba1262b48c1b6c0890b3ae22744383..477a0b670e49f8aa4ee8c250d4957886eb865ed5 100644 --- a/tensorflow/core/api_def/api_test.cc +++ b/tensorflow/core/api_def/api_test.cc @@ -41,8 +41,9 @@ namespace tensorflow { namespace { constexpr char kDefaultApiDefDir[] = "tensorflow/core/api_def/base_api"; +constexpr char kPythonApiDefDir[] = + "tensorflow/core/api_def/python_api"; constexpr char kApiDefFilePattern[] = "api_def_*.pbtxt"; -} // namespace // Reads golden ApiDef files and returns a map from file name to ApiDef file // contents. @@ -66,9 +67,93 @@ void GetGoldenApiDefs(Env* env, const string& api_files_dir, } } -class ApiTest : public ::testing::Test { +void TestAllApiDefsHaveCorrespondingOp( + const OpList& ops, const std::unordered_map& api_defs_map) { + std::unordered_set op_names; + for (const auto& op : ops.op()) { + op_names.insert(op.name()); + } + for (const auto& name_and_api_def : api_defs_map) { + ASSERT_TRUE(op_names.find(name_and_api_def.first) != op_names.end()) + << name_and_api_def.first << " op has ApiDef but missing from ops. " + << "Does api_def_" << name_and_api_def.first << " need to be deleted?"; + } +} + +void TestAllApiDefInputArgsAreValid( + const OpList& ops, const std::unordered_map& api_defs_map) { + for (const auto& op : ops.op()) { + const auto api_def_iter = api_defs_map.find(op.name()); + if (api_def_iter == api_defs_map.end()) { + continue; + } + const auto& api_def = api_def_iter->second; + for (const auto& api_def_arg : api_def.in_arg()) { + bool found_arg = false; + for (const auto& op_arg : op.input_arg()) { + if (api_def_arg.name() == op_arg.name()) { + found_arg = true; + break; + } + } + ASSERT_TRUE(found_arg) + << "Input argument " << api_def_arg.name() + << " (overwritten in api_def_" << op.name() + << ".pbtxt) is not defined in OpDef for " << op.name(); + } + } +} + +void TestAllApiDefOutputArgsAreValid( + const OpList& ops, const std::unordered_map& api_defs_map) { + for (const auto& op : ops.op()) { + const auto api_def_iter = api_defs_map.find(op.name()); + if (api_def_iter == api_defs_map.end()) { + continue; + } + const auto& api_def = api_def_iter->second; + for (const auto& api_def_arg : api_def.out_arg()) { + bool found_arg = false; + for (const auto& op_arg : op.output_arg()) { + if (api_def_arg.name() == op_arg.name()) { + found_arg = true; + break; + } + } + ASSERT_TRUE(found_arg) + << "Output argument " << api_def_arg.name() + << " (overwritten in api_def_" << op.name() + << ".pbtxt) is not defined in OpDef for " << op.name(); + } + } +} + +void TestAllApiDefAttributeNamesAreValid( + const OpList& ops, const std::unordered_map& api_defs_map) { + for (const auto& op : ops.op()) { + const auto api_def_iter = api_defs_map.find(op.name()); + if (api_def_iter == api_defs_map.end()) { + continue; + } + const auto& api_def = api_def_iter->second; + for (const auto& api_def_attr : api_def.attr()) { + bool found_attr = false; + for (const auto& op_attr : op.attr()) { + if (api_def_attr.name() == op_attr.name()) { + found_attr = true; + } + } + ASSERT_TRUE(found_attr) + << "Attribute " << api_def_attr.name() << " (overwritten in api_def_" + << op.name() << ".pbtxt) is not defined in OpDef for " << op.name(); + } + } +} +} // namespace + +class BaseApiTest : public ::testing::Test { protected: - ApiTest() { + BaseApiTest() { OpRegistry::Global()->Export(false, &ops_); const std::vector multi_line_fields = {"description"}; @@ -80,7 +165,7 @@ class ApiTest : public ::testing::Test { }; // Check that all ops have an ApiDef. -TEST_F(ApiTest, AllOpsAreInApiDef) { +TEST_F(BaseApiTest, AllOpsAreInApiDef) { auto* excluded_ops = GetExcludedOps(); for (const auto& op : ops_.op()) { if (excluded_ops->find(op.name()) != excluded_ops->end()) { @@ -94,16 +179,8 @@ TEST_F(ApiTest, AllOpsAreInApiDef) { } // Check that ApiDefs have a corresponding op. -TEST_F(ApiTest, AllApiDefsHaveCorrespondingOp) { - std::unordered_set op_names; - for (const auto& op : ops_.op()) { - op_names.insert(op.name()); - } - for (const auto& name_and_api_def : api_defs_map_) { - ASSERT_TRUE(op_names.find(name_and_api_def.first) != op_names.end()) - << name_and_api_def.first << " op has ApiDef but missing from ops. " - << "Does api_def_" << name_and_api_def.first << " need to be deleted?"; - } +TEST_F(BaseApiTest, AllApiDefsHaveCorrespondingOp) { + TestAllApiDefsHaveCorrespondingOp(ops_, api_defs_map_); } string GetOpDefHasDocStringError(const string& op_name) { @@ -117,7 +194,7 @@ string GetOpDefHasDocStringError(const string& op_name) { // Check that OpDef's do not have descriptions and summaries. // Descriptions and summaries must be in corresponding ApiDefs. -TEST_F(ApiTest, OpDefsShouldNotHaveDocs) { +TEST_F(BaseApiTest, OpDefsShouldNotHaveDocs) { auto* excluded_ops = GetExcludedOps(); for (const auto& op : ops_.op()) { if (excluded_ops->find(op.name()) != excluded_ops->end()) { @@ -143,62 +220,56 @@ TEST_F(ApiTest, OpDefsShouldNotHaveDocs) { // Checks that input arg names in an ApiDef match input // arg names in corresponding OpDef. -TEST_F(ApiTest, AllApiDefInputArgsAreValid) { - for (const auto& op : ops_.op()) { - const auto& api_def = api_defs_map_[op.name()]; - for (const auto& api_def_arg : api_def.in_arg()) { - bool found_arg = false; - for (const auto& op_arg : op.input_arg()) { - if (api_def_arg.name() == op_arg.name()) { - found_arg = true; - break; - } - } - ASSERT_TRUE(found_arg) - << "Input argument " << api_def_arg.name() - << " (overwritten in api_def_" << op.name() - << ".pbtxt) is not defined in OpDef for " << op.name(); - } - } +TEST_F(BaseApiTest, AllApiDefInputArgsAreValid) { + TestAllApiDefInputArgsAreValid(ops_, api_defs_map_); } // Checks that output arg names in an ApiDef match output // arg names in corresponding OpDef. -TEST_F(ApiTest, AllApiDefOutputArgsAreValid) { - for (const auto& op : ops_.op()) { - const auto& api_def = api_defs_map_[op.name()]; - for (const auto& api_def_arg : api_def.out_arg()) { - bool found_arg = false; - for (const auto& op_arg : op.output_arg()) { - if (api_def_arg.name() == op_arg.name()) { - found_arg = true; - break; - } - } - ASSERT_TRUE(found_arg) - << "Output argument " << api_def_arg.name() - << " (overwritten in api_def_" << op.name() - << ".pbtxt) is not defined in OpDef for " << op.name(); - } - } +TEST_F(BaseApiTest, AllApiDefOutputArgsAreValid) { + TestAllApiDefOutputArgsAreValid(ops_, api_defs_map_); } // Checks that attribute names in an ApiDef match attribute // names in corresponding OpDef. -TEST_F(ApiTest, AllApiDefAttributeNamesAreValid) { - for (const auto& op : ops_.op()) { - const auto& api_def = api_defs_map_[op.name()]; - for (const auto& api_def_attr : api_def.attr()) { - bool found_attr = false; - for (const auto& op_attr : op.attr()) { - if (api_def_attr.name() == op_attr.name()) { - found_attr = true; - } - } - ASSERT_TRUE(found_attr) - << "Attribute " << api_def_attr.name() << " (overwritten in api_def_" - << op.name() << ".pbtxt) is not defined in OpDef for " << op.name(); - } +TEST_F(BaseApiTest, AllApiDefAttributeNamesAreValid) { + TestAllApiDefAttributeNamesAreValid(ops_, api_defs_map_); +} + +class PythonApiTest : public ::testing::Test { + protected: + PythonApiTest() { + OpRegistry::Global()->Export(false, &ops_); + const std::vector multi_line_fields = {"description"}; + + Env* env = Env::Default(); + GetGoldenApiDefs(env, kPythonApiDefDir, &api_defs_map_); } + OpList ops_; + std::unordered_map api_defs_map_; +}; + +// Check that ApiDefs have a corresponding op. +TEST_F(PythonApiTest, AllApiDefsHaveCorrespondingOp) { + TestAllApiDefsHaveCorrespondingOp(ops_, api_defs_map_); } + +// Checks that input arg names in an ApiDef match input +// arg names in corresponding OpDef. +TEST_F(PythonApiTest, AllApiDefInputArgsAreValid) { + TestAllApiDefInputArgsAreValid(ops_, api_defs_map_); +} + +// Checks that output arg names in an ApiDef match output +// arg names in corresponding OpDef. +TEST_F(PythonApiTest, AllApiDefOutputArgsAreValid) { + TestAllApiDefOutputArgsAreValid(ops_, api_defs_map_); +} + +// Checks that attribute names in an ApiDef match attribute +// names in corresponding OpDef. +TEST_F(PythonApiTest, AllApiDefAttributeNamesAreValid) { + TestAllApiDefAttributeNamesAreValid(ops_, api_defs_map_); +} + } // namespace tensorflow diff --git a/tensorflow/core/api_def/base_api/api_def_AssignAddVariableOp.pbtxt b/tensorflow/core/api_def/base_api/api_def_AssignAddVariableOp.pbtxt index 5d21d7bab699ff481c65ed44eb9bf66ec14ea387..ac05b54eea95f70e4a6db843aab13adf7b94602c 100644 --- a/tensorflow/core/api_def/base_api/api_def_AssignAddVariableOp.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_AssignAddVariableOp.pbtxt @@ -20,10 +20,7 @@ END } summary: "Adds a value to the current value of a variable." description: <::min()`. +If the maximum is empty for a given segment ID `i`, it outputs the smallest +possible value for the specific numeric type, +`output[i] = numeric_limits::lowest()`.
diff --git a/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentMin.pbtxt b/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentMin.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..55ea69b5dd5f7fda5c877ca5771ec2cbb86e3a9a --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentMin.pbtxt @@ -0,0 +1,33 @@ +op { + graph_op_name: "UnsortedSegmentMin" + in_arg { + name: "segment_ids" + description: <::max()`. +END +} diff --git a/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentProd.pbtxt b/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentProd.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..577ff53d60c5a174b4ba43a667885a6983b2dfb9 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentProd.pbtxt @@ -0,0 +1,32 @@ +op { + graph_op_name: "UnsortedSegmentProd" + in_arg { + name: "segment_ids" + description: <ChunkFromHandle(prev); strings::StrAppend(&dbg, ", prev: ", p->DebugString(a, false)); diff --git a/tensorflow/core/common_runtime/constant_folding.cc b/tensorflow/core/common_runtime/constant_folding.cc index 0398c2a60d1fe4dfeed91e242272f13dd45389b2..b5a51d2526d95313d4564337ae0420472bc0b3da 100644 --- a/tensorflow/core/common_runtime/constant_folding.cc +++ b/tensorflow/core/common_runtime/constant_folding.cc @@ -328,7 +328,8 @@ void FindConstantFoldableNodes( ConsiderConstantFoldableNode( n, opts, nodes, constant_control_deps, shape_replacement_map, &internal_node_inserted); - }); + }, + NodeComparatorName()); // If we have inserted just leaf level nodes, then there is nothing to fold. if (!internal_node_inserted) { nodes->clear(); @@ -339,8 +340,8 @@ void FindConstantFoldableNodes( typedef std::pair NodeAndOutput; int64 UniqueConstantId() { - static std::atomic_int_fast64_t id; - return id.fetch_add(1); + static std::atomic_int_fast64_t unique_constant_id; + return unique_constant_id.fetch_add(1); } // Adds n to constant_graph which is being built up for subsequent evaluation of @@ -386,14 +387,12 @@ void AddShapeNodeToConstantGraph( const std::unordered_map>& shape_replacement_map, std::unordered_map>* node_map, - Graph* constant_graph) { + const ConstantFoldNameGenerator& generate_new_name, Graph* constant_graph) { std::vector& added = (*node_map)[n]; const string& node_name = n->name(); for (const Tensor& t : shape_replacement_map.at(n)) { auto builder = - NodeDefBuilder(strings::StrCat(constant_graph->NewName(node_name), - "__cf__", UniqueConstantId()), - "Const") + NodeDefBuilder(generate_new_name(constant_graph, node_name), "Const") .Attr("dtype", t.dtype()) .Attr("value", t); NodeDef def; @@ -414,7 +413,8 @@ Graph* GetConstantGraph( const Graph* orig_graph, const std::vector& nodes, const std::unordered_map>& shape_replacement_map, - std::map* tensors_to_fetch) { + std::map* tensors_to_fetch, + const ConstantFoldNameGenerator& generate_new_name) { Graph* constant_graph = new Graph(orig_graph->op_registry()); std::unordered_map> node_map; node_map[orig_graph->source_node()] = {constant_graph->source_node()}; @@ -424,7 +424,7 @@ Graph* GetConstantGraph( AddNodeToConstantGraph(n, &node_map, constant_graph); } else { AddShapeNodeToConstantGraph(n, shape_replacement_map, &node_map, - constant_graph); + generate_new_name, constant_graph); } } @@ -458,10 +458,11 @@ Graph* GetConstantGraph( // replacement was successful, false otherwise. // 'control_deps' is the set of nodes that should be control predecessors of the // new constant node. -bool ReplaceTensorWithConstant(Graph* graph, Device* partition_device, - NodeAndOutput tensor, const Tensor& constant, - const gtl::FlatSet& control_deps, - int64 max_constant_size_in_bytes) { +bool ReplaceTensorWithConstant( + Graph* graph, Device* partition_device, NodeAndOutput tensor, + const Tensor& constant, const gtl::FlatSet& control_deps, + int64 max_constant_size_in_bytes, + const ConstantFoldNameGenerator& generate_new_name) { // Be conservative when replacing a tensor with a constant, when not // running on CPU. // 1) If the destination tensor is not an int32 tensor, and has HOST_MEMORY @@ -509,9 +510,7 @@ bool ReplaceTensorWithConstant(Graph* graph, Device* partition_device, } const string& node_name = n->name(); Node* constant_node; - auto builder = NodeDefBuilder(strings::StrCat(graph->NewName(node_name), - "__cf__", UniqueConstantId()), - "Const") + auto builder = NodeDefBuilder(generate_new_name(graph, node_name), "Const") .Attr("dtype", constant.dtype()) .Attr("value", constant); if (partition_device) { @@ -555,6 +554,13 @@ Status ConstantFold(const ConstantFoldingOptions& opts, FunctionLibraryRuntime* function_library, Env* env, Device* partition_device, Graph* graph, bool* was_mutated) { DumpGraph("Before", graph); + ConstantFoldNameGenerator generate_new_name = opts.generate_new_name; + if (generate_new_name == nullptr) { + generate_new_name = [](Graph* graph, string old_name) { + return strings::StrCat(graph->NewName(old_name), "__cf__", + UniqueConstantId()); + }; + } std::vector constant_foldable_nodes; std::unordered_map> constant_control_deps; @@ -571,7 +577,7 @@ Status ConstantFold(const ConstantFoldingOptions& opts, std::map tensors_to_fetch; std::unique_ptr constant_graph( GetConstantGraph(graph, constant_foldable_nodes, shape_replacement_map, - &tensors_to_fetch)); + &tensors_to_fetch, generate_new_name)); DumpGraph("Constant graph", constant_graph.get()); if (tensors_to_fetch.empty()) { @@ -585,7 +591,16 @@ Status ConstantFold(const ConstantFoldingOptions& opts, std::vector tensors_to_fetch_names; std::vector tensors_to_replace; - for (auto n : tensors_to_fetch) { + // Sorting the nodes based on the name gives us a stable ordering between runs + // for the same graph. + std::vector> tensors_to_fetch_sorted( + tensors_to_fetch.begin(), tensors_to_fetch.end()); + std::sort(tensors_to_fetch_sorted.begin(), tensors_to_fetch_sorted.end(), + [](const std::pair& n1, + const std::pair& n2) { + return n1.first.first->name() < n2.first.first->name(); + }); + for (auto n : tensors_to_fetch_sorted) { tensors_to_fetch_names.push_back( strings::StrCat(n.first.first->name(), ":", n.first.second)); tensors_to_replace.push_back({n.second, n.first.second}); @@ -617,7 +632,7 @@ Status ConstantFold(const ConstantFoldingOptions& opts, constant_control_deps[tensors_to_replace[c].first]; if (ReplaceTensorWithConstant( graph, partition_device, tensors_to_replace[c], outputs[c], - control_deps, opts.max_constant_size_in_bytes)) { + control_deps, opts.max_constant_size_in_bytes, generate_new_name)) { ++num_nodes_replaced; } } diff --git a/tensorflow/core/common_runtime/constant_folding.h b/tensorflow/core/common_runtime/constant_folding.h index e4d724c58a25347db3e40a0d024acf1ac97ea575..b1e1fb831963bccb81731752ec76b9d5be123d9f 100644 --- a/tensorflow/core/common_runtime/constant_folding.h +++ b/tensorflow/core/common_runtime/constant_folding.h @@ -24,6 +24,11 @@ limitations under the License. namespace tensorflow { +// This generator type is used to generate a name for the newly folded node +// based on the node's old name. +using ConstantFoldNameGenerator = + std::function; + // Options specific to constant folding optimizations. struct ConstantFoldingOptions { // If "consider" is not a nullptr, then only constant fold a node "n" if @@ -37,6 +42,11 @@ struct ConstantFoldingOptions { // The maximum size of each constant created during constant folding // optimization. int64 max_constant_size_in_bytes = 10 * 1024 * 1024; + + // A generator for the name suffix of constant folded nodes. A + // default id generator that monotonically increases is used if nullptr is + // passed. + ConstantFoldNameGenerator generate_new_name = nullptr; }; // Perform constant folding optimization on "graph". diff --git a/tensorflow/core/common_runtime/constant_folding_test.cc b/tensorflow/core/common_runtime/constant_folding_test.cc index 923a4d924936386ce0e06c6355c2a4d0af5cc4a4..6ac9319ad1e2c4953c2d82257dac6a3aeeffcd5c 100644 --- a/tensorflow/core/common_runtime/constant_folding_test.cc +++ b/tensorflow/core/common_runtime/constant_folding_test.cc @@ -121,6 +121,58 @@ TEST_F(ConstantFoldingTest, Basic) { {2, 2}); } +// Tests that different node creation ordering creates same graph after constant +// folding. +TEST_F(ConstantFoldingTest, DeterministicFolding) { + auto build_graph_and_constant_folding = [](Graph& g, bool swap) -> Status { + Scope s = Scope::NewRootScope(); + auto a = ops::Const(s, {1.0}, {}); + auto b = ops::Const(s, {2.0}, {}); + + if (swap) { + auto add1 = ops::Add(s.WithOpName("add1"), a, b); + auto add2 = ops::Add(s.WithOpName("add2"), a, b); + auto s1 = + ops::_Send(s.WithOpName("s1"), add1, "add1", "sender", 0, "receiver"); + auto s2 = + ops::_Send(s.WithOpName("s2"), add2, "add2", "sender", 0, "receiver"); + } else { + // Swap the order of node creation. + auto add2 = ops::Add(s.WithOpName("add2"), a, b); + auto add1 = ops::Add(s.WithOpName("add1"), a, b); + auto s1 = + ops::_Send(s.WithOpName("s1"), add1, "add1", "sender", 0, "receiver"); + auto s2 = + ops::_Send(s.WithOpName("s2"), add2, "add2", "sender", 0, "receiver"); + } + + TF_CHECK_OK(s.ToGraph(&g)); + bool was_mutated; + int64 unique_id = 0; + auto generate_new_name = [&unique_id](Graph* graph, string old_name) { + return strings::StrCat(graph->NewName(old_name), "__cf__", unique_id++); + }; + ConstantFoldingOptions opt{}; + opt.generate_new_name = generate_new_name; + TF_CHECK_OK( + ConstantFold(opt, nullptr, Env::Default(), nullptr, &g, &was_mutated)); + return Status::OK(); + }; + + Graph g1(OpRegistry::Global()); + TF_ASSERT_OK(build_graph_and_constant_folding(g1, false)); + Graph g2(OpRegistry::Global()); + TF_ASSERT_OK(build_graph_and_constant_folding(g2, true)); + EXPECT_EQ(g1.num_nodes(), g2.num_nodes()); + auto index = NodeNameIndex(g2); + + // All the nodes in g1 are expected to be present in g2. + for (int64 i = 0; i < g1.num_nodes(); ++i) { + Node* n1 = g1.FindNodeId(i); + EXPECT_GT(index.count(n1->name()), 0); + } +} + TEST_F(ConstantFoldingTest, ConsiderFunction) { Scope s = Scope::NewRootScope(); BuildSimpleGraph(&s); diff --git a/tensorflow/core/common_runtime/device_set_test.cc b/tensorflow/core/common_runtime/device_set_test.cc index 0507076c8c3734083ac0ef7ffea0edebf180ad1a..fd9c4222a7afd4914415c9c62e1ced118ea75d1f 100644 --- a/tensorflow/core/common_runtime/device_set_test.cc +++ b/tensorflow/core/common_runtime/device_set_test.cc @@ -88,7 +88,9 @@ TEST_F(DeviceSetTest, PrioritizedDeviceTypeList) { // D3 is prioritized below D1. AddDevice("d3", "/job:a/replica:0/task:0/device:d3:0"); EXPECT_EQ((std::vector{ - DeviceType("d2"), DeviceType("d1"), DeviceType("d3"), + DeviceType("d2"), + DeviceType("d1"), + DeviceType("d3"), }), types()); } diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc index 20c59ad42b35865f6fdad60e8bc8ac5ffebc4415..df6f4b88773fb1a72100d1c223276a06b857a908 100644 --- a/tensorflow/core/common_runtime/direct_session.cc +++ b/tensorflow/core/common_runtime/direct_session.cc @@ -61,7 +61,6 @@ limitations under the License. #include "tensorflow/core/util/device_name_utils.h" #include "tensorflow/core/util/env_var.h" - namespace tensorflow { namespace { @@ -472,9 +471,9 @@ Status DirectSession::Run(const RunOptions& run_options, Executor::Args args; args.step_id = step_id_counter_.fetch_add(1); - TF_RETURN_IF_ERROR( - GetOrCreateExecutors(input_tensor_names, output_names, target_nodes, - &executors_and_keys, &run_state_args)); + TF_RETURN_IF_ERROR(GetOrCreateExecutors(input_tensor_names, output_names, + target_nodes, &executors_and_keys, + &run_state_args)); const int64 executor_step_count = executors_and_keys->step_count.fetch_add(1); std::unique_ptr debugger_state; diff --git a/tensorflow/core/common_runtime/direct_session_test.cc b/tensorflow/core/common_runtime/direct_session_test.cc index 99b33e2ef0d532aca08dfb538857d347d22a7351..b75a4f76d94f704cf38a6c4657b6089a863c085f 100644 --- a/tensorflow/core/common_runtime/direct_session_test.cc +++ b/tensorflow/core/common_runtime/direct_session_test.cc @@ -436,10 +436,7 @@ TEST(DirectSessionTest, FetchMultipleTimes) { } } -REGISTER_OP("Darth") - .Input("x: float") - .Output("y: float") - .Doc(R"doc( +REGISTER_OP("Darth").Input("x: float").Output("y: float").Doc(R"doc( Darth promises one return value. x: float @@ -972,39 +969,38 @@ static void TestSessionInterOpThreadsImpl(bool use_function_lib, std::atomic num_done(0); // Runs session to compute :0 using inter_op thread pool . - auto add_session_run_call = [use_global_pools, &def, &options, &sessions, - &sessions_mu, - &num_done](thread::ThreadPool* tp, Node* node, - int inter_op_pool) { - auto fn = [use_global_pools, &def, &options, &sessions, &sessions_mu, - inter_op_pool, node, &num_done]() { - RunOptions run_options; - run_options.set_inter_op_thread_pool(inter_op_pool); - std::vector outputs; - - Session* session; - if (use_global_pools) { - std::unique_ptr s(NewSession(options)); - TF_ASSERT_OK(s->Create(def)); - session = s.get(); - - mutex_lock l(sessions_mu); - sessions.emplace_back(std::move(s)); - } else { - session = sessions[0].get(); - } + auto add_session_run_call = + [use_global_pools, &def, &options, &sessions, &sessions_mu, &num_done]( + thread::ThreadPool* tp, Node* node, int inter_op_pool) { + auto fn = [use_global_pools, &def, &options, &sessions, &sessions_mu, + inter_op_pool, node, &num_done]() { + RunOptions run_options; + run_options.set_inter_op_thread_pool(inter_op_pool); + std::vector outputs; + + Session* session; + if (use_global_pools) { + std::unique_ptr s(NewSession(options)); + TF_ASSERT_OK(s->Create(def)); + session = s.get(); + + mutex_lock l(sessions_mu); + sessions.emplace_back(std::move(s)); + } else { + session = sessions[0].get(); + } - Status s = session->Run(run_options, {} /* inputs */, - {node->name() + ":0"} /* output_names */, {}, - &outputs, nullptr /* run_metadata */); - TF_CHECK_OK(s); - ASSERT_EQ(1, outputs.size()); - auto flat = outputs[0].flat(); - EXPECT_FLOAT_EQ(1.2, flat(0)); - num_done.fetch_add(1); - }; - tp->Schedule(fn); - }; + Status s = session->Run(run_options, {} /* inputs */, + {node->name() + ":0"} /* output_names */, {}, + &outputs, nullptr /* run_metadata */); + TF_CHECK_OK(s); + ASSERT_EQ(1, outputs.size()); + auto flat = outputs[0].flat(); + EXPECT_FLOAT_EQ(1.2, flat(0)); + num_done.fetch_add(1); + }; + tp->Schedule(fn); + }; // For blocking states: // - Starts at 0, BlockingOp::Compute will move to 1. diff --git a/tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc b/tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc index df9cf0c91f1b7e5521061b6915fc1b7ed609e003..31fb128f937ae46eefb309fc9bab8167e54846a7 100644 --- a/tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc +++ b/tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc @@ -161,14 +161,14 @@ static void TestHWAccelerator(bool enableHWTrace) { x->set_assigned_device_name("/job:localhost/replica:0/task:0/device:GPU:0"); #ifdef TENSORFLOW_USE_SYCL x->set_assigned_device_name("/job:localhost/replica:0/task:0/device:SYCL:0"); -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL // y = A * x Node* y = test::graph::Matmul(&graph, a, x, false, false); y->set_assigned_device_name("/job:localhost/replica:0/task:0/device:GPU:0"); #ifdef TENSORFLOW_USE_SYCL -y->set_assigned_device_name("/job:localhost/replica:0/task:0/device:SYCL:0"); -#endif // TENSORFLOW_USE_SYCL + y->set_assigned_device_name("/job:localhost/replica:0/task:0/device:SYCL:0"); +#endif // TENSORFLOW_USE_SYCL Node* y_neg = test::graph::Unary(&graph, "Neg", y); y_neg->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:0"); @@ -181,7 +181,7 @@ y->set_assigned_device_name("/job:localhost/replica:0/task:0/device:SYCL:0"); (*options.config.mutable_device_count())["GPU"] = 1; #ifdef TENSORFLOW_USE_SYCL (*options.config.mutable_device_count())["SYCL"] = 1; -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL options.config.set_allow_soft_placement(true); options.config.mutable_graph_options()->set_build_cost_model(1); std::unique_ptr session(NewSession(options)); diff --git a/tensorflow/core/common_runtime/executor.cc b/tensorflow/core/common_runtime/executor.cc index 9d03caff1e1e89c4c667f94853352580545e70e5..e3416da988cf14f62e4dd44ebbdcb64f1fb0b1cb 100644 --- a/tensorflow/core/common_runtime/executor.cc +++ b/tensorflow/core/common_runtime/executor.cc @@ -1609,7 +1609,7 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_usec) { auto done = [this, state]() { Device* device = impl_->params_.device; NodeExecStatsWrapper* stats = state->stats; // Shorthand - Entry* first_input = state->first_input; // Shorthand + Entry* first_input = state->first_input; // Shorthand nodestats::SetOpEnd(stats); EntryVector outputs; @@ -1776,6 +1776,19 @@ Status ExecutorState::PrepareInputs(const NodeItem& item, Entry* first_input, entry->ref_mu = nullptr; inp->tensor = entry->val.get(); + // The dtype of entry->ref could have been changed by another operation + // that ran after the operation that "produced" it executed, so + // re-validate that the type of the dereferenced tensor matches the + // expected input type. + if (item.input_type(i) != inp->tensor->dtype()) { + return AttachDef( + errors::InvalidArgument( + i, "-th input expects type ", + DataTypeString(item.input_type(i)), + " but automatically dereferenced input tensor has type ", + DataTypeString(inp->tensor->dtype())), + item.kernel->def()); + } } } } diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc index e9c4328f29e2c941afd8e14142beb0db224110d8..d349d2bb1251d8b31d8e432ad8c357d6f0a81389 100644 --- a/tensorflow/core/common_runtime/function.cc +++ b/tensorflow/core/common_runtime/function.cc @@ -97,12 +97,11 @@ static Node* AddNoOp(Graph* g) { static Node* AddIdentity(Graph* g, Endpoint input) { DCHECK_LT(0, input.dtype()); - DCHECK_LT(input.dtype(), DT_FLOAT_REF); NodeDef ndef; ndef.set_name(g->NewName(kNodeLabel)); ndef.set_op("Identity"); ndef.add_input(input.name()); - AddNodeAttr("T", input.dtype(), &ndef); + AddNodeAttr("T", BaseType(input.dtype()), &ndef); Status s; Node* ret = g->AddNode(ndef, &s); TF_CHECK_OK(s); @@ -183,6 +182,10 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime { string DebugString(Handle h) override; + Status Clone(std::unique_ptr* out_lib_def, + std::unique_ptr* out_pflr, + FunctionLibraryRuntime** out_flr) override; + private: typedef FunctionLibraryRuntimeImpl ME; @@ -205,7 +208,7 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime { // The instantiated and transformed function is encoded as a Graph // object, and an executor is created for the graph. struct Item : public core::RefCounted { - const Graph* graph = nullptr; // Owned by exec. + const Graph* graph = nullptr; // Owned by exec. const FunctionLibraryDefinition* overlay_lib = nullptr; // Not owned. FunctionBody* func_graph = nullptr; Executor* exec = nullptr; @@ -895,6 +898,21 @@ string FunctionLibraryRuntimeImpl::DebugString(Handle handle) { } } +Status FunctionLibraryRuntimeImpl::Clone( + std::unique_ptr* out_lib_def, + std::unique_ptr* out_pflr, + FunctionLibraryRuntime** out_flr) { + TF_RETURN_IF_ERROR( + parent_->Clone(env_, graph_def_version_, optimizer_.options(), + custom_kernel_creator_, out_lib_def, out_pflr)); + *out_flr = (*out_pflr)->GetFLR(device_->name()); + if (out_flr != nullptr) { + return Status::OK(); + } else { + return errors::Internal("Cloning FunctionLibraryRuntime failed."); + } +} + namespace { struct CustomCreatorSingleton { diff --git a/tensorflow/core/common_runtime/function_test.cc b/tensorflow/core/common_runtime/function_test.cc index cad3b3801e74a00a9f6fb6b236842f5caeaf72bc..8b051462990fc3abc5b864f644274bb8b2211191 100644 --- a/tensorflow/core/common_runtime/function_test.cc +++ b/tensorflow/core/common_runtime/function_test.cc @@ -787,7 +787,7 @@ TEST_F(FunctionLibraryRuntimeTest, OptimizeGraph) { Scope s = Scope::NewRootScope(); auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0); auto x4_x2_scale = ops::Const( - s.WithOpName("x4/x2/scale/_15__cf__9") + s.WithOpName("x4/x2/scale/_12__cf__6") .WithDevice("/job:localhost/replica:0/task:0/device:CPU:0"), 2.0f); auto x4_x2_y = ops::Mul(s.WithOpName("x4/x2/y"), x, x4_x2_scale); @@ -993,13 +993,13 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_XTimesTwo) { auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0); auto func0 = ops::_Arg(s.WithOpName("Func/_0"), DT_FLOAT, 1); auto scale = ops::Const( - s.WithOpName("scale/_5__cf__10") + s.WithOpName("scale/_6__cf__11") .WithDevice("/job:localhost/replica:0/task:0/device:CPU:0"), 2.0f); auto func1_gx = ops::Mul(s.WithOpName("Func/_1/gx"), func0, scale); auto func1_sx = ops::Shape(s.WithOpName("Func/_1/sx"), x); auto const0 = ops::Const( - s.WithOpName("Func/_1/sy/_6__cf__11") + s.WithOpName("Func/_1/sy/_5__cf__10") .WithDevice("/job:localhost/replica:0/task:0/device:CPU:0"), 0, {0}); auto func1_rx = ops::internal::BroadcastGradientArgs( diff --git a/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator_test.cc b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator_test.cc index 9e4b617d2bd5b070f5b8bdeedabb15b94d212743..67caeb3495c6b0600f12c9b20ef73ee90f8b3e0d 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator_test.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator_test.cc @@ -154,8 +154,9 @@ TEST(GPUBFCAllocatorTest, ExerciseCoalescing) { a.DeallocateRaw(t3); a.DeallocateRaw(t4); } - CheckStats(&a, 4097, 0, 1024 * sizeof(float) + 1048576 * sizeof(int64) + - 2048 * sizeof(double) + 10485760 * sizeof(float), + CheckStats(&a, 4097, 0, + 1024 * sizeof(float) + 1048576 * sizeof(int64) + + 2048 * sizeof(double) + 10485760 * sizeof(float), 10485760 * sizeof(float)); // At the end, we should have coalesced all memory into one region diff --git a/tensorflow/core/common_runtime/gpu/gpu_device.cc b/tensorflow/core/common_runtime/gpu/gpu_device.cc index 933d700f6042bf51f11f773d731cece6ef5af436..80a5bdbfff4ddfc40eb6beba619cd97c308b04c9 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_device.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_device.cc @@ -762,9 +762,11 @@ int64 MinSystemMemory(int64 available_memory) { // is necessary. min_system_memory *= 2; #endif + #if defined(ANDROID_TEGRA) - // 1GB system mem for NVIDIA Tegra devices since they use the same mem for RAM and Video RAM - min_system_memory = 1<<30; + // 1GB system mem for NVIDIA Tegra devices since they use the same mem for RAM + // and Video RAM + min_system_memory = 1 << 30; #endif return min_system_memory; } diff --git a/tensorflow/core/common_runtime/gpu/gpu_stream_util_test.cc b/tensorflow/core/common_runtime/gpu/gpu_stream_util_test.cc index 7763a4f2e6f50292e78b4d16d8d4a3ee84d4163b..2500425359c424fa479af6dd34d6a0312c404577 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_stream_util_test.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_stream_util_test.cc @@ -108,7 +108,8 @@ TEST_F(GpuStreamUtilTest, StreamOverrides) { ops::_Recv(root.WithOpName("input"), DT_FLOAT, "input", "/cpu:0", 0, "/device:GPU:0"); Output n = ops::MatMul(root, {}, {}); - ops::_Send(root.WithOpName("output"), n, "output", "/device:GPU:0", 0, "/cpu:0"); + ops::_Send(root.WithOpName("output"), n, "output", "/device:GPU:0", 0, + "/cpu:0"); Graph g(OpRegistry::Global()); TF_ASSERT_OK(root.ToGraph(&g)); diff --git a/tensorflow/core/common_runtime/gpu/process_state.cc b/tensorflow/core/common_runtime/gpu/process_state.cc index 2f13cf8bd76e955e2ead1b2cf575b27a14e40b26..b195de7cbace095cfb29fa2adf1ee5f44853cab5 100644 --- a/tensorflow/core/common_runtime/gpu/process_state.cc +++ b/tensorflow/core/common_runtime/gpu/process_state.cc @@ -88,8 +88,8 @@ ProcessState::~ProcessState() { } string ProcessState::MemDesc::DebugString() { - return strings::StrCat((loc == CPU ? "CPU " : "GPU "), dev_index, ", dma: ", - gpu_registered, ", nic: ", nic_registered); + return strings::StrCat((loc == CPU ? "CPU " : "GPU "), dev_index, + ", dma: ", gpu_registered, ", nic: ", nic_registered); } ProcessState::MemDesc ProcessState::PtrType(const void* ptr) { diff --git a/tensorflow/core/common_runtime/graph_execution_state.cc b/tensorflow/core/common_runtime/graph_execution_state.cc index 3b309e915cdd2c6d5eead9ed0312f3873bcf7335..33a5d60eb7ec4de829d3c0784f909ef42cf994d1 100644 --- a/tensorflow/core/common_runtime/graph_execution_state.cc +++ b/tensorflow/core/common_runtime/graph_execution_state.cc @@ -340,8 +340,11 @@ Status GraphExecutionState::OptimizeGraph( std::unordered_map device_map; Device* cpu_device = nullptr; for (const auto& device : device_set_->devices()) { - device_map[device->name()] = - grappler::GetDeviceInfo(device->parsed_name()); + DeviceProperties props = grappler::GetDeviceInfo(device->parsed_name()); + if (props.type() == "UNKNOWN") { + continue; + } + device_map[device->name()] = props; if (device->parsed_name().id == 0 && StringPiece(device->parsed_name().type) == "CPU" && device->GetAllocator(AllocatorAttributes()) != nullptr) { diff --git a/tensorflow/core/common_runtime/graph_execution_state.h b/tensorflow/core/common_runtime/graph_execution_state.h index db2686ce2c45aa4c9997a624bb12720d63710b65..2312e1a89fd1fd5734fab4316c25ca2e39f16ae5 100644 --- a/tensorflow/core/common_runtime/graph_execution_state.h +++ b/tensorflow/core/common_runtime/graph_execution_state.h @@ -139,9 +139,7 @@ class GraphExecutionState { // The graph returned by BuildGraph may contain only the pruned // graph, whereas some clients may want access to the full graph. - const Graph* full_graph() { - return graph_; - } + const Graph* full_graph() { return graph_; } // Returns the node with the given name, or null if it does not exist. const Node* get_node_by_name(const string& name) const { diff --git a/tensorflow/core/common_runtime/graph_optimizer.h b/tensorflow/core/common_runtime/graph_optimizer.h index 8477cea126f1808d9472bd4f4127fd43e172848e..80246281cde373863e4da1bb8d86bee39bfb9dfd 100644 --- a/tensorflow/core/common_runtime/graph_optimizer.h +++ b/tensorflow/core/common_runtime/graph_optimizer.h @@ -52,6 +52,8 @@ class GraphOptimizer { shape_map, const std::function& cse_consider_fn = nullptr); + const OptimizerOptions& options() { return opts_; } + private: OptimizerOptions opts_; diff --git a/tensorflow/core/common_runtime/memory_types.cc b/tensorflow/core/common_runtime/memory_types.cc index 76b926ba40053288360f0e4e6fe2a37bd44ff0b4..090a16ebeb10007261666aeb6491a1785dd2e5c4 100644 --- a/tensorflow/core/common_runtime/memory_types.cc +++ b/tensorflow/core/common_runtime/memory_types.cc @@ -47,7 +47,7 @@ struct EndpointEq { static Status ProcessMemoryTypes( const DeviceType& device_type, const Graph* g, const std::function& fn) { - if (device_type != DEVICE_GPU && device_type != DEVICE_SYCL ) { + if (device_type != DEVICE_GPU && device_type != DEVICE_SYCL) { // On non-GPU and non-SYCL devices, HOST_MEMORY and DEVICE_MEMORY are always // compatible. return Status::OK(); diff --git a/tensorflow/core/common_runtime/memory_types_test.cc b/tensorflow/core/common_runtime/memory_types_test.cc index 2a834ddca4236c626c6252f63c97118e8e1f0bd0..a093585571994e8b161b46a7fc397cdc3cd4254c 100644 --- a/tensorflow/core/common_runtime/memory_types_test.cc +++ b/tensorflow/core/common_runtime/memory_types_test.cc @@ -36,7 +36,7 @@ TEST(MemoryTypeChecker, Int32OK) { #endif // GOOGLE_CUDA #ifdef TENSORFLOW_USE_SYCL TF_EXPECT_OK(ValidateMemoryTypes(DEVICE_SYCL, g)); -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL delete g; } @@ -64,7 +64,7 @@ TEST(MemoryTypeChecker, Int32NotOk) { // But we can insert _HostSend/_HostRecv to ensure the invariant. TF_EXPECT_OK(EnsureMemoryTypes(DEVICE_SYCL, "/device:SYCL:0", g)); TF_EXPECT_OK(ValidateMemoryTypes(DEVICE_SYCL, g)); -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL delete g; } @@ -91,7 +91,7 @@ TEST(MemoryTypeChecker, MemoryTypeForOutput) { TF_EXPECT_OK(MemoryTypeForOutput(DEVICE_SYCL, g, si, 0, &memory_type)); // int Switch's output on GPU has HOST_MEMORY constraint. EXPECT_EQ(memory_type, HOST_MEMORY); -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL delete g; } diff --git a/tensorflow/core/common_runtime/placer.h b/tensorflow/core/common_runtime/placer.h index c5b76592e1b4b86863009ef42b7bb7106377d054..75dce7c7feb2269fc994cbb8c5efd4b3799e75dd 100644 --- a/tensorflow/core/common_runtime/placer.h +++ b/tensorflow/core/common_runtime/placer.h @@ -88,9 +88,9 @@ class Placer { void AssignAndLog(int assigned_device, Node* node) const; void LogDeviceAssignment(const Node* node) const; - Graph* const graph_; // Not owned. - const DeviceSet* const devices_; // Not owned. - const SessionOptions* options_; // Not owned. + Graph* const graph_; // Not owned. + const DeviceSet* const devices_; // Not owned. + const SessionOptions* options_; // Not owned. const bool log_device_placement_; TF_DISALLOW_COPY_AND_ASSIGN(Placer); diff --git a/tensorflow/core/common_runtime/placer_test.cc b/tensorflow/core/common_runtime/placer_test.cc index 5d87b1e279ab0390a642df8f285fd451803ba29a..02c9cd5313ee24c83243d27baf688341451996c5 100644 --- a/tensorflow/core/common_runtime/placer_test.cc +++ b/tensorflow/core/common_runtime/placer_test.cc @@ -619,9 +619,9 @@ TEST_F(PlacerTest, TestReferenceConnectionIgnoreInfeasible) { Node* input = ops::SourceOp( "TestDevice", b.opts().WithName("in").WithDevice("/job:a/task:0/device:fakegpu:0")); - Node* var = ops::SourceOp("TestVariable", - b.opts().WithName("var_0").WithDevice( - "/job:a/task:0/device:fakegpu:0")); + Node* var = + ops::SourceOp("TestVariable", b.opts().WithName("var_0").WithDevice( + "/job:a/task:0/device:fakegpu:0")); // This op is specified on CPU, but in practice will be ignored, // because the reference edges forces it on GPU. diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.cc b/tensorflow/core/common_runtime/process_function_library_runtime.cc index 12947e284a36fef171caf6af0c46d59ca89efb61..f9d9633beea1c59dc79880b2120332f3ee7588bd 100644 --- a/tensorflow/core/common_runtime/process_function_library_runtime.cc +++ b/tensorflow/core/common_runtime/process_function_library_runtime.cc @@ -70,23 +70,6 @@ ProcessFunctionLibraryRuntime::ProcessFunctionLibraryRuntime( } } -ProcessFunctionLibraryRuntime::ProcessFunctionLibraryRuntime( - const DeviceMgr* device_mgr, Env* env, int graph_def_version, - const FunctionLibraryDefinition* lib_def, - const OptimizerOptions& optimizer_options) - : ProcessFunctionLibraryRuntime(device_mgr, env, graph_def_version, lib_def, - optimizer_options, - nullptr /* cluster_flr */) {} - -ProcessFunctionLibraryRuntime::ProcessFunctionLibraryRuntime( - const DeviceMgr* device_mgr, Env* env, int graph_def_version, - const FunctionLibraryDefinition* lib_def, - const OptimizerOptions& optimizer_options, - CustomKernelCreator custom_kernel_creator) - : ProcessFunctionLibraryRuntime( - device_mgr, env, graph_def_version, lib_def, optimizer_options, - std::move(custom_kernel_creator), nullptr /* cluster_flr */) {} - /* static */ Status ProcessFunctionLibraryRuntime::SendTensors( const string& source_device, const string& target_device, @@ -158,7 +141,7 @@ Status ProcessFunctionLibraryRuntime::GetDeviceContext( } FunctionLibraryRuntime* ProcessFunctionLibraryRuntime::GetFLR( - const string& device_name) { + const string& device_name) const { Device* device = nullptr; if (device_name != kDefaultFLRDevice) { if (!device_mgr_->LookupDevice(device_name, &device).ok()) { @@ -350,4 +333,16 @@ void ProcessFunctionLibraryRuntime::Run( done(errors::Internal("Could not find device")); } +Status ProcessFunctionLibraryRuntime::Clone( + Env* env, int graph_def_version, const OptimizerOptions& optimizer_options, + CustomKernelCreator custom_kernel_creator, + std::unique_ptr* out_lib_def, + std::unique_ptr* out_pflr) { + out_lib_def->reset(new FunctionLibraryDefinition(*lib_def_)); + out_pflr->reset(new ProcessFunctionLibraryRuntime( + device_mgr_, env, graph_def_version, out_lib_def->get(), + optimizer_options, std::move(custom_kernel_creator), parent_)); + return Status::OK(); +} + } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.h b/tensorflow/core/common_runtime/process_function_library_runtime.h index a1adc4b6b35950339b727774c45014ef71839554..0473e16d242814930a9de17c88d4851d0d73edbe 100644 --- a/tensorflow/core/common_runtime/process_function_library_runtime.h +++ b/tensorflow/core/common_runtime/process_function_library_runtime.h @@ -29,12 +29,13 @@ class ProcessFunctionLibraryRuntime { // Creates FunctionLibraryRuntime objects for each device in the provided // DeviceMgr. Caller needs to make sure that device_mgr, lib_def and parent // (if provided) outlive this object. - ProcessFunctionLibraryRuntime(const DeviceMgr* device_mgr, Env* env, - int graph_def_version, - const FunctionLibraryDefinition* lib_def, - const OptimizerOptions& optimizer_options, - DistributedFunctionLibraryRuntime* parent); + ProcessFunctionLibraryRuntime( + const DeviceMgr* device_mgr, Env* env, int graph_def_version, + const FunctionLibraryDefinition* lib_def, + const OptimizerOptions& optimizer_options, + DistributedFunctionLibraryRuntime* parent = nullptr); + // With `custom_kernel_creator`. ProcessFunctionLibraryRuntime(const DeviceMgr* device_mgr, Env* env, int graph_def_version, const FunctionLibraryDefinition* lib_def, @@ -42,17 +43,6 @@ class ProcessFunctionLibraryRuntime { CustomKernelCreator custom_kernel_creator, DistributedFunctionLibraryRuntime* parent); - ProcessFunctionLibraryRuntime(const DeviceMgr* device_mgr, Env* env, - int graph_def_version, - const FunctionLibraryDefinition* lib_def, - const OptimizerOptions& optimizer_options); - - ProcessFunctionLibraryRuntime(const DeviceMgr* device_mgr, Env* env, - int graph_def_version, - const FunctionLibraryDefinition* lib_def, - const OptimizerOptions& optimizer_options, - CustomKernelCreator custom_kernel_creator); - // Sends `tensors_to_send` from `source_device` to `target_device` using // `rendezvous`. `key_prefix` is used as a prefix for the keys sent to the // Rendezvous. `device_context` should be the DeviceContext of the device @@ -85,7 +75,7 @@ class ProcessFunctionLibraryRuntime { static const char kDefaultFLRDevice[]; // Returns the FunctionLibraryRuntime for the corresponding device_name. - FunctionLibraryRuntime* GetFLR(const string& device_name); + FunctionLibraryRuntime* GetFLR(const string& device_name) const; // Returns the device incarnation for the given device_name. Status GetDeviceIncarnation(const string& device_name, int64* incarnation); @@ -145,6 +135,12 @@ class ProcessFunctionLibraryRuntime { // Removes handle from the state owned by this object. Status RemoveHandle(FunctionLibraryRuntime::Handle handle); + Status Clone(Env* env, int graph_def_version, + const OptimizerOptions& optimizer_options, + CustomKernelCreator custom_kernel_creator, + std::unique_ptr* out_lib_def, + std::unique_ptr* out_pflr); + friend class FunctionLibraryRuntimeImpl; mutable mutex mu_; diff --git a/tensorflow/core/common_runtime/session_factory.cc b/tensorflow/core/common_runtime/session_factory.cc index 0234d4c37250d8ed3c645759dd17f94093e57df0..4dbe113e44ee0b7a6eba44ace3c1ff8daa17059f 100644 --- a/tensorflow/core/common_runtime/session_factory.cc +++ b/tensorflow/core/common_runtime/session_factory.cc @@ -60,8 +60,8 @@ const string RegisteredFactoriesErrorMessageLocked() { str_util::Join(factory_types, ", "), "}."); } string SessionOptionsToString(const SessionOptions& options) { - return strings::StrCat("target: \"", options.target, "\" config: ", - ProtoShortDebugString(options.config)); + return strings::StrCat("target: \"", options.target, + "\" config: ", ProtoShortDebugString(options.config)); } } // namespace diff --git a/tensorflow/core/common_runtime/step_stats_collector.cc b/tensorflow/core/common_runtime/step_stats_collector.cc index d7e01144c9ef3aa09ddd212947eafe48ccff555b..cb900db10af98496cfdfafa5a38296bfdc4e996b 100644 --- a/tensorflow/core/common_runtime/step_stats_collector.cc +++ b/tensorflow/core/common_runtime/step_stats_collector.cc @@ -226,22 +226,23 @@ void StepStatsCollector::BuildCostModel( if (node) { for (int i = 0; i < stats.output_size(); ++i) { const auto& output = stats.output(i); - cm->RecordMaxMemorySize(node, i, Bytes(output.tensor_description() - .allocation_description() - .allocated_bytes()), + cm->RecordMaxMemorySize(node, i, + Bytes(output.tensor_description() + .allocation_description() + .allocated_bytes()), stats.output(i).tensor_description().shape(), node->output_types()[i]); - cm->RecordAllocationId(node, i, output.tensor_description() - .allocation_description() - .allocation_id()); + cm->RecordAllocationId(node, i, + output.tensor_description() + .allocation_description() + .allocation_id()); } cm->RecordMemoryStats(node, stats.memory_stats()); // Use hardware stats to record the execution time if they're available, // otherwise use the regular (less accurate) stats string node_name = dev_stats.regular_stats->node_stats(i).node_name(); - if (dev_stats.hardware_stats && - name_to_hw_node_stats.find(node_name) != - name_to_hw_node_stats.end()) { + if (dev_stats.hardware_stats && name_to_hw_node_stats.find(node_name) != + name_to_hw_node_stats.end()) { const NodeExecStats& hw_stats = name_to_hw_node_stats[node_name]; cm->RecordMaxExecutionTime( node, Microseconds(hw_stats.op_end_rel_micros())); diff --git a/tensorflow/core/common_runtime/sycl/sycl_allocator.cc b/tensorflow/core/common_runtime/sycl/sycl_allocator.cc index 9094824ee734a9398db5aca2a507af4acd07c26b..02bd8b8f3bc692728ce73176f6268d95f860dc9b 100644 --- a/tensorflow/core/common_runtime/sycl/sycl_allocator.cc +++ b/tensorflow/core/common_runtime/sycl/sycl_allocator.cc @@ -80,7 +80,7 @@ void SYCLAllocator::ClearStats() override { size_t SYCLAllocator::RequestedSize(void* ptr) { mutex_lock lock(mu_); - if(!sycl_device_) { + if (!sycl_device_) { return 0; } const auto& buffer = sycl_device_->get_sycl_buffer(ptr); diff --git a/tensorflow/core/common_runtime/sycl/sycl_allocator.h b/tensorflow/core/common_runtime/sycl/sycl_allocator.h index cca9f92c62e2a4f4d57c8a6111b53dccee505f93..550f1933322420fc97da2bb588c719c73ea5ae4d 100644 --- a/tensorflow/core/common_runtime/sycl/sycl_allocator.h +++ b/tensorflow/core/common_runtime/sycl/sycl_allocator.h @@ -20,10 +20,10 @@ limitations under the License. #ifndef TENSORFLOW_COMMON_RUNTIME_SYCL_SYCL_ALLOCATOR_H_ #define TENSORFLOW_COMMON_RUNTIME_SYCL_SYCL_ALLOCATOR_H_ +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/types.h" -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" namespace tensorflow { @@ -56,14 +56,13 @@ class SYCLAllocator : public Allocator { // Clear the SYCL device used by the Allocator void ClearSYCLDevice() { mutex_lock lock(mu_); - if(sycl_device_) { + if (sycl_device_) { delete sycl_device_; sycl_device_ = nullptr; } } private: - mutable mutex mu_; Eigen::SyclDevice* sycl_device_ GUARDED_BY(mu_); // owned AllocatorStats stats_ GUARDED_BY(mu_); diff --git a/tensorflow/core/common_runtime/sycl/sycl_device.h b/tensorflow/core/common_runtime/sycl/sycl_device.h index cc272d156ef67a4f4f93f35603ffe301d154932a..7c09e0b8f194c7dc8a594aa487ec62e00d5b5e39 100644 --- a/tensorflow/core/common_runtime/sycl/sycl_device.h +++ b/tensorflow/core/common_runtime/sycl/sycl_device.h @@ -187,9 +187,9 @@ class GSYCLInterface { type = "Unknown"; } - return strings::StrCat("id: ", device_id, ", type: ", type, ", name: ", - name.c_str(), ", vendor: ", vendor.c_str(), - ", profile: ", profile.c_str()); + return strings::StrCat( + "id: ", device_id, ", type: ", type, ", name: ", name.c_str(), + ", vendor: ", vendor.c_str(), ", profile: ", profile.c_str()); } }; diff --git a/tensorflow/core/common_runtime/sycl/sycl_device_factory.cc b/tensorflow/core/common_runtime/sycl/sycl_device_factory.cc index 19c14770dcad7a3ca045ccb4ff68189c943d8cff..14f7727659d91db2373a1ac8ad0e46258cc32fbe 100644 --- a/tensorflow/core/common_runtime/sycl/sycl_device_factory.cc +++ b/tensorflow/core/common_runtime/sycl/sycl_device_factory.cc @@ -26,7 +26,6 @@ class SYCLDeviceFactory : public DeviceFactory { public: Status CreateDevices(const SessionOptions &options, const string &name_prefix, std::vector *devices) override { - auto syclInterface = GSYCLInterface::instance(); size_t n = 1; @@ -37,13 +36,11 @@ class SYCLDeviceFactory : public DeviceFactory { for (int i = 0; i < n; i++) { string name = strings::StrCat(name_prefix, "/device:SYCL:", i); - devices->push_back( - new SYCLDevice(options, name, Bytes(256 << 20), DeviceLocality() - , syclInterface->GetShortDeviceDescription(i) - , syclInterface->GetSYCLAllocator(i) - , syclInterface->GetCPUAllocator(i) - , syclInterface->GetSYCLContext(i)) - ); + devices->push_back(new SYCLDevice( + options, name, Bytes(256 << 20), DeviceLocality(), + syclInterface->GetShortDeviceDescription(i), + syclInterface->GetSYCLAllocator(i), syclInterface->GetCPUAllocator(i), + syclInterface->GetSYCLContext(i))); } return Status::OK(); @@ -51,6 +48,6 @@ class SYCLDeviceFactory : public DeviceFactory { }; REGISTER_LOCAL_DEVICE_FACTORY("SYCL", SYCLDeviceFactory, 200); -} +} // namespace tensorflow #endif // TENSORFLOW_USE_SYCL diff --git a/tensorflow/core/common_runtime/sycl/sycl_util.h b/tensorflow/core/common_runtime/sycl/sycl_util.h index 83016b706a57033bfdaec932f763bc118434db90..3124ed23c92eb542e90e6c077fc703fb84b38a18 100644 --- a/tensorflow/core/common_runtime/sycl/sycl_util.h +++ b/tensorflow/core/common_runtime/sycl/sycl_util.h @@ -20,8 +20,8 @@ limitations under the License. #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_SYCL_SYCL_UTIL_H_ #define TENSORFLOW_CORE_COMMON_RUNTIME_SYCL_SYCL_UTIL_H_ -#include "tensorflow/core/common_runtime/device.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/common_runtime/device.h" // For DMA helper #include "tensorflow/core/common_runtime/dma_helper.h" #include "tensorflow/core/framework/tensor.h" diff --git a/tensorflow/core/debug/debug_gateway.cc b/tensorflow/core/debug/debug_gateway.cc index 616ced3d0f3d9cfed683120e792b40eb9010fe06..2e1aabd1cc8066df6a5f7e6dd0aa27c6a16ef614 100644 --- a/tensorflow/core/debug/debug_gateway.cc +++ b/tensorflow/core/debug/debug_gateway.cc @@ -24,31 +24,31 @@ limitations under the License. namespace tensorflow { DebugGateway::DebugGateway(DirectSession* session) : session_(session) { - session_->node_outputs_callback_ = [this]( - const string& node_name, const int output_slot, const Tensor* tensor, - const bool is_ref, OpKernelContext* ctx) { - if (comp_cb_ != nullptr && output_slot <= 0) { - // The node completion callback is invoked once for a node regardless - // of whether the node has zero, one or more outputs. - // The output_slot can be negative (-1, or kControlSlot) if - // node_outputs_callback_ is invoked for a node with no output. If that - // is the case, notify the callback that the node in question has no - // output. - comp_cb_(node_name, output_slot == 0); - } - - // Copy tensor values (e.g., from GPU to host) only if the - // value callback is not nullptr. - if (val_cb_ != nullptr && output_slot >= 0) { - CopyTensor( - node_name, output_slot, tensor, ctx, - [this, node_name, output_slot, is_ref](const Tensor* copied_tensor) { - val_cb_(node_name, output_slot, *copied_tensor, is_ref); - }); - } - - return Status::OK(); - }; + session_->node_outputs_callback_ = + [this](const string& node_name, const int output_slot, + const Tensor* tensor, const bool is_ref, OpKernelContext* ctx) { + if (comp_cb_ != nullptr && output_slot <= 0) { + // The node completion callback is invoked once for a node regardless + // of whether the node has zero, one or more outputs. + // The output_slot can be negative (-1, or kControlSlot) if + // node_outputs_callback_ is invoked for a node with no output. If + // that is the case, notify the callback that the node in question has + // no output. + comp_cb_(node_name, output_slot == 0); + } + + // Copy tensor values (e.g., from GPU to host) only if the + // value callback is not nullptr. + if (val_cb_ != nullptr && output_slot >= 0) { + CopyTensor(node_name, output_slot, tensor, ctx, + [this, node_name, output_slot, + is_ref](const Tensor* copied_tensor) { + val_cb_(node_name, output_slot, *copied_tensor, is_ref); + }); + } + + return Status::OK(); + }; } DebugGateway::~DebugGateway() { @@ -86,7 +86,8 @@ void DebugGateway::CopyTensor(const string& node_name, const int output_slot, // Determine if the tensor is on device (GPU) or host (CPU). // The second part of the check is necessary because even an OpKernel on // may have output tensors allocated on CPU. - if ((device->name().find("GPU:") != string::npos || device->name().find("SYCL:") != string::npos) && + if ((device->name().find("GPU:") != string::npos || + device->name().find("SYCL:") != string::npos) && !ctx->output_alloc_attr(output_slot).on_host()) { // GPU tensors: Copy it to host (CPU). DeviceContext* device_ctxt = ctx->op_device_context(); diff --git a/tensorflow/core/debug/debug_gateway_test.cc b/tensorflow/core/debug/debug_gateway_test.cc index 57583349069a0b4deb137cb09564cdbb3909a4b0..b1bbd3f6980b16c13a1e5c9cd3a0f6c4bb8c1217 100644 --- a/tensorflow/core/debug/debug_gateway_test.cc +++ b/tensorflow/core/debug/debug_gateway_test.cc @@ -390,9 +390,9 @@ TEST_F(SessionDebugMinusAXTest, debug_gateway.SetNodeValueCallback( [this, &mu, &val_callback_count, &a_debug_identity_node_name, &x_debug_identity_node_name, &y_debug_identity_node_name, - &debug_identity_tensor_vals, &callbacks_done, &kConcurrentRuns]( - const string& node_name, const int output_slot, - const Tensor& tensor_value, const bool is_ref) { + &debug_identity_tensor_vals, &callbacks_done, + &kConcurrentRuns](const string& node_name, const int output_slot, + const Tensor& tensor_value, const bool is_ref) { mutex_lock l(mu); if (node_name == a_debug_identity_node_name && output_slot == 0) { @@ -560,21 +560,21 @@ TEST_F(SessionDebugOutputSlotWithoutOutgoingEdgeTest, Notification callbacks_done; std::vector debug_identity_tensor_vals; - debug_gateway.SetNodeValueCallback([this, &mu, &callbacks_done, - &debug_identity_node_name, - &debug_identity_tensor_vals]( - const string& node_name, const int output_slot, - const Tensor& tensor_value, const bool is_ref) { - mutex_lock l(mu); + debug_gateway.SetNodeValueCallback( + [this, &mu, &callbacks_done, &debug_identity_node_name, + &debug_identity_tensor_vals]( + const string& node_name, const int output_slot, + const Tensor& tensor_value, const bool is_ref) { + mutex_lock l(mu); - if (node_name == debug_identity_node_name && output_slot == 0) { - debug_identity_tensor_vals.push_back(tensor_value); + if (node_name == debug_identity_node_name && output_slot == 0) { + debug_identity_tensor_vals.push_back(tensor_value); - if (!callbacks_done.HasBeenNotified()) { - callbacks_done.Notify(); - } - } - }); + if (!callbacks_done.HasBeenNotified()) { + callbacks_done.Notify(); + } + } + }); // Add DebugIdentity watch on c:0, which does not have an outgoing edge. RunOptions run_opts; diff --git a/tensorflow/core/debug/debug_grpc_testlib.cc b/tensorflow/core/debug/debug_grpc_testlib.cc index a312f789d8444360a0892faa4b3a0f9a0bdf7a32..f70931e926507c72287588da278a3b8d6bb19122 100644 --- a/tensorflow/core/debug/debug_grpc_testlib.cc +++ b/tensorflow/core/debug/debug_grpc_testlib.cc @@ -30,7 +30,7 @@ namespace test { ::grpc::Status TestEventListenerImpl::SendEvents( ::grpc::ServerContext* context, - ::grpc::ServerReaderWriter< ::tensorflow::EventReply, ::tensorflow::Event>* + ::grpc::ServerReaderWriter<::tensorflow::EventReply, ::tensorflow::Event>* stream) { Event event; diff --git a/tensorflow/core/debug/debug_io_utils_test.cc b/tensorflow/core/debug/debug_io_utils_test.cc index 2f83c2415b831cc1a2b90d4e6a2046218e6fe5f6..0807a85b8b39cf8bf479227bd6b6bd581e2ba9b0 100644 --- a/tensorflow/core/debug/debug_io_utils_test.cc +++ b/tensorflow/core/debug/debug_io_utils_test.cc @@ -57,7 +57,8 @@ class DebugIOUtilsTest : public ::testing::Test { TEST_F(DebugIOUtilsTest, ConstructDebugNodeKey) { DebugNodeKey debug_node_key("/job:worker/replica:1/task:0/device:GPU:2", "hidden_1/MatMul", 0, "DebugIdentity"); - EXPECT_EQ("/job:worker/replica:1/task:0/device:GPU:2", debug_node_key.device_name); + EXPECT_EQ("/job:worker/replica:1/task:0/device:GPU:2", + debug_node_key.device_name); EXPECT_EQ("hidden_1/MatMul", debug_node_key.node_name); EXPECT_EQ(0, debug_node_key.output_slot); EXPECT_EQ("DebugIdentity", debug_node_key.debug_op); diff --git a/tensorflow/core/distributed_runtime/graph_mgr.h b/tensorflow/core/distributed_runtime/graph_mgr.h index d0ca2a625778ff73c6d40492cc5d02ec81ef3cc6..cc35264b8fe0b6decc325dab793c6a5fe6ad097f 100644 --- a/tensorflow/core/distributed_runtime/graph_mgr.h +++ b/tensorflow/core/distributed_runtime/graph_mgr.h @@ -140,7 +140,7 @@ class GraphMgr { GraphMgr* graph_mgr; }; - const WorkerEnv* worker_env_; // Not owned. + const WorkerEnv* worker_env_; // Not owned. DeviceMgr* device_mgr_; CostModelManager cost_model_manager_; diff --git a/tensorflow/core/distributed_runtime/master.cc b/tensorflow/core/distributed_runtime/master.cc index d1dc622ce79df1a98c3712e447a66bad3baecba1..1a488303ac73b8628b9d3fe4050ad9144724348e 100644 --- a/tensorflow/core/distributed_runtime/master.cc +++ b/tensorflow/core/distributed_runtime/master.cc @@ -528,8 +528,8 @@ void Master::ListDevices(const ListDevicesRequest* req, auto session = FindMasterSession(req->session_handle()); if (session == nullptr) { done(errors::InvalidArgument( - "Session ", req->session_handle(), - " is not found. Possibly, this master has restarted.")); + "Session ", req->session_handle(), + " is not found. Possibly, this master has restarted.")); return; } core::ScopedUnref ref(session); diff --git a/tensorflow/core/distributed_runtime/master_test.cc b/tensorflow/core/distributed_runtime/master_test.cc index 121c58762f10a87fea059ce43b190f70e49e1f64..f2c1f3489c388d6a5fff729b1c8f98136532105c 100644 --- a/tensorflow/core/distributed_runtime/master_test.cc +++ b/tensorflow/core/distributed_runtime/master_test.cc @@ -61,7 +61,7 @@ class MasterTest : public ::testing::Test { // rpc calls. Status CreateSession(const GraphDef& def, string* handle, - int64* initial_version) { + int64* initial_version) { ::grpc::ClientContext ctx; CreateSessionRequest req; *(req.mutable_graph_def()) = def; @@ -77,7 +77,7 @@ class MasterTest : public ::testing::Test { } Status ExtendSession(const string& handle, const GraphDef& def, - int64 current_version, int64* new_version) { + int64 current_version, int64* new_version) { ::grpc::ClientContext ctx; ExtendSessionRequest req; req.set_session_handle(handle); diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_master_service.cc b/tensorflow/core/distributed_runtime/rpc/grpc_master_service.cc index ac279937730466514451d7e81257d2110e128eff..b4d18d8607eaddd75f4e395e71fbd75554645a61 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_master_service.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_master_service.cc @@ -185,23 +185,22 @@ class GrpcMasterService : public AsyncServiceInterface { MutableRunStepResponseWrapper* wrapped_response = new NonOwnedProtoRunStepResponse(&call->response); call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); }); - master_impl_->RunStep(call_opts, wrapped_request, wrapped_response, - [call, call_opts, wrapped_request, wrapped_response, - trace](const Status& status) { - call->ClearCancelCallback(); - delete call_opts; - delete wrapped_request; - delete trace; - if (call->request.store_errors_in_response_body() && - !status.ok()) { - call->response.set_status_code(status.code()); - call->response.set_status_error_message( - status.error_message()); - call->SendResponse(ToGrpcStatus(Status::OK())); - } else { - call->SendResponse(ToGrpcStatus(status)); - } - }); + master_impl_->RunStep( + call_opts, wrapped_request, wrapped_response, + [call, call_opts, wrapped_request, wrapped_response, + trace](const Status& status) { + call->ClearCancelCallback(); + delete call_opts; + delete wrapped_request; + delete trace; + if (call->request.store_errors_in_response_body() && !status.ok()) { + call->response.set_status_code(status.code()); + call->response.set_status_error_message(status.error_message()); + call->SendResponse(ToGrpcStatus(Status::OK())); + } else { + call->SendResponse(ToGrpcStatus(status)); + } + }); ENQUEUE_REQUEST(RunStep, true); } diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.h b/tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.h index 4e203e260a1a370cc2bc7e40c3ce9e84da4d3ad4..6ae94b74417c3fb6c4da1589bb9f532cb6d79930 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.h +++ b/tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.h @@ -89,9 +89,9 @@ class MasterService final { ::grpc::Status ExtendSession(::grpc::ClientContext* context, const ExtendSessionRequest& request, ExtendSessionResponse* response) override; - ::grpc::Status PartialRunSetup( - ::grpc::ClientContext* context, const PartialRunSetupRequest& request, - PartialRunSetupResponse* response) override; + ::grpc::Status PartialRunSetup(::grpc::ClientContext* context, + const PartialRunSetupRequest& request, + PartialRunSetupResponse* response) override; ::grpc::Status RunStep(::grpc::ClientContext* context, const RunStepRequest& request, RunStepResponse* response) override; diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_remote_master.cc b/tensorflow/core/distributed_runtime/rpc/grpc_remote_master.cc index 70418f63686843414dca6c5ae4907ee263dc2904..1088e9be66ceb7fbddfaed0691423745f362343f 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_remote_master.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_remote_master.cc @@ -69,8 +69,7 @@ class GrpcRemoteMaster : public MasterInterface { ::grpc::ClientContext ctx; auto trace = TraceRpc("RunStep/Client", &ctx); return Call(&ctx, call_options, &request->ToProto(), - get_proto_from_wrapper(response), - &MasterServiceStub::RunStep); + get_proto_from_wrapper(response), &MasterServiceStub::RunStep); } Status CloseSession(CallOptions* call_options, @@ -114,8 +113,9 @@ class GrpcRemoteMaster : public MasterInterface { template Status Call(::grpc::ClientContext* ctx, CallOptions* call_options, const Request* request, Response* response, - ::grpc::Status (MasterServiceStub::*pfunc)( - ::grpc::ClientContext*, const Request&, Response*)) { + ::grpc::Status (MasterServiceStub::*pfunc)(::grpc::ClientContext*, + const Request&, + Response*)) { ctx->set_fail_fast(false); SetDeadline(ctx, call_options->GetTimeout()); return FromGrpcStatus((stub_.get()->*pfunc)(ctx, *request, response)); diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_serialization_traits.h b/tensorflow/core/distributed_runtime/rpc/grpc_serialization_traits.h index dd114d39c62f6b69a3fb9ea4401459f963137a1f..730124c25e9a3e8d102a9dd39e4c4a17f2ce39d1 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_serialization_traits.h +++ b/tensorflow/core/distributed_runtime/rpc/grpc_serialization_traits.h @@ -66,7 +66,7 @@ class GrpcBufferWriter final } // It's dangerous to keep an inlined grpc_slice as the backup slice, since // on a following Next() call, a reference will be returned to this slice - // via GRPC_SLICE_START_PTR, which will not be an adddress held by + // via GRPC_SLICE_START_PTR, which will not be an address held by // slice_buffer_. have_backup_ = backup_slice_.refcount != NULL; byte_count_ -= count; diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_testlib_ops.cc b/tensorflow/core/distributed_runtime/rpc/grpc_testlib_ops.cc index 373eecffcab1dded60de7ffea96ba58208bb692c..5597ee7a76a55f125dd0db82eceb58f5e922ab13 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_testlib_ops.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_testlib_ops.cc @@ -21,11 +21,8 @@ namespace tensorflow { namespace test { // ErrorOp::Compute returns an error. -REGISTER_OP("Error") - .Input("in: T") - .Output("out: T") - .Attr("T: type") - .Attr("message: string"); +REGISTER_OP("Error").Input("in: T").Output("out: T").Attr("T: type").Attr( + "message: string"); class ErrorOp : public OpKernel { public: explicit ErrorOp(OpKernelConstruction* ctx) : OpKernel(ctx) { @@ -66,11 +63,8 @@ REGISTER_KERNEL_BUILDER(Name("InvalidRefType").Device(DEVICE_CPU), // DelayOp::AsyncCompute sleeps for "micros"-econd and then returns // its input. -REGISTER_OP("Delay") - .Input("in: T") - .Output("out: T") - .Attr("T: type") - .Attr("micros: int"); +REGISTER_OP("Delay").Input("in: T").Output("out: T").Attr("T: type").Attr( + "micros: int"); class DelayOp : public AsyncOpKernel { public: explicit DelayOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) { diff --git a/tensorflow/core/distributed_runtime/rpcbench_test.cc b/tensorflow/core/distributed_runtime/rpcbench_test.cc index b2668fae25a8a6bc60b37ddfaa83b8b523c3a6f5..d3af7417e61105c788b8029c84c222e49a0d2830 100644 --- a/tensorflow/core/distributed_runtime/rpcbench_test.cc +++ b/tensorflow/core/distributed_runtime/rpcbench_test.cc @@ -184,8 +184,8 @@ static void BM_Helper(int iters, int width, int num_stages, int tensor_size, testing::SetLabel( strings::StrCat(def.node_size(), " nodes; ", - use_multiple_devices ? "Multi device" : "Single device", - "; tensor bytes/send: ", tensor_size * sizeof(float))); + use_multiple_devices ? "Multi device" : "Single device", + "; tensor bytes/send: ", tensor_size * sizeof(float))); std::vector outputs; diff --git a/tensorflow/core/distributed_runtime/scheduler.cc b/tensorflow/core/distributed_runtime/scheduler.cc index 4766f4c33b654481f7d99ab82939e33e77564771..9dae5b3b926fab14c2b36955436d3956baa29fdd 100644 --- a/tensorflow/core/distributed_runtime/scheduler.cc +++ b/tensorflow/core/distributed_runtime/scheduler.cc @@ -17,9 +17,9 @@ limitations under the License. #include -#include "tensorflow/core/graph/graph.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device_set.h" +#include "tensorflow/core/graph/graph.h" #include "tensorflow/core/util/util.h" namespace tensorflow { diff --git a/tensorflow/core/distributed_runtime/scheduler.h b/tensorflow/core/distributed_runtime/scheduler.h index eabcaccdd1e6c1a732f8871bc9da6265bd9a8dd8..ef87b9834dba50cf628a8c29c70b0266661d6227 100644 --- a/tensorflow/core/distributed_runtime/scheduler.h +++ b/tensorflow/core/distributed_runtime/scheduler.h @@ -16,15 +16,15 @@ limitations under the License. #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_SCHEDULER_H_ #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_SCHEDULER_H_ -#include #include +#include #include #include #include -#include "tensorflow/core/graph/costmodel.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device_set.h" +#include "tensorflow/core/graph/costmodel.h" namespace tensorflow { diff --git a/tensorflow/core/distributed_runtime/tensor_coding.cc b/tensorflow/core/distributed_runtime/tensor_coding.cc index fe2d1a12934dde814344b70f52fbc972f74347e0..34a4013547b5feef12b49198bff4e733f1b9e932 100644 --- a/tensorflow/core/distributed_runtime/tensor_coding.cc +++ b/tensorflow/core/distributed_runtime/tensor_coding.cc @@ -81,7 +81,7 @@ void TensorResponse::InitPartial(const RecvTensorResponse& response) { Status TensorResponse::ParseFrom(Source* source) { if (!on_host_) { protobuf::io::CodedInputStream input(source->contents()); - input.SetTotalBytesLimit(INT_MAX, INT_MAX); // Unlimited + input.SetTotalBytesLimit(INT_MAX); // Unlimited // Pre-parse into local storage, then delegate to device. if (!meta_.ParseFromCodedStream(&input) || !input.ConsumedEntireMessage()) { @@ -217,7 +217,7 @@ bool TensorResponse::ParseTensorSubmessage( bool TensorResponse::ParseFast(Source* source) { protobuf::io::CodedInputStream input(source->contents()); - input.SetTotalBytesLimit(INT_MAX, INT_MAX); // Unlimited + input.SetTotalBytesLimit(INT_MAX); // Unlimited while (true) { auto p = input.ReadTagWithCutoff(127); int tag = GetTagFieldNumber(p.first); diff --git a/tensorflow/core/distributed_runtime/worker_cache_logger.cc b/tensorflow/core/distributed_runtime/worker_cache_logger.cc index 702af78c88014d54fe2f72a8266e5e7e43b3cfb9..95ca3c3b4d11fac0d103eb52f19d5b0b2f4ad3ea 100644 --- a/tensorflow/core/distributed_runtime/worker_cache_logger.cc +++ b/tensorflow/core/distributed_runtime/worker_cache_logger.cc @@ -97,9 +97,8 @@ void WorkerCacheLogger::RecordDataTransfer(int64 step_id, int64 start_usecs, const string& tensor_name, const string& src_device, const string& dst_device, - int64 bytes, - const string& details, - const string& transfer_method_name){ + int64 bytes, const string& details, + const string& transfer_method_name) { NodeExecStats* ns = new NodeExecStats; ns->set_node_name(transfer_method_name); if (details.empty()) { diff --git a/tensorflow/core/framework/dataset.h b/tensorflow/core/framework/dataset.h index 2c2c7e7c585c9364e1d08280d5fe76f1bf1eff23..96566c285a2c2407497ab1126055d6d873925b6d 100644 --- a/tensorflow/core/framework/dataset.h +++ b/tensorflow/core/framework/dataset.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -12,64 +12,603 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#ifndef TENSORFLOW_CORE_FRAMEWORK_DATASET_H_ +#define TENSORFLOW_CORE_FRAMEWORK_DATASET_H_ -#ifndef TENSORFLOW_FRAMEWORK_DATASET_H_ -#define TENSORFLOW_FRAMEWORK_DATASET_H_ +#include + +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/attr_value_util.h" +#include "tensorflow/core/framework/dataset_stateful_op_whitelist.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/framework/variant_encode_decode.h" +#include "tensorflow/core/framework/variant_tensor_data.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/tracing.h" + +// Polymorphic datasets should support all primitive TensorFlow +// types. Use this macro to expand `m(T)` once for each primitive type +// `T`, e.g. to build a `switch` statement. +#define TF_CALL_DATASET_TYPES(m) TF_CALL_ALL_TYPES(m) TF_CALL_QUANTIZED_TYPES(m) namespace tensorflow { -namespace dataset { -// Registry for stateful ops that need to be used in dataset functions. -// See below macro for usage details. -class WhitelistedStatefulOpRegistry { + +// Interface for reading values from a key-value store. +// Used for restoring iterator state. +class IteratorStateReader { + public: + virtual Status ReadScalar(StringPiece key, int64* val) = 0; + virtual Status ReadScalar(StringPiece key, string* val) = 0; + virtual Status ReadTensor(StringPiece key, Tensor* val) = 0; + virtual bool Contains(StringPiece key) = 0; + + virtual ~IteratorStateReader() {} +}; + +// Interface for writing values to a key-value store. +// Used for saving iterator state. +class IteratorStateWriter { + public: + virtual Status WriteScalar(StringPiece key, const int64 val) = 0; + virtual Status WriteScalar(StringPiece key, const string& val) = 0; + virtual Status WriteTensor(StringPiece key, const Tensor& val) = 0; + + virtual ~IteratorStateWriter() {} +}; + +// Forward declarations to avoid introducing a dependency on headers in +// "tensorflow/core/graph/...". +class GraphDefBuilder; +class GraphDatasetBase; +class Node; + +// Wrapper around GraphDefBuilder. Used to serialize Dataset graph. +class GraphDefBuilderWrapper { public: - Status Add(StringPiece op_name) { - op_names_.insert(op_name); + explicit GraphDefBuilderWrapper(GraphDefBuilder* b) : b_(b) {} + + // Adds a Const node with scalar value to the Graph. + // `*output` contains a pointer to the output `Node`. It is guaranteed to be + // non-null if the method returns with an OK status. + // The returned Node pointer is owned by the backing Graph of GraphDefBuilder. + template + Status AddScalar(const T& val, Node** output) { + Tensor val_t = Tensor(DataTypeToEnum::v(), TensorShape({})); + val_t.scalar()() = val; + AddTensorInternal(val_t, output); + if (*output == nullptr) { + return errors::Internal("AddScalar: Failed to build Const op."); + } return Status::OK(); } - bool Contains(StringPiece op_name) { - return op_names_.find(op_name) != op_names_.end(); + // Adds a Const node with vector value to the Graph. + // `*output` contains a pointer to the output `Node`. It is guaranteed to be + // non-null if the method returns with an OK status. + // The returned Node pointer is owned by the backing Graph of GraphDefBuilder. + // TODO(shivaniagrawal): Consider changing to gtl::ArraySlice? + template + Status AddVector(const std::vector& val, Node** output) { + Tensor val_t = Tensor(DataTypeToEnum::v(), + TensorShape({static_cast(val.size())})); + for (int i = 0; i < val.size(); i++) { + val_t.flat()(i) = val[i]; + } + AddTensorInternal(val_t, output); + if (*output == nullptr) { + return errors::Internal("AddVector: Failed to build Const op."); + } + return Status::OK(); } - static WhitelistedStatefulOpRegistry* Global() { - static WhitelistedStatefulOpRegistry* reg = - new WhitelistedStatefulOpRegistry; - return reg; + // Adds a Const node with Tensor value to the Graph. + // `*output` contains a pointer to the output `Node`. It is guaranteed to be + // non-null if the method returns with an OK status. + // The returned Node pointer is owned by the backing Graph of GraphDefBuilder. + Status AddTensor(const Tensor& val, Node** output) { + AddTensorInternal(val, output); + if (*output == nullptr) { + return errors::Internal("AddTensor: Failed to build Const op."); + } + return Status::OK(); + } + + Status AddDataset(const GraphDatasetBase* dataset, + const std::vector& inputs, Node** output) { + return AddDataset(dataset, inputs, {}, output); + } + + // Adds a node corresponding to the `DatasetType` to the Graph. + // Return value of `DatasetType::op_name()` is used as the op type for the + // node. + // Values for the output_types and output_shapes node attributes are also + // written if those attributes are defined in the OpDef. + // `*output` contains a pointer to the output `Node`. It is guaranteed to be + // non-null if the method returns with an OK status. + // The returned Node pointer is owned by the backing Graph of GraphDefBuilder. + Status AddDataset(const GraphDatasetBase* dataset, + const std::vector& inputs, + const std::vector>& attrs, + Node** output) { + std::vector> enumerated_inputs(inputs.size()); + for (int i = 0; i < inputs.size(); i++) { + enumerated_inputs[i] = std::make_pair(i, inputs[i]); + } + return AddDataset(dataset, enumerated_inputs, {}, attrs, output); + } + + Status AddDataset( + const GraphDatasetBase* dataset, + const std::vector>& inputs, + const std::vector>>& list_inputs, + const std::vector>& attrs, + Node** output); + + // Adds a user-defined function with name `function_name` to the graph and + // recursively adds all functions it references. If a function with a matching + // name has already been added, returns with OK status. If a user-defined with + // name `function_name` is not found in the FunctionLibraryDefinition, returns + // an InvalidArgumentError. If the function with name `function_name` or any + // of its dependent functions are stateful, returns an InvalidArgument error. + Status AddFunction(OpKernelContext* ctx, const string& function_name); + + template + void BuildAttrValue(const T& value, AttrValue* attr) { + SetAttrValue(value, attr); } private: - WhitelistedStatefulOpRegistry() {} - WhitelistedStatefulOpRegistry(WhitelistedStatefulOpRegistry const& copy); - WhitelistedStatefulOpRegistry operator=( - WhitelistedStatefulOpRegistry const& copy); - std::set op_names_; + void AddTensorInternal(const Tensor& val, Node** output); + + Status EnsureFunctionIsStateless(OpKernelContext* ctx, + const string& function_name) const { + const FunctionLibraryDefinition* lib_def = + ctx->function_library()->GetFunctionLibraryDefinition(); + const FunctionDef* function_def = lib_def->Find(function_name); + if (!function_def) { + return errors::InvalidArgument("Unable to find FunctionDef for ", + function_name, " in registry."); + } + for (const NodeDef& node_def : function_def->node_def()) { + const OpDef* op_def; + TF_RETURN_IF_ERROR(lib_def->LookUpOpDef(node_def.op(), &op_def)); + // TODO(b/65524810): Hack to allow functions to capture Dataset op + // nodes needed for FlatMap. Currently, source datasets nodes have been + // marked stateful to avoid constant folding since we do not have a + // good way of serializing them. + if (IsOpWhitelisted(op_def)) { + continue; + } + if (op_def->is_stateful()) { + return errors::InvalidArgument( + "Op[name: ", node_def.name(), ", type: ", node_def.op(), "] ", + "in function ", function_name, " is stateful. ", + "Saving stateful functions is not supported yet."); + } + } + return Status::OK(); + } + + // Returns whether an op has been whitelisted for use inside map_fns. + // Uses a heuristic to whitelist source dataset ops which have been + // marked stateful due to b/65524810. + // Also looks up the `op_def->name` in the global + // `WhitelistedStatefulOpRegistry`. + bool IsOpWhitelisted(const OpDef* op_def) const { + return (StringPiece(op_def->name()).ends_with("Dataset") && + op_def->output_arg_size() == 1 && + op_def->output_arg(0).type() == DT_VARIANT) || + dataset::WhitelistedStatefulOpRegistry::Global()->Contains( + op_def->name()); + } + + bool HasAttr(const string& op_type_name, const string& attr_name) const; + + bool HasAttr(const OpDef* op_def, const string& attr_name) const { + for (auto attr : op_def->attr()) { + if (attr.name() == attr_name) { + return true; + } + } + return false; + } + + Status AddAttrFunctions(const AttrValue& attr_value, OpKernelContext* ctx) { + if (attr_value.has_func()) { + TF_RETURN_IF_ERROR(AddFunction(ctx, attr_value.func().name())); + } else if (attr_value.has_list()) { + for (const NameAttrList& name_attr_list : attr_value.list().func()) { + TF_RETURN_IF_ERROR(AddFunction(ctx, name_attr_list.name())); + } + } + return Status::OK(); + } + + GraphDefBuilder* b_; }; -} // namespace dataset - -// Use this macro to whitelist an op that is marked stateful but needs to be -// used inside a map_fn in an input pipeline. This is only needed if you wish -// to be able to checkpoint the state of the input pipeline. We currently -// do not allow stateful ops to be defined inside of map_fns since it is not -// possible to save their state. -// Note that the state of the whitelisted ops inside functions will not be -// saved during checkpointing, hence this should only be used if the op is -// marked stateful for reasons like to avoid constant folding during graph -// optimiztion but is not stateful. -// If possible, try to remove the stateful flag on the op first. -// Example usage: +class StatsAggregator; + +// A cut-down version of OpKernelContext for running computations in +// iterators. Note that we cannot simply use OpKernelContext here +// because we might run computation in an iterator whose lifetime is +// not nested within the lifetime of a single OpKernelContext +// (e.g. asynchronous prefetching). // -// WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS("LegacyStatefulReader"); +// TODO(mrry): We will probably need to support more of +// OpKernelContext here. For example, should allocation be handled by +// the IteratorContext? +// TODO(mrry): We're making some daring assumptions about the lifetime +// of the runner passed in here. A runner will be deleted when the original +// step ends, but all existing runners only close over session-lifetime (or +// longer-lived) state, so we can make a copy of the function. There's nothing +// in the definition of the API from which we took the runner to guarantee that +// what we are doing is safe. We should formalize the properties here. +class IteratorContext { + public: + struct Params { + // Interface to operating system functionality. + Env* env; + + // Function call support. + std::function)> runner = nullptr; + + // A function that returns the current `StatsAggregator` instance to be + // used when recording statistics about the iterator. + // + // NOTE(mrry): This is somewhat awkward, because (i) the `StatsAggregator` + // is a property of the `IteratorResource` (which this class does not know + // about), and (ii) it can change after the `IteratorContext` has been + // created. Better suggestions are welcome! + std::function()> stats_aggregator_getter = + nullptr; + + // The FunctionLibraryRuntime object to be used to make function calls. + FunctionLibraryRuntime* lib = nullptr; + std::shared_ptr function_library = nullptr; + + // The Allocator to be used to allocate the output of an iterator. + Allocator* allocator = nullptr; + }; + + explicit IteratorContext(Params params) : params_(std::move(params)) {} + + Env* env() const { return params_.env; } + + std::function)>* runner() { + return ¶ms_.runner; + } + + std::shared_ptr stats_aggregator() { + if (params_.stats_aggregator_getter) { + return params_.stats_aggregator_getter(); + } else { + return nullptr; + } + } + + std::shared_ptr function_library() { + return params_.function_library; + } + + FunctionLibraryRuntime* lib() { return params_.lib; } + + void set_lib(FunctionLibraryRuntime* lib) { params_.lib = lib; } + + Allocator* allocator(AllocatorAttributes attrs); + + private: + Params params_; +}; + +// Represents the current position in a range of outputs, where the +// range of outputs is typically represented by an `DatasetBase`, +// defined below. +class IteratorBase { + public: + virtual ~IteratorBase() {} + + // Gets the next output from the range that this iterator is traversing. + // + // If at least one output remains in this iterator's range, that + // output will be stored in `*out_tensors` and `false` will be + // stored in `*end_of_sequence`. + // + // If no more outputs remain in this iterator's range, `true` will + // be stored in `*end_of_sequence`, and the content of + // `*out_tensors` will be undefined. + // + // This method is thread-safe. + // + // TODO(mrry): Define `GetNextAsync()` or `GetNextManyAsync()`, and + // potentially remove this method. + virtual Status GetNext(IteratorContext* ctx, std::vector* out_tensors, + bool* end_of_sequence) = 0; + + // Returns a vector of DataType values, representing the respective + // element types of each tuple component in the outputs of this + // iterator. + virtual const DataTypeVector& output_dtypes() const = 0; + + // Returns a vector of tensor shapes, representing the respective + // (and possibly partially defined) shapes of each tuple component + // in the outputs of this iterator. + virtual const std::vector& output_shapes() const = 0; + + // Saves the state of this iterator. + virtual Status Save(OpKernelContext* ctx, IteratorStateWriter* writer) { + return SaveInternal(writer); + } + + // Restores the state of this iterator. + virtual Status Restore(IteratorContext* ctx, IteratorStateReader* reader) { + return RestoreInternal(ctx, reader); + } + + protected: + // This is needed so that sub-classes of IteratorBase can call + // `SaveInternal` on their parent iterators, e.g., in + // `RepeatDataasetOp::Dataset`. + Status SaveParent(IteratorStateWriter* writer, + const std::unique_ptr& parent) { + return parent->SaveInternal(writer); + } + + // This is needed so that sub-classes of IteratorBase can call + // `RestoreInternal` on their parent iterators, e.g., in + // `RepeatDataasetOp::Dataset`. + Status RestoreParent(IteratorContext* ctx, IteratorStateReader* reader, + const std::unique_ptr& parent) { + return parent->RestoreInternal(ctx, reader); + } + + // Saves the state of this iterator recursively. + virtual Status SaveInternal(IteratorStateWriter* writer) { + return errors::Unimplemented("SaveInternal"); + } + + // Restores the state of this iterator recursively. + virtual Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) { + return errors::Unimplemented("RestoreInternal"); + } +}; + +// Represents a (potentially infinite) range of outputs, where each +// output is a tuple of tensors. +class DatasetBase : public core::RefCounted { + public: + // Returns a new iterator for iterating over the range of elements in + // this dataset. + // + // This method may be called multiple times on the same instance, + // and the resulting iterators will have distinct state. Each + // iterator will traverse all elements in this dataset from the + // start. + // + // Ownership of the created iterator will be transferred to the caller. + // + // The prefix identifies the sequence of iterators leading up to the newly + // created iterator. + virtual std::unique_ptr MakeIterator( + const string& prefix) const = 0; + + // Returns a vector of DataType values, representing the respective + // element types of each tuple component in the outputs of this + // dataset. + virtual const DataTypeVector& output_dtypes() const = 0; + + // Returns a vector of tensor shapes, representing the respective + // (and possibly partially defined) shapes of each tuple component + // in the outputs of this dataset. + virtual const std::vector& output_shapes() const = 0; + + // A human-readable debug string for this dataset. + virtual string DebugString() = 0; + + // Serializes the dataset and writes it to the `writer`. + virtual Status Save(OpKernelContext* ctx, IteratorStateWriter* writer) const { + return errors::Unimplemented("DatasetBase::Save"); + } + + protected: + // TODO(srbs): Ideally all graph related logic should reside in + // GraphDatasetBase. However, that would require Datasets defined in all ops + // to derive from GraphDatasetBase. Once that is done we can move + // DatasetGraphDefBuilder and AsGraphDefInternal to GraphDatasetBase. + class DatasetGraphDefBuilder : public GraphDefBuilderWrapper { + public: + DatasetGraphDefBuilder(GraphDefBuilder* b) : GraphDefBuilderWrapper(b) {} + Status AddParentDataset(OpKernelContext* ctx, const DatasetBase* dataset, + Node** output) { + return dataset->AsGraphDefInternal(ctx, this, output); + } + }; + + virtual Status AsGraphDefInternal(OpKernelContext* ctx, + DatasetGraphDefBuilder* b, + Node** node) const { + return AsGraphDefInternal(b, node); + } + + virtual Status AsGraphDefInternal(DatasetGraphDefBuilder* b, + Node** node) const { + return errors::Unimplemented("AsGraphDefInternal"); + } +}; + +// Base-class for datasets that are built by ops. +class GraphDatasetBase : public DatasetBase { + public: + GraphDatasetBase(OpKernelContext* ctx) + : op_name_(ctx->op_kernel().type_string()) {} + + const string op_name() const { return op_name_; } + + Status Save(OpKernelContext* ctx, + IteratorStateWriter* writer) const override { + string serialized_graph_def; + string output_node; + TF_RETURN_IF_ERROR(Serialize(ctx, &serialized_graph_def, &output_node)); + TF_RETURN_IF_ERROR( + writer->WriteScalar(kDatasetGraphKey, serialized_graph_def)); + TF_RETURN_IF_ERROR( + writer->WriteScalar(kDatasetGraphOutputNodeKey, output_node)); + return Status::OK(); + } + + // Key for storing the Dataset graph in the serialized format. + static const char kDatasetGraphKey[]; + + // Key for storing the output node of the Dataset graph in the serialized + // format. + static const char kDatasetGraphOutputNodeKey[]; + + private: + Status Serialize(OpKernelContext* ctx, string* serialized_graph_def, + string* output_node) const; + + const string op_name_; +}; + +// Represents an iterator that is associated with a particular parent dataset. +template +class DatasetIterator : public IteratorBase { + public: + struct Params { + // Owns one reference on the shared dataset resource. + const DatasetType* dataset; + + // Identifies the sequence of iterators leading up to this iterator. + const string prefix; + }; + + explicit DatasetIterator(const Params& params) : params_(params) { + params_.dataset->Ref(); + } + + ~DatasetIterator() override { params_.dataset->Unref(); } + + // The dataset from which this iterator was created. + const DatasetType* dataset() const { return params_.dataset; } + + // The sequence of iterators leading up to this iterator. + const string prefix() const { return params_.prefix; } + + const DataTypeVector& output_dtypes() const override { + return params_.dataset->output_dtypes(); + } + + const std::vector& output_shapes() const override { + return params_.dataset->output_shapes(); + } + + Status GetNext(IteratorContext* ctx, std::vector* out_tensors, + bool* end_of_sequence) final { + port::Tracing::TraceMe activity(params_.prefix); + Status s = GetNextInternal(ctx, out_tensors, end_of_sequence); + if (TF_PREDICT_FALSE(errors::IsOutOfRange(s) && !*end_of_sequence)) { + s = errors::Internal( + "Iterator \"", params_.prefix, + "\" returned OutOfRange without setting `*end_of_sequence`. This " + "indicates that an error may have occurred. Original message: ", + s.error_message()); + LOG(ERROR) << s; + } + return s; + } + + Status Save(OpKernelContext* ctx, IteratorStateWriter* writer) final { + TF_RETURN_IF_ERROR(dataset()->Save(ctx, writer)); + return IteratorBase::Save(ctx, writer); + } + + protected: + // Internal implementation of GetNext that is wrapped in tracing logic. + virtual Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) = 0; + + string full_name(const string& name) const { + return strings::StrCat(prefix(), ":", name); + } + + private: + Params params_; +}; + +// Encapsulates the work required to plug a DatasetBase into the core TensorFlow +// graph execution engine. +class DatasetOpKernel : public OpKernel { + public: + DatasetOpKernel(OpKernelConstruction* ctx) : OpKernel(ctx) {} + void Compute(OpKernelContext* ctx) final; + + protected: + // Subclasses should implement this method. It will be called during Compute + // execution. + virtual void MakeDataset(OpKernelContext* ctx, DatasetBase** output) = 0; + + template + Status ParseScalarArgument(OpKernelContext* ctx, + const StringPiece& argument_name, T* output) { + const Tensor* argument_t; + TF_RETURN_IF_ERROR(ctx->input(argument_name, &argument_t)); + if (!TensorShapeUtils::IsScalar(argument_t->shape())) { + return errors::InvalidArgument(argument_name, " must be a scalar"); + } + *output = argument_t->scalar()(); + return Status::OK(); + } +}; + +// Encapsulates the work required to plug unary Datasets into the core +// TensorFlow graph execution engine. +class UnaryDatasetOpKernel : public DatasetOpKernel { + public: + UnaryDatasetOpKernel(OpKernelConstruction* ctx) : DatasetOpKernel(ctx) {} + + protected: + void MakeDataset(OpKernelContext* ctx, DatasetBase** output) final; + virtual void MakeDataset(OpKernelContext* ctx, DatasetBase* input, + DatasetBase** output) = 0; +}; + +// Encapsulates the work required to plug binary Datasets into the core +// TensorFlow graph execution engine. +class BinaryDatasetOpKernel : public DatasetOpKernel { + public: + BinaryDatasetOpKernel(OpKernelConstruction* ctx) : DatasetOpKernel(ctx) {} + + protected: + void MakeDataset(OpKernelContext* ctx, DatasetBase** output) final; + virtual void MakeDataset(OpKernelContext* ctx, DatasetBase* input, + DatasetBase* another_input, + DatasetBase** output) = 0; +}; + +// Validates and extracts a `DatasetBase` object from `tensor`. +// +// `tensor` must have been written by a call to SetVariantTensorToDataset(). +// +// The retrieved pointer is a borrowed reference to the dataset, which is owned +// by the tensor. The consumer must either acquire its own reference to the +// dataset by calling `(*out_dataset)->Ref()`, or ensure that `tensor` is not +// destroyed or mutated while the retrieved pointer is in use. +Status GetDatasetFromVariantTensor(const Tensor& tensor, + DatasetBase** out_dataset); + +// Stores a `DatasetBase` object in `tensor`. // -#define WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS(name) \ - WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS_UNIQ_HELPER(__COUNTER__, name) -#define WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS_UNIQ_HELPER(ctr, name) \ - WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS_UNIQ(ctr, name) -#define WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS_UNIQ(ctr, name) \ - static ::tensorflow::Status whitelist_op##ctr TF_ATTRIBUTE_UNUSED = \ - ::tensorflow::dataset::WhitelistedStatefulOpRegistry::Global()->Add( \ - name) +// The ownership of `dataset` is transferred to `tensor`. +Status StoreDatasetInVariantTensor(DatasetBase* dataset, Tensor* tensor); } // namespace tensorflow -#endif // TENSORFLOW_FRAMEWORK_DATASET_H_ +#endif // TENSORFLOW_CORE_FRAMEWORK_DATASET_H_ diff --git a/tensorflow/core/framework/dataset_stateful_op_whitelist.h b/tensorflow/core/framework/dataset_stateful_op_whitelist.h new file mode 100644 index 0000000000000000000000000000000000000000..3b48999edb37da4fdf232f2cbcd61df7affb40f2 --- /dev/null +++ b/tensorflow/core/framework/dataset_stateful_op_whitelist.h @@ -0,0 +1,77 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_FRAMEWORK_DATASET_STATEFUL_OP_WHITELIST_H_ +#define TENSORFLOW_CORE_FRAMEWORK_DATASET_STATEFUL_OP_WHITELIST_H_ + +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { +namespace dataset { +// Registry for stateful ops that need to be used in dataset functions. +// See below macro for usage details. +class WhitelistedStatefulOpRegistry { + public: + Status Add(StringPiece op_name) { + op_names_.insert(op_name); + return Status::OK(); + } + + bool Contains(StringPiece op_name) { + return op_names_.find(op_name) != op_names_.end(); + } + + static WhitelistedStatefulOpRegistry* Global() { + static WhitelistedStatefulOpRegistry* reg = + new WhitelistedStatefulOpRegistry; + return reg; + } + + private: + WhitelistedStatefulOpRegistry() {} + WhitelistedStatefulOpRegistry(WhitelistedStatefulOpRegistry const& copy); + WhitelistedStatefulOpRegistry operator=( + WhitelistedStatefulOpRegistry const& copy); + std::set op_names_; +}; + +} // namespace dataset + +// Use this macro to whitelist an op that is marked stateful but needs to be +// used inside a map_fn in an input pipeline. This is only needed if you wish +// to be able to checkpoint the state of the input pipeline. We currently +// do not allow stateful ops to be defined inside of map_fns since it is not +// possible to save their state. +// Note that the state of the whitelisted ops inside functions will not be +// saved during checkpointing, hence this should only be used if the op is +// marked stateful for reasons like to avoid constant folding during graph +// optimiztion but is not stateful. +// If possible, try to remove the stateful flag on the op first. +// Example usage: +// +// WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS("LegacyStatefulReader"); +// +#define WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS(name) \ + WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS_UNIQ_HELPER(__COUNTER__, name) +#define WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS_UNIQ_HELPER(ctr, name) \ + WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS_UNIQ(ctr, name) +#define WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS_UNIQ(ctr, name) \ + static ::tensorflow::Status whitelist_op##ctr TF_ATTRIBUTE_UNUSED = \ + ::tensorflow::dataset::WhitelistedStatefulOpRegistry::Global()->Add( \ + name) + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_FRAMEWORK_DATASET_STATEFUL_OP_WHITELIST_H_ diff --git a/tensorflow/core/framework/function.cc b/tensorflow/core/framework/function.cc index d6b576166cafb4a70a24a2a96db12679f5ea644d..eae8e6c3c10c4b49081aed0e253d9a6f382f562b 100644 --- a/tensorflow/core/framework/function.cc +++ b/tensorflow/core/framework/function.cc @@ -1064,26 +1064,36 @@ Status FunctionLibraryDefinition::AddLibrary( return Status::OK(); } -void FunctionLibraryDefinition::RemoveFunction(const string& func) { +Status FunctionLibraryDefinition::RemoveFunction(const string& func) { const auto& i = function_defs_.find(func); - DCHECK(i != function_defs_.end()); + if (i == function_defs_.end()) { + return errors::InvalidArgument("Tried to remove non-existent function ", + func); + } function_defs_.erase(i); + return Status::OK(); } -void FunctionLibraryDefinition::RemoveGradient(const string& func) { +Status FunctionLibraryDefinition::RemoveGradient(const string& func) { const auto& i = func_grad_.find(func); - DCHECK(i != func_grad_.end()); + if (i == func_grad_.end()) { + return errors::InvalidArgument("Tried to remove non-existent gradient ", + func); + } func_grad_.erase(i); + return Status::OK(); } void FunctionLibraryDefinition::Remove( const std::vector& funcs, const std::vector& funcs_with_grads) { for (const string& f : funcs) { - RemoveFunction(f); + Status s = RemoveFunction(f); + DCHECK(s.ok()); } for (const string& f : funcs_with_grads) { - RemoveGradient(f); + Status s = RemoveGradient(f); + DCHECK(s.ok()); } } diff --git a/tensorflow/core/framework/function.h b/tensorflow/core/framework/function.h index b933ee0b0e4009b1568d7465ca28d4b4f9a018e6..e27001133bbb5056abf1a3e1f5b9d69c8e01bc56 100644 --- a/tensorflow/core/framework/function.h +++ b/tensorflow/core/framework/function.h @@ -35,6 +35,7 @@ namespace tensorflow { class CancellationManager; class GraphDef; class OpKernel; +class ProcessFunctionLibraryRuntime; class ResourceMgr; class Rendezvous; class ScopedStepContainer; @@ -312,6 +313,14 @@ class FunctionLibraryDefinition : public OpRegistryInterface { // This operation is atomic. Status AddGradientDef(const GradientDef& grad); + // Remove function `func` from the library. Returns non-OK Status unless + // `func` is in the library. + Status RemoveFunction(const string& func); + + // Remove gradient of function `func` from the library. Returns non-OK Status + // unless `func` has a gradient. + Status RemoveGradient(const string& func); + // Adds the functions and gradients in 'other' to this function library. // Duplicate functions and gradients are ignored. // This operation is atomic. @@ -384,13 +393,6 @@ class FunctionLibraryDefinition : public OpRegistryInterface { // attr from. const FunctionDef* GetAttrImpl(const NodeDef& ndef) const; - // Remove function `func` from the library. `func` must be in the library. - void RemoveFunction(const string& func); - - // Remove gradient of function `func` from the library. `func` must have - // a gradient. - void RemoveGradient(const string& func); - // Remove all functions in `funcs` and all gradients of // functions in `funcs_with_grads` from this library. void Remove(const std::vector& funcs, @@ -534,6 +536,10 @@ class FunctionLibraryRuntime { virtual int graph_def_version() = 0; typedef uint64 LocalHandle; + + virtual Status Clone(std::unique_ptr* out_lib_def, + std::unique_ptr* out_pflr, + FunctionLibraryRuntime** out_flr) = 0; }; // Returns a canonicalized string for the instantiation of the diff --git a/tensorflow/core/framework/numeric_types.h b/tensorflow/core/framework/numeric_types.h index 99a5d0a054e9fe2c5dd729e165276369ebea7a71..4c38fbbe591a5d07ba4cbbea00dcbfb41ca2f403 100644 --- a/tensorflow/core/framework/numeric_types.h +++ b/tensorflow/core/framework/numeric_types.h @@ -17,7 +17,6 @@ limitations under the License. #define TENSORFLOW_FRAMEWORK_NUMERIC_TYPES_H_ #include - #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" // Disable clang-format to prevent 'FixedPoint' header from being included // before 'Tensor' header on which it depends. @@ -43,12 +42,47 @@ typedef Eigen::QUInt16 quint16; } // namespace tensorflow + + + +static inline tensorflow::bfloat16 FloatToBFloat16(float float_val) { +#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ + return *reinterpret_cast( + reinterpret_cast(&float_val)); +#else + return *reinterpret_cast( + &(reinterpret_cast(&float_val)[1])); +#endif +} + namespace Eigen { -// TOOD(xpan): We probably need to overwrite more methods to have correct eigen -// behavior. E.g. loest(), is_integer, etc. See NumTraits.h in eigen. +// TODO(xpan): We probably need to overwrite more methods to have correct eigen +// behavior. E.g. epsilon(), dummy_precision, etc. See NumTraits.h in eigen. template <> struct NumTraits - : GenericNumTraits {}; + : GenericNumTraits { + enum { + IsInteger = 0, + IsSigned = 1, + RequireInitialization = 0 + }; + static EIGEN_STRONG_INLINE tensorflow::bfloat16 highest() { + return FloatToBFloat16(NumTraits::highest()); + } + + static EIGEN_STRONG_INLINE tensorflow::bfloat16 lowest() { + return FloatToBFloat16(NumTraits::lowest()); + } + + static EIGEN_STRONG_INLINE tensorflow::bfloat16 infinity() { + return FloatToBFloat16(NumTraits::infinity()); + } + + static EIGEN_STRONG_INLINE tensorflow::bfloat16 quiet_NaN() { + return FloatToBFloat16(NumTraits::quiet_NaN()); + } +}; + using ::tensorflow::operator==; using ::tensorflow::operator!=; diff --git a/tensorflow/core/framework/op_kernel.cc b/tensorflow/core/framework/op_kernel.cc index fd2d06be9899852fa8ed61b2fdc4373ca4c0310e..56c013db9decbfc4688ce6764a5ae0066dfd4e7d 100644 --- a/tensorflow/core/framework/op_kernel.cc +++ b/tensorflow/core/framework/op_kernel.cc @@ -79,8 +79,14 @@ Status MatchSignatureHelper(const DataTypeSlice expected_inputs, // OpKernel ------------------------------------------------------------------ +// TODO(mrry): Convert to std::make_unique when available. OpKernel::OpKernel(OpKernelConstruction* context) - : def_(new NodeDef(context->def())), + : OpKernel(context, + std::unique_ptr(new NodeDef(context->def()))) {} + +OpKernel::OpKernel(OpKernelConstruction* context, + std::unique_ptr node_def) + : def_(std::move(node_def)), input_types_(context->input_types().begin(), context->input_types().end()), input_memory_types_(context->input_memory_types().begin(), diff --git a/tensorflow/core/framework/op_kernel.h b/tensorflow/core/framework/op_kernel.h index b72f1405cffd83439dd837fa7f8e641ecf44e2ae..c45026c6af3df14fb42b029a1a72283ce1c814cb 100644 --- a/tensorflow/core/framework/op_kernel.h +++ b/tensorflow/core/framework/op_kernel.h @@ -75,6 +75,14 @@ class OpKernel { // OpKernel won't be instantiated by the scheduler, so you may perform // expensive initialization in the descendant's constructor. explicit OpKernel(OpKernelConstruction* context); + + // Specialized constructor that enables the descendant to provide a different + // `NodeDef` value. For example, this constructor can be used to provide a + // stripped-down `NodeDef` that does not contain the full set of attrs (such + // as tensor values) if the descendant stores them in a different form. + explicit OpKernel(OpKernelConstruction* context, + std::unique_ptr node_def); + virtual ~OpKernel(); // An OpKernel's computation can be either synchronous or @@ -901,9 +909,13 @@ class OpKernelContext { } AllocatorAttributes input_alloc_attr(int index) const { - DCHECK_GE(index, 0); - DCHECK_LT(index, params_->input_alloc_attrs->size()); - return (*params_->input_alloc_attrs)[index]; + if (params_->input_alloc_attrs == nullptr) { + return AllocatorAttributes(); + } else { + DCHECK_GE(index, 0); + DCHECK_LT(index, params_->input_alloc_attrs->size()); + return (*params_->input_alloc_attrs)[index]; + } } AllocatorAttributes output_alloc_attr(int index) const { diff --git a/tensorflow/core/framework/register_types.h b/tensorflow/core/framework/register_types.h index 17d16c9b8d6871794dc0d048e0fe230b4e6ad1e6..e90596980f840588768c7883031f1ad179628833 100644 --- a/tensorflow/core/framework/register_types.h +++ b/tensorflow/core/framework/register_types.h @@ -179,7 +179,7 @@ limitations under the License. // Call "m" on all types. #define TF_CALL_ALL_TYPES(m) \ - TF_CALL_POD_TYPES(m) TF_CALL_string(m) TF_CALL_resource(m) + TF_CALL_POD_TYPES(m) TF_CALL_string(m) TF_CALL_resource(m) TF_CALL_variant(m) // Call "m" on POD and string types. #define TF_CALL_POD_STRING_TYPES(m) TF_CALL_POD_TYPES(m) TF_CALL_string(m) diff --git a/tensorflow/core/framework/tensor.cc b/tensorflow/core/framework/tensor.cc index 77a3edcc10e9c5ceb8bf26570c3e271f9e853444..0645ec42822fe7633e0517b28e50b0c221b3f80e 100644 --- a/tensorflow/core/framework/tensor.cc +++ b/tensorflow/core/framework/tensor.cc @@ -886,8 +886,9 @@ bool Tensor::CanUseDMA() const { namespace { // Print from left dim to right dim recursively. template -void PrintOneDim(int dim_index, gtl::InlinedVector shape, int64 limit, - int shape_size, T* data, int64* data_index, string* result) { +void PrintOneDim(int dim_index, const gtl::InlinedVector& shape, + int64 limit, int shape_size, const T* data, int64* data_index, + string* result) { if (*data_index >= limit) return; int64 element_count = shape[dim_index]; // We have reached the right-most dimension of the tensor. diff --git a/tensorflow/core/graph/costmodel.cc b/tensorflow/core/graph/costmodel.cc index 4118f14f8bf4084e860e7986f3ce62aaefedf366..4f3a6ec38cb88213c7127df41823bc16e9834d09 100644 --- a/tensorflow/core/graph/costmodel.cc +++ b/tensorflow/core/graph/costmodel.cc @@ -158,8 +158,8 @@ void CostModel::SetNumOutputs(const Node* node, int num_outputs) { Ensure(id, 0); auto perslot = &slot_bytes_[id]; if (!perslot->empty()) { - CHECK_EQ(num_outputs, perslot->size()) << "Cannot resize slot_bytes, node=" - << node->name(); + CHECK_EQ(num_outputs, perslot->size()) + << "Cannot resize slot_bytes, node=" << node->name(); } Ensure(id, num_outputs); } @@ -252,9 +252,12 @@ void CostModel::RecordMaxMemorySize(const Node* node, int output_slot, const DataType& dtype) { const int id = Id(node); if (id < 0) return; - CHECK_LT(output_slot, node->num_outputs()) - << "Unexpected output slot for node " << node->DebugString() << ". Got " - << output_slot << " but its num_outputs is " << node->num_outputs(); + if (output_slot >= node->num_outputs()) { + LOG(ERROR) << "Unexpected output slot for node " << node->DebugString() + << ". Got " << output_slot << " but its num_outputs is " + << node->num_outputs(); + return; + } Ensure(id, node->num_outputs()); auto& current_max = max_mem_usage_[id].output_port_mem[output_slot]; // If the memory allocator doesn't track memory usage, let's infer a lower diff --git a/tensorflow/core/graph/costmodel.h b/tensorflow/core/graph/costmodel.h index c60a946c2cc08af439134cd11b33c3cf97fe6d2f..9b703e46938b3355ed769045cdb3f298b48bb922 100644 --- a/tensorflow/core/graph/costmodel.h +++ b/tensorflow/core/graph/costmodel.h @@ -198,7 +198,7 @@ class CostModel { // Cumulative execution time. std::vector time_; // Cumulative Bytes output on each channel. - std::vector > slot_bytes_; + std::vector> slot_bytes_; // Maximum execution time std::vector max_exec_time_; @@ -217,7 +217,7 @@ class CostModel { }; std::vector max_mem_usage_; - std::vector > output_port_alloc_ids_; + std::vector> output_port_alloc_ids_; std::set persistent_alloc_ids_; std::map> persistent_alloc_ids_by_devices_; diff --git a/tensorflow/core/graph/graph.h b/tensorflow/core/graph/graph.h index b620127d9072a845721f97112f4bad107412b06f..93d8dd6f1100e9474b6e1c7afc56699163fc713f 100644 --- a/tensorflow/core/graph/graph.h +++ b/tensorflow/core/graph/graph.h @@ -62,8 +62,8 @@ class Node; class VersionDef; class WhileContext; -class NeighborIter; // Declared below -class NodeIter; // Declared below +class NeighborIter; // Declared below +class NodeIter; // Declared below class NodeProperties; // Defined in .cc class Node { diff --git a/tensorflow/core/graph/graph_def_builder_test.cc b/tensorflow/core/graph/graph_def_builder_test.cc index e85de71ef79988199cd194274f2ef9986e86d350..e928c81b45385ca0e10c2e6e1521d6d5d5a5eaf9 100644 --- a/tensorflow/core/graph/graph_def_builder_test.cc +++ b/tensorflow/core/graph/graph_def_builder_test.cc @@ -26,7 +26,6 @@ namespace tensorflow { namespace { TEST(GraphDefBuilderTest, Version) { - // Verify that our assertions will be nontrivial ASSERT_LT(0, TF_GRAPH_DEF_VERSION); diff --git a/tensorflow/core/graph/mkl_graph_util.h b/tensorflow/core/graph/mkl_graph_util.h index 3df981437afed760744ef870fd542d7abdd6e25d..1b99d54e8e33fd5155913a78ee833343bf92b905 100644 --- a/tensorflow/core/graph/mkl_graph_util.h +++ b/tensorflow/core/graph/mkl_graph_util.h @@ -21,102 +21,101 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" namespace tensorflow { - // Since our ops are going to produce and also consume N addition tensors - // (Mkl) for N Tensorflow tensors, we can have following different - // orderings among these 2N tensors. - // - // E.g., for Tensorflow tensors A, B, and C, our ops will produce and - // consume A_m, B_m, and C_m additionally. - // - // INTERLEAVED: in this case 2N tensors are interleaved. So for above - // example, the ordering looks like: A, A_m, B, B_m, C, C_m. - // - // CONTIGUOUS: in thi case N Tensorflow tensors are contiguous followed - // by N Mkl tensors. So for above example, the ordering looks - // like: A, B, C, A_m, B_m, C_m - // - // Following APIs map index of original Tensorflow tensors to their - // appropriate position based on selected ordering. For contiguous ordering, - // we need to know the total number of tensors (parameter total). - // - typedef enum { TENSORS_INTERLEAVED, TENSORS_CONTIGUOUS } MklTfTensorOrdering; - // NOTE: Currently, we use contiguous ordering. If you change this, then you - // would need to change Mkl op definitions in nn_ops.cc. - static MklTfTensorOrdering kTensorOrdering = TENSORS_CONTIGUOUS; +// Since our ops are going to produce and also consume N addition tensors +// (Mkl) for N Tensorflow tensors, we can have following different +// orderings among these 2N tensors. +// +// E.g., for Tensorflow tensors A, B, and C, our ops will produce and +// consume A_m, B_m, and C_m additionally. +// +// INTERLEAVED: in this case 2N tensors are interleaved. So for above +// example, the ordering looks like: A, A_m, B, B_m, C, C_m. +// +// CONTIGUOUS: in thi case N Tensorflow tensors are contiguous followed +// by N Mkl tensors. So for above example, the ordering looks +// like: A, B, C, A_m, B_m, C_m +// +// Following APIs map index of original Tensorflow tensors to their +// appropriate position based on selected ordering. For contiguous ordering, +// we need to know the total number of tensors (parameter total). +// +typedef enum { TENSORS_INTERLEAVED, TENSORS_CONTIGUOUS } MklTfTensorOrdering; +// NOTE: Currently, we use contiguous ordering. If you change this, then you +// would need to change Mkl op definitions in nn_ops.cc. +static MklTfTensorOrdering kTensorOrdering = TENSORS_CONTIGUOUS; - // Get index of MetaData tensor from index 'n' of Data tensor. - inline int DataIndexToMetaDataIndex(int n, int total_tensors) { - if (kTensorOrdering == MklTfTensorOrdering::TENSORS_INTERLEAVED) { - // For interleaved ordering, Mkl tensor follows immediately after - // Tensorflow tensor. - return n + 1; - } else { - CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS); - // For contiguous ordering, Mkl tensor is n+total_tensors / 2 away. - return n + total_tensors / 2; - } +// Get index of MetaData tensor from index 'n' of Data tensor. +inline int DataIndexToMetaDataIndex(int n, int total_tensors) { + if (kTensorOrdering == MklTfTensorOrdering::TENSORS_INTERLEAVED) { + // For interleaved ordering, Mkl tensor follows immediately after + // Tensorflow tensor. + return n + 1; + } else { + CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS); + // For contiguous ordering, Mkl tensor is n+total_tensors / 2 away. + return n + total_tensors / 2; } +} - int inline GetTensorDataIndex(int n, int total_tensors) { - if (kTensorOrdering == MklTfTensorOrdering::TENSORS_INTERLEAVED) { - return 2 * n; // index corresponding to nth input/output tensor - } else { - CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS); - return n; - } - } +int inline GetTensorDataIndex(int n, int total_tensors) { + if (kTensorOrdering == MklTfTensorOrdering::TENSORS_INTERLEAVED) { + return 2 * n; // index corresponding to nth input/output tensor + } else { + CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS); + return n; + } +} - int inline GetTensorMetaDataIndex(int n, int total_tensors) { - // Get index for TensorData first and then use mapping function - // to get TensorMetaData index from TensorData index. - int tidx = GetTensorDataIndex(n, total_tensors); - return DataIndexToMetaDataIndex(tidx, total_tensors); - } +int inline GetTensorMetaDataIndex(int n, int total_tensors) { + // Get index for TensorData first and then use mapping function + // to get TensorMetaData index from TensorData index. + int tidx = GetTensorDataIndex(n, total_tensors); + return DataIndexToMetaDataIndex(tidx, total_tensors); +} namespace mkl_op_registry { - static const char* kMklOpLabel = "MklOp"; - static const char* kMklOpLabelPattern = "label='MklOp'"; - // Prefix that we add to Tensorflow op name to construct Mkl op name. - static const char* const kMklOpPrefix = "_Mkl"; +static const char* kMklOpLabel = "MklOp"; +static const char* kMklOpLabelPattern = "label='MklOp'"; +// Prefix that we add to Tensorflow op name to construct Mkl op name. +static const char* const kMklOpPrefix = "_Mkl"; - // Get the name of Mkl op from original TensorFlow op - // We prefix 'Mkl' to the original op to get Mkl op. - inline string GetMklOpName(const string& name) { - return string(kMklOpPrefix) + name; - } +// Get the name of Mkl op from original TensorFlow op +// We prefix 'Mkl' to the original op to get Mkl op. +inline string GetMklOpName(const string& name) { + return string(kMklOpPrefix) + name; +} - // Check whether opname with type T is registered as MKL-compliant. - // - // @input: name of the op - // @input: T datatype to be used for checking op - // @return: true if opname is registered as Mkl op; false otherwise - static inline bool IsMklOp(const std::string& op_name, DataType T) { - string kernel = KernelsRegisteredForOp(op_name); - bool result = - kernel.find(kMklOpLabelPattern) != string::npos && (T == DT_FLOAT); - return result; - } +// Check whether opname with type T is registered as MKL-compliant. +// +// @input: name of the op +// @input: T datatype to be used for checking op +// @return: true if opname is registered as Mkl op; false otherwise +static inline bool IsMklOp(const std::string& op_name, DataType T) { + string kernel = KernelsRegisteredForOp(op_name); + bool result = + kernel.find(kMklOpLabelPattern) != string::npos && (T == DT_FLOAT); + return result; +} - // Check whether opname with type T is registered as MKL-compliant and - // is element-wise. - // - // @input: name of the op - // @input: T datatype to be used for checking op - // @return: true if opname is registered as element-wise Mkl op; - // false otherwise - static inline bool IsMklElementWiseOp(const std::string& op_name, - DataType T) { - if (!IsMklOp(op_name, T)) { - return false; - } - bool result = (0 == op_name.compare(GetMklOpName("Add")) || - 0 == op_name.compare(GetMklOpName("Sub")) || - 0 == op_name.compare(GetMklOpName("Mul")) || - 0 == op_name.compare(GetMklOpName("Maximum")) || - 0 == op_name.compare(GetMklOpName("SquaredDifference"))); - - return result; +// Check whether opname with type T is registered as MKL-compliant and +// is element-wise. +// +// @input: name of the op +// @input: T datatype to be used for checking op +// @return: true if opname is registered as element-wise Mkl op; +// false otherwise +static inline bool IsMklElementWiseOp(const std::string& op_name, DataType T) { + if (!IsMklOp(op_name, T)) { + return false; } + bool result = (0 == op_name.compare(GetMklOpName("Add")) || + 0 == op_name.compare(GetMklOpName("Sub")) || + 0 == op_name.compare(GetMklOpName("Mul")) || + 0 == op_name.compare(GetMklOpName("Maximum")) || + 0 == op_name.compare(GetMklOpName("SquaredDifference"))); + + return result; +} } // namespace mkl_op_registry } // namespace tensorflow #endif // INTEL_MKL diff --git a/tensorflow/core/graph/mkl_layout_pass.cc b/tensorflow/core/graph/mkl_layout_pass.cc index 911d931a52cb48bcd87b2c39134255c03a055ef8..0e8a1cb26ce76855f334e16c8fa46c677735a34f 100644 --- a/tensorflow/core/graph/mkl_layout_pass.cc +++ b/tensorflow/core/graph/mkl_layout_pass.cc @@ -37,8 +37,8 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/util/tensor_format.h" -#include "tensorflow/core/graph/mkl_layout_pass.h" #include "tensorflow/core/graph/mkl_graph_util.h" +#include "tensorflow/core/graph/mkl_layout_pass.h" namespace tensorflow { @@ -281,7 +281,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass { csinfo_.mkl_conv2d_grad_filter = "_MklConv2DBackpropFilter"; csinfo_.mkl_conv2d_with_bias = "_MklConv2DWithBias"; csinfo_.mkl_conv2d_with_bias_backprop_bias = - "_MklConv2DWithBiasBackpropBias"; + "_MklConv2DWithBiasBackpropBias"; csinfo_.relu = "Relu"; csinfo_.relu_grad = "ReluGrad"; csinfo_.reshape = "Reshape"; @@ -297,10 +297,9 @@ class MklLayoutRewritePass : public GraphOptimizationPass { // End - element-wise ops. See note above. // NOTE: names are alphabetically sorted. - rinfo_.push_back({csinfo_.addn, mkl_op_registry::GetMklOpName(csinfo_.addn), CopyAttrsAddN, - AddNRewrite, nullptr}); - rinfo_.push_back({csinfo_.add, - mkl_op_registry::GetMklOpName(csinfo_.add), + rinfo_.push_back({csinfo_.addn, mkl_op_registry::GetMklOpName(csinfo_.addn), + CopyAttrsAddN, AddNRewrite, nullptr}); + rinfo_.push_back({csinfo_.add, mkl_op_registry::GetMklOpName(csinfo_.add), CopyAttrsDataType, AlwaysRewrite, nullptr}); rinfo_.push_back({csinfo_.avg_pool, mkl_op_registry::GetMklOpName(csinfo_.avg_pool), @@ -337,14 +336,14 @@ class MklLayoutRewritePass : public GraphOptimizationPass { rinfo_.push_back({csinfo_.fused_batch_norm, mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm), CopyAttrsFusedBatchNorm, AlwaysRewrite, nullptr}); - rinfo_.push_back({csinfo_.fused_batch_norm_grad, - mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_grad), - CopyAttrsFusedBatchNorm, AlwaysRewrite, nullptr}); + rinfo_.push_back( + {csinfo_.fused_batch_norm_grad, + mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_grad), + CopyAttrsFusedBatchNorm, AlwaysRewrite, nullptr}); rinfo_.push_back({csinfo_.identity, mkl_op_registry::GetMklOpName(csinfo_.identity), CopyAttrsIdentity, AlwaysRewrite, nullptr}); - rinfo_.push_back({csinfo_.lrn, - mkl_op_registry::GetMklOpName(csinfo_.lrn), + rinfo_.push_back({csinfo_.lrn, mkl_op_registry::GetMklOpName(csinfo_.lrn), CopyAttrsLRN, AlwaysRewrite, nullptr}); rinfo_.push_back({csinfo_.lrn_grad, mkl_op_registry::GetMklOpName(csinfo_.lrn_grad), @@ -358,11 +357,9 @@ class MklLayoutRewritePass : public GraphOptimizationPass { rinfo_.push_back({csinfo_.maximum, mkl_op_registry::GetMklOpName(csinfo_.maximum), CopyAttrsDataType, AlwaysRewrite, nullptr}); - rinfo_.push_back({csinfo_.mul, - mkl_op_registry::GetMklOpName(csinfo_.mul), + rinfo_.push_back({csinfo_.mul, mkl_op_registry::GetMklOpName(csinfo_.mul), CopyAttrsDataType, AlwaysRewrite, nullptr}); - rinfo_.push_back({csinfo_.relu, - mkl_op_registry::GetMklOpName(csinfo_.relu), + rinfo_.push_back({csinfo_.relu, mkl_op_registry::GetMklOpName(csinfo_.relu), CopyAttrsDataType, AlwaysRewrite, nullptr}); rinfo_.push_back({csinfo_.relu_grad, mkl_op_registry::GetMklOpName(csinfo_.relu_grad), @@ -373,8 +370,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass { rinfo_.push_back({csinfo_.squared_difference, mkl_op_registry::GetMklOpName(csinfo_.squared_difference), CopyAttrsDataType, AlwaysRewrite, nullptr}); - rinfo_.push_back({csinfo_.sub, - mkl_op_registry::GetMklOpName(csinfo_.sub), + rinfo_.push_back({csinfo_.sub, mkl_op_registry::GetMklOpName(csinfo_.sub), CopyAttrsDataType, AlwaysRewrite, nullptr}); // Add info about which ops to add workspace edge to and the slots. @@ -388,9 +384,9 @@ class MklLayoutRewritePass : public GraphOptimizationPass { biasaddgrad_matmul_context_ = {csinfo_.bias_add_grad, csinfo_.matmul, IsBiasAddGradInMatMulContext}; - biasaddgrad_conv2dwithbias_context_ = {csinfo_.bias_add_grad, - csinfo_.mkl_conv2d_with_bias, - IsBiasAddGradInConv2DWithBiasContext}; + biasaddgrad_conv2dwithbias_context_ = { + csinfo_.bias_add_grad, csinfo_.mkl_conv2d_with_bias, + IsBiasAddGradInConv2DWithBiasContext}; cinfo_.push_back(&biasaddgrad_matmul_context_); cinfo_.push_back(&biasaddgrad_conv2dwithbias_context_); @@ -410,9 +406,9 @@ class MklLayoutRewritePass : public GraphOptimizationPass { /// Structure to specify the context information used in a node rewrite rule typedef struct { - string node; // Name of the node to be rewritten - string fwd; // Name of the node in the forward pass that this node - // corresponds to + string node; // Name of the node to be rewritten + string fwd; // Name of the node in the forward pass that this node + // corresponds to std::function context_match_fn; } ContextInfo; @@ -615,14 +611,13 @@ class MklLayoutRewritePass : public GraphOptimizationPass { std::vector ksize, strides; CHECK_EQ(GetNodeAttr(n->def(), "ksize", &ksize).ok(), true); CHECK_EQ(GetNodeAttr(n->def(), "strides", &strides).ok(), true); - CHECK_EQ(GetNodeAttr(n->def(), "data_format", &data_format_str).ok(), - true); + CHECK_EQ(GetNodeAttr(n->def(), "data_format", &data_format_str).ok(), true); CHECK_EQ(FormatFromString(data_format_str, &data_format), true); // Condition that specifies non-batch-wise and non-depth-wise pooling. - if (GetTensorDim(ksize, data_format, 'N') == 1 && + if (GetTensorDim(ksize, data_format, 'N') == 1 && GetTensorDim(strides, data_format, 'N') == 1 && - GetTensorDim(ksize, data_format, 'C') == 1 && + GetTensorDim(ksize, data_format, 'C') == 1 && GetTensorDim(strides, data_format, 'C') == 1) { return true; } @@ -785,8 +780,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass { for (const Edge* fe : first_inp_of_filter->out_edges()) { if (fe->dst()->type_string() == csinfo_.mkl_conv2d_with_bias && fe->dst_input() == 0) { - VLOG(1) << "MklLayoutRewritePass: found " - << fe->dst()->DebugString() + VLOG(1) << "MklLayoutRewritePass: found " << fe->dst()->DebugString() << " as the forward node for matching context, backward" << " node is: " << n->DebugString(); *fwd_node = fe->dst(); @@ -803,13 +797,11 @@ class MklLayoutRewritePass : public GraphOptimizationPass { // // @return - true (if BiasAddGrad is associated with MatMul); // false otherwise. - static bool IsBiasAddGradInMatMulContext(const Node* n, - const Node** fwd_node, + static bool IsBiasAddGradInMatMulContext(const Node* n, const Node** fwd_node, void* ci) { return (!IsBiasAddGradInConv2DWithBiasContext(n, fwd_node, ci)); } - // Rewrite rule that uses context-information for matching, // used in scenario 2. // @@ -880,10 +872,11 @@ class MklLayoutRewritePass : public GraphOptimizationPass { // @output output_nodes - the list of new nodes creating Mkl tensors // // @return None - void GetNodesProducingMklTensorList(std::unique_ptr* g, - Node* orig_node, const gtl::InlinedVector, 4>& inputs, - int* input_idx, int list_length, - std::vector* output_nodes); + void GetNodesProducingMklTensorList( + std::unique_ptr* g, Node* orig_node, + const gtl::InlinedVector, 4>& inputs, + int* input_idx, int list_length, + std::vector* output_nodes); // Get a node that will feed an Mkl tensor to the new // node that we are constructing. The output node could be (1) 'n' @@ -900,7 +893,8 @@ class MklLayoutRewritePass : public GraphOptimizationPass { // will feed the tensor // @return None void GetNodeProducingMklTensor(std::unique_ptr* g, Node* orig_node, - Node* n, int n_output_slot, Node** mkl_node, int* mkl_node_output_slot); + Node* n, int n_output_slot, Node** mkl_node, + int* mkl_node_output_slot); // Setup new inputs using old inputs 'inputs' for the rewritten node in 'nb' // in graph 'g'. Original node is input in 'old_node'. Inputs to 'nb' are @@ -970,9 +964,9 @@ class MklLayoutRewritePass : public GraphOptimizationPass { MklLayoutRewritePass::ConstStringsInfo MklLayoutRewritePass::csinfo_; MklLayoutRewritePass::ContextInfo - MklLayoutRewritePass::biasaddgrad_conv2dwithbias_context_; + MklLayoutRewritePass::biasaddgrad_conv2dwithbias_context_; MklLayoutRewritePass::ContextInfo - MklLayoutRewritePass::biasaddgrad_matmul_context_; + MklLayoutRewritePass::biasaddgrad_matmul_context_; std::vector MklLayoutRewritePass::cinfo_; // We register Mkl rewrite pass for phase 1 in post partitioning group. @@ -1041,13 +1035,13 @@ void MklLayoutRewritePass::GetDummyMklTensorNode(std::unique_ptr* g, TensorShape dummy_shape({8}); dummy_shape.AsProto(proto.mutable_tensor_shape()); TF_CHECK_OK(NodeBuilder((*g)->NewName("DMT"), "Const") - .Attr("value", proto) - .Attr("dtype", dt) - .Device(orig_node->def().device()) // We place this node on - // the same device as the - // device of the original - // node. - .Finalize(&**g, out)); + .Attr("value", proto) + .Attr("dtype", dt) + .Device(orig_node->def().device()) // We place this node on + // the same device as the + // device of the original + // node. + .Finalize(&**g, out)); // If number of inputs to the original node is > 0, then we add // control dependency between 1st input (index 0) of the original node and @@ -1060,8 +1054,8 @@ void MklLayoutRewritePass::GetDummyMklTensorNode(std::unique_ptr* g, // the same frame. if (orig_node->num_inputs() > 0) { Node* orig_input0 = nullptr; - TF_CHECK_OK(orig_node->input_node(0, - const_cast(&orig_input0))); + TF_CHECK_OK( + orig_node->input_node(0, const_cast(&orig_input0))); CHECK_NOTNULL((*g)->AddControlEdge(orig_input0, *out)); } @@ -1069,11 +1063,9 @@ void MklLayoutRewritePass::GetDummyMklTensorNode(std::unique_ptr* g, } void MklLayoutRewritePass::GetNodesProducingMklTensorList( - std::unique_ptr* g, - Node* orig_node, - const gtl::InlinedVector, 4>& inputs, - int* input_idx, int list_length, - std::vector* output_nodes) { + std::unique_ptr* g, Node* orig_node, + const gtl::InlinedVector, 4>& inputs, int* input_idx, + int list_length, std::vector* output_nodes) { CHECK_LT(*input_idx, inputs.size()); CHECK_GT(list_length, 0); CHECK_NOTNULL(output_nodes); @@ -1090,8 +1082,8 @@ void MklLayoutRewritePass::GetNodesProducingMklTensorList( int mkl_node_output_slot = 0; GetNodeProducingMklTensor(g, orig_node, n, slot, &mkl_node, &mkl_node_output_slot); - output_nodes->push_back(NodeBuilder::NodeOut(mkl_node, - mkl_node_output_slot)); + output_nodes->push_back( + NodeBuilder::NodeOut(mkl_node, mkl_node_output_slot)); (*input_idx)++; list_length--; } @@ -1101,9 +1093,9 @@ void MklLayoutRewritePass::GetNodesProducingMklTensorList( // node that we are constructing. An input node could be (1) 'n' // if it is Mkl layer, or (2) a dummy node producing dummy Mkl tensor // if 'n' is not an Mkl layer. -void MklLayoutRewritePass::GetNodeProducingMklTensor(std::unique_ptr* g, - Node* orig_node, Node* n, - int n_output_slot, Node** mkl_node, int* mkl_node_output_slot) { +void MklLayoutRewritePass::GetNodeProducingMklTensor( + std::unique_ptr* g, Node* orig_node, Node* n, int n_output_slot, + Node** mkl_node, int* mkl_node_output_slot) { CHECK_NOTNULL(n); CHECK_NOTNULL(mkl_node); CHECK_NOTNULL(mkl_node_output_slot); @@ -1234,8 +1226,8 @@ int MklLayoutRewritePass::SetUpContiguousInputs( if (ArgIsList(arg)) { std::vector new_node_inputs; int N = GetTensorListLength(arg, old_node); - GetNodesProducingMklTensorList(g, old_node, old_node_inputs, &iidx, - N, &new_node_inputs); + GetNodesProducingMklTensorList(g, old_node, old_node_inputs, &iidx, N, + &new_node_inputs); nb->Input(new_node_inputs); nn_slot_idx++; } else { @@ -1336,13 +1328,13 @@ void MklLayoutRewritePass::GetDummyWorkspaceTensorNode( TensorShape dummy_shape({1}); dummy_shape.AsProto(proto.mutable_tensor_shape()); TF_CHECK_OK(NodeBuilder((*g)->NewName("DMT"), "Const") - .Attr("value", proto) - .Attr("dtype", dt) - .Device(orig_node->def().device()) // We place this node on - // same the device as the - // device of the original - // node. - .Finalize(&**g, out)); + .Attr("value", proto) + .Attr("dtype", dt) + .Device(orig_node->def().device()) // We place this node on + // same the device as the + // device of the original + // node. + .Finalize(&**g, out)); // If number of inputs to the original node is > 0, then we add // control dependency between 1st input (index 0) of the original node and @@ -1355,8 +1347,8 @@ void MklLayoutRewritePass::GetDummyWorkspaceTensorNode( // the same frame. if (orig_node->num_inputs() > 0) { Node* orig_input0 = nullptr; - TF_CHECK_OK(orig_node->input_node(0, - const_cast(&orig_input0))); + TF_CHECK_OK( + orig_node->input_node(0, const_cast(&orig_input0))); CHECK_NOTNULL((*g)->AddControlEdge(orig_input0, *out)); } @@ -1374,7 +1366,8 @@ void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded( TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); for (auto ws : wsinfo_) { if (orig_node->type_string() == ws.fwd_op && - mkl_op_registry::IsMklOp(mkl_op_registry::GetMklOpName(orig_node->type_string()), T)) { + mkl_op_registry::IsMklOp( + mkl_op_registry::GetMklOpName(orig_node->type_string()), T)) { // If this op is a fwd op, then we need to check if there is an // edge from this node's fwd_slot to bwdop's bwd_slot. If there is // an edge, then we just add an attribute on this node for setting @@ -1400,8 +1393,9 @@ void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded( nb->Attr("workspace_enabled", false); } } else if (orig_node->type_string() == ws.bwd_op && - mkl_op_registry::IsMklOp(mkl_op_registry::GetMklOpName(orig_node->type_string()), - T)) { + mkl_op_registry::IsMklOp( + mkl_op_registry::GetMklOpName(orig_node->type_string()), + T)) { // If this op is a bwd op, then we need to add workspace edge and // it's Mkl tensor edge between its corresponding fwd op and this // op. Corresponding fwd op is specified in 'fwd_op' field of @@ -1416,7 +1410,8 @@ void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded( if (e->src_output() == ws.fwd_slot && // We would have rewritten the forward op, so we need to use // GetMklOpName call to get its Mkl name. - e->src()->type_string() == mkl_op_registry::GetMklOpName(ws.fwd_op) && + e->src()->type_string() == + mkl_op_registry::GetMklOpName(ws.fwd_op) && e->dst_input() == ws.bwd_slot) { nb->Attr("workspace_enabled", true); CHECK_NOTNULL(ws_tensors); @@ -1593,7 +1588,7 @@ void MklLayoutRewritePass::CopyAttrsDataType(const Node* orig_node, } void MklLayoutRewritePass::CopyAttrsReshape(const Node* orig_node, - NodeBuilder* nb) { + NodeBuilder* nb) { DataType T; DataType Tshape; @@ -1869,8 +1864,8 @@ Status MklLayoutRewritePass::MergeNode(std::unique_ptr* g, Node* succ, if (e->IsControlEdge()) { CHECK_NOTNULL((*g)->AddControlEdge(new_node, e->dst())); } else { - CHECK_NOTNULL((*g)->AddEdge(new_node, e->src_output(), e->dst(), - e->dst_input())); + CHECK_NOTNULL( + (*g)->AddEdge(new_node, e->src_output(), e->dst(), e->dst_input())); } } @@ -1941,9 +1936,9 @@ Status MklLayoutRewritePass::RewriteNode(std::unique_ptr* g, // and leave BiasAddGrad as it is. But we check for this condition // when we check for node rewrite rule. So we should not even come // here for MatMul. So we will fail now. - return Status( - error::Code::INVALID_ARGUMENT, - "No rewrite is required for BiasAddGrad for MatMul context."); + return Status( + error::Code::INVALID_ARGUMENT, + "No rewrite is required for BiasAddGrad for MatMul context."); } } @@ -2012,9 +2007,10 @@ Status MklLayoutRewritePass::RewriteNode(std::unique_ptr* g, if (e->IsControlEdge()) { CHECK_NOTNULL((*g)->AddControlEdge(new_node, e->dst())); } else { - CHECK_NOTNULL((*g)->AddEdge(new_node, GetTensorDataIndex(e->src_output(), - e->src()->num_outputs()), - e->dst(), e->dst_input())); + CHECK_NOTNULL((*g)->AddEdge( + new_node, + GetTensorDataIndex(e->src_output(), e->src()->num_outputs()), + e->dst(), e->dst_input())); } } @@ -2070,7 +2066,8 @@ MklLayoutRewritePass::CheckForNodeRewrite(const Node* n) const { // BiasAddGrad is not an Mkl layer, so we make an exception for it. if (n->type_string() != csinfo_.bias_add_grad) { - if (!mkl_op_registry::IsMklOp(mkl_op_registry::GetMklOpName(n->type_string()), T)) { + if (!mkl_op_registry::IsMklOp( + mkl_op_registry::GetMklOpName(n->type_string()), T)) { return nullptr; } } @@ -2186,8 +2183,7 @@ bool RunMklLayoutRewritePass(std::unique_ptr* g) { return MklLayoutRewritePass().RunPass(g); } -Status MklLayoutRewritePass::Run( - const GraphOptimizationPassOptions& options) { +Status MklLayoutRewritePass::Run(const GraphOptimizationPassOptions& options) { if (options.graph == nullptr && options.partition_graphs == nullptr) { return Status::OK(); } @@ -2421,7 +2417,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass { csinfo_.conv2d_grad_input = "Conv2DBackpropInput"; csinfo_.conv2d_grad_filter = "Conv2DBackpropFilter"; csinfo_.conv2d_grad_filter_with_bias = - "__MklDummyConv2DBackpropFilterWithBias"; + "__MklDummyConv2DBackpropFilterWithBias"; csinfo_.fused_batch_norm = "FusedBatchNorm"; csinfo_.fused_batch_norm_grad = "FusedBatchNormGrad"; csinfo_.identity = "Identity"; @@ -2435,11 +2431,11 @@ class MklLayoutRewritePass : public GraphOptimizationPass { csinfo_.mkl_conv2d_grad_filter = "_MklConv2DBackpropFilter"; csinfo_.mkl_conv2d_with_bias = "_MklConv2DWithBias"; csinfo_.mkl_conv2d_grad_filter_with_bias = - "_MklConv2DBackpropFilterWithBias"; + "_MklConv2DBackpropFilterWithBias"; csinfo_.relu = "Relu"; csinfo_.relu_grad = "ReluGrad"; - csinfo_.tanh = "Tanh"; - csinfo_.tanh_grad = "TanhGrad"; + csinfo_.tanh = "Tanh"; + csinfo_.tanh_grad = "TanhGrad"; csinfo_.reshape = "Reshape"; csinfo_.softmax = "Softmax"; csinfo_.split = "Split"; @@ -2456,9 +2452,9 @@ class MklLayoutRewritePass : public GraphOptimizationPass { // NOTE: names are alphabetically sorted. rinfo_.push_back({csinfo_.addn, mkl_op_registry::GetMklOpName(csinfo_.addn), CopyAttrsAddN, AddNRewrite}); - /* rinfo_.push_back({csinfo_.add, + rinfo_.push_back({csinfo_.add, mkl_op_registry::GetMklOpName(csinfo_.add), - CopyAttrsDataType, AlwaysRewrite}); */ + CopyAttrsDataType, AlwaysRewrite}); rinfo_.push_back({csinfo_.avg_pool, mkl_op_registry::GetMklOpName(csinfo_.avg_pool), CopyAttrsPooling, AlwaysRewrite}); @@ -2474,29 +2470,28 @@ class MklLayoutRewritePass : public GraphOptimizationPass { rinfo_.push_back({csinfo_.conv2d, mkl_op_registry::GetMklOpName(csinfo_.conv2d), CopyAttrsConv2D, AlwaysRewrite}); - rinfo_.push_back({csinfo_.conv2d_with_bias, - csinfo_.mkl_conv2d_with_bias, + rinfo_.push_back({csinfo_.conv2d_with_bias, csinfo_.mkl_conv2d_with_bias, CopyAttrsConv2D, AlwaysRewrite}); rinfo_.push_back({csinfo_.conv2d_grad_filter, mkl_op_registry::GetMklOpName(csinfo_.conv2d_grad_filter), CopyAttrsConv2D, AlwaysRewrite}); rinfo_.push_back({csinfo_.conv2d_grad_filter_with_bias, - csinfo_.mkl_conv2d_grad_filter_with_bias, - CopyAttrsConv2D, AlwaysRewrite}); + csinfo_.mkl_conv2d_grad_filter_with_bias, CopyAttrsConv2D, + AlwaysRewrite}); rinfo_.push_back({csinfo_.conv2d_grad_input, mkl_op_registry::GetMklOpName(csinfo_.conv2d_grad_input), CopyAttrsConv2D, AlwaysRewrite}); rinfo_.push_back({csinfo_.fused_batch_norm, mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm), CopyAttrsFusedBatchNorm, AlwaysRewrite}); - rinfo_.push_back({csinfo_.fused_batch_norm_grad, - mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_grad), - CopyAttrsFusedBatchNorm, AlwaysRewrite}); + rinfo_.push_back( + {csinfo_.fused_batch_norm_grad, + mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_grad), + CopyAttrsFusedBatchNorm, AlwaysRewrite}); rinfo_.push_back({csinfo_.identity, mkl_op_registry::GetMklOpName(csinfo_.identity), CopyAttrsDataType, AlwaysRewrite}); - rinfo_.push_back({csinfo_.lrn, - mkl_op_registry::GetMklOpName(csinfo_.lrn), + rinfo_.push_back({csinfo_.lrn, mkl_op_registry::GetMklOpName(csinfo_.lrn), CopyAttrsLRN, AlwaysRewrite}); rinfo_.push_back({csinfo_.lrn_grad, mkl_op_registry::GetMklOpName(csinfo_.lrn_grad), @@ -2507,14 +2502,13 @@ class MklLayoutRewritePass : public GraphOptimizationPass { rinfo_.push_back({csinfo_.max_pool_grad, mkl_op_registry::GetMklOpName(csinfo_.max_pool_grad), CopyAttrsPooling, AlwaysRewrite}); - /* + rinfo_.push_back({csinfo_.maximum, mkl_op_registry::GetMklOpName(csinfo_.maximum), CopyAttrsDataType, AlwaysRewrite}); rinfo_.push_back({csinfo_.mul, mkl_op_registry::GetMklOpName(csinfo_.mul), CopyAttrsDataType, AlwaysRewrite}); - */ rinfo_.push_back({csinfo_.relu, mkl_op_registry::GetMklOpName(csinfo_.relu), CopyAttrsDataType, AlwaysRewrite}); @@ -2535,14 +2529,14 @@ class MklLayoutRewritePass : public GraphOptimizationPass { rinfo_.push_back({csinfo_.softmax, mkl_op_registry::GetMklOpName(csinfo_.softmax), CopyAttrsDataType, AlwaysRewrite}); - /* + rinfo_.push_back({csinfo_.squared_difference, mkl_op_registry::GetMklOpName(csinfo_.squared_difference), CopyAttrsDataType, AlwaysRewrite}); rinfo_.push_back({csinfo_.sub, mkl_op_registry::GetMklOpName(csinfo_.sub), CopyAttrsDataType, AlwaysRewrite}); - */ + // Add info about which ops to add workspace edge to and the slots. wsinfo_.push_back({csinfo_.lrn, csinfo_.lrn_grad, 0, 2, 1, 3}); @@ -2550,8 +2544,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass { // Add a rule for merging nodes minfo_.push_back({csinfo_.conv2d, csinfo_.bias_add, - csinfo_.conv2d_with_bias, - GetConv2DOrBiasAdd}); + csinfo_.conv2d_with_bias, GetConv2DOrBiasAdd}); minfo_.push_back({csinfo_.conv2d_grad_filter, csinfo_.bias_add_grad, csinfo_.conv2d_grad_filter_with_bias, @@ -2846,9 +2839,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass { // Default rewrite rule to be used in scenario 1 for rewrite. // @return - true (since we want to always rewrite) - static bool AlwaysRewrite(const Node* n) { - return true; - } + static bool AlwaysRewrite(const Node* n) { return true; } // Check if we are performing pooling on depth or batch. If it is, then we // do not rewrite MaxPool node to Mkl version. @@ -2862,14 +2853,13 @@ class MklLayoutRewritePass : public GraphOptimizationPass { std::vector ksize, strides; CHECK_EQ(GetNodeAttr(n->def(), "ksize", &ksize).ok(), true); CHECK_EQ(GetNodeAttr(n->def(), "strides", &strides).ok(), true); - CHECK_EQ(GetNodeAttr(n->def(), "data_format", &data_format_str).ok(), - true); + CHECK_EQ(GetNodeAttr(n->def(), "data_format", &data_format_str).ok(), true); CHECK_EQ(FormatFromString(data_format_str, &data_format), true); // Condition that specifies non-batch-wise and non-depth-wise pooling. - if (GetTensorDim(ksize, data_format, 'N') == 1 && + if (GetTensorDim(ksize, data_format, 'N') == 1 && GetTensorDim(strides, data_format, 'N') == 1 && - GetTensorDim(ksize, data_format, 'C') == 1 && + GetTensorDim(ksize, data_format, 'C') == 1 && GetTensorDim(strides, data_format, 'C') == 1) { return true; } @@ -2941,10 +2931,11 @@ class MklLayoutRewritePass : public GraphOptimizationPass { // @output output_nodes - the list of new nodes creating Mkl tensors // // @return None - void GetNodesProducingMklTensorList(std::unique_ptr* g, - Node* orig_node, const gtl::InlinedVector, 4>& inputs, - int* input_idx, int list_length, - std::vector* output_nodes); + void GetNodesProducingMklTensorList( + std::unique_ptr* g, Node* orig_node, + const gtl::InlinedVector, 4>& inputs, + int* input_idx, int list_length, + std::vector* output_nodes); // Get a node that will feed an Mkl tensor to the new // node that we are constructing. The output node could be (1) 'n' @@ -2961,7 +2952,8 @@ class MklLayoutRewritePass : public GraphOptimizationPass { // will feed the tensor // @return None void GetNodeProducingMklTensor(std::unique_ptr* g, Node* orig_node, - Node* n, int n_output_slot, Node** mkl_node, int* mkl_node_output_slot); + Node* n, int n_output_slot, Node** mkl_node, + int* mkl_node_output_slot); // Setup new inputs using old inputs 'inputs' for the rewritten node in 'nb' // in graph 'g'. Original node is input in 'old_node'. Inputs to 'nb' are @@ -3096,13 +3088,13 @@ void MklLayoutRewritePass::GetDummyMklTensorNode(std::unique_ptr* g, TensorShape dummy_shape({8}); dummy_shape.AsProto(proto.mutable_tensor_shape()); TF_CHECK_OK(NodeBuilder((*g)->NewName("DMT"), "Const") - .Attr("value", proto) - .Attr("dtype", dt) - .Device(orig_node->def().device()) // We place this node on - // the same device as the - // device of the original - // node. - .Finalize(&**g, out)); + .Attr("value", proto) + .Attr("dtype", dt) + .Device(orig_node->def().device()) // We place this node on + // the same device as the + // device of the original + // node. + .Finalize(&**g, out)); // If number of inputs to the original node is > 0, then we add // control dependency between 1st input (index 0) of the original node and @@ -3115,8 +3107,8 @@ void MklLayoutRewritePass::GetDummyMklTensorNode(std::unique_ptr* g, // the same frame. if (orig_node->num_inputs() > 0) { Node* orig_input0 = nullptr; - TF_CHECK_OK(orig_node->input_node(0, - const_cast(&orig_input0))); + TF_CHECK_OK( + orig_node->input_node(0, const_cast(&orig_input0))); // Allow duplicate while adding control edge as it would fail (return // NULL) if we try to add duplicate edge. CHECK_NOTNULL((*g)->AddControlEdge(orig_input0, *out, true)); @@ -3126,11 +3118,9 @@ void MklLayoutRewritePass::GetDummyMklTensorNode(std::unique_ptr* g, } void MklLayoutRewritePass::GetNodesProducingMklTensorList( - std::unique_ptr* g, - Node* orig_node, - const gtl::InlinedVector, 4>& inputs, - int* input_idx, int list_length, - std::vector* output_nodes) { + std::unique_ptr* g, Node* orig_node, + const gtl::InlinedVector, 4>& inputs, int* input_idx, + int list_length, std::vector* output_nodes) { CHECK_LT(*input_idx, inputs.size()); CHECK_GT(list_length, 0); CHECK_NOTNULL(output_nodes); @@ -3147,8 +3137,8 @@ void MklLayoutRewritePass::GetNodesProducingMklTensorList( int mkl_node_output_slot = 0; GetNodeProducingMklTensor(g, orig_node, n, slot, &mkl_node, &mkl_node_output_slot); - output_nodes->push_back(NodeBuilder::NodeOut(mkl_node, - mkl_node_output_slot)); + output_nodes->push_back( + NodeBuilder::NodeOut(mkl_node, mkl_node_output_slot)); (*input_idx)++; list_length--; } @@ -3158,9 +3148,9 @@ void MklLayoutRewritePass::GetNodesProducingMklTensorList( // node that we are constructing. An input node could be (1) 'n' // if it is Mkl layer, or (2) a dummy node producing dummy Mkl tensor // if 'n' is not an Mkl layer. -void MklLayoutRewritePass::GetNodeProducingMklTensor(std::unique_ptr* g, - Node* orig_node, Node* n, - int n_output_slot, Node** mkl_node, int* mkl_node_output_slot) { +void MklLayoutRewritePass::GetNodeProducingMklTensor( + std::unique_ptr* g, Node* orig_node, Node* n, int n_output_slot, + Node** mkl_node, int* mkl_node_output_slot) { CHECK_NOTNULL(n); CHECK_NOTNULL(mkl_node); CHECK_NOTNULL(mkl_node_output_slot); @@ -3292,8 +3282,8 @@ int MklLayoutRewritePass::SetUpContiguousInputs( if (ArgIsList(arg)) { std::vector new_node_inputs; int N = GetTensorListLength(arg, old_node); - GetNodesProducingMklTensorList(g, old_node, old_node_inputs, &iidx, - N, &new_node_inputs); + GetNodesProducingMklTensorList(g, old_node, old_node_inputs, &iidx, N, + &new_node_inputs); nb->Input(new_node_inputs); nn_slot_idx++; } else { @@ -3394,13 +3384,13 @@ void MklLayoutRewritePass::GetDummyWorkspaceTensorNode( TensorShape dummy_shape({1}); dummy_shape.AsProto(proto.mutable_tensor_shape()); TF_CHECK_OK(NodeBuilder((*g)->NewName("DMT"), "Const") - .Attr("value", proto) - .Attr("dtype", dt) - .Device(orig_node->def().device()) // We place this node on - // same the device as the - // device of the original - // node. - .Finalize(&**g, out)); + .Attr("value", proto) + .Attr("dtype", dt) + .Device(orig_node->def().device()) // We place this node on + // same the device as the + // device of the original + // node. + .Finalize(&**g, out)); // If number of inputs to the original node is > 0, then we add // control dependency between 1st input (index 0) of the original node and @@ -3413,8 +3403,8 @@ void MklLayoutRewritePass::GetDummyWorkspaceTensorNode( // the same frame. if (orig_node->num_inputs() > 0) { Node* orig_input0 = nullptr; - TF_CHECK_OK(orig_node->input_node(0, - const_cast(&orig_input0))); + TF_CHECK_OK( + orig_node->input_node(0, const_cast(&orig_input0))); // Allow duplicate while adding control edge as it would fail (return // NULL) if we try to add duplicate edge. CHECK_NOTNULL((*g)->AddControlEdge(orig_input0, *out, true)); @@ -3434,8 +3424,8 @@ void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded( TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); for (auto ws : wsinfo_) { if (orig_node->type_string() == ws.fwd_op && - mkl_op_registry::IsMklOp(mkl_op_registry::GetMklOpName( - orig_node->type_string()), T)) { + mkl_op_registry::IsMklOp( + mkl_op_registry::GetMklOpName(orig_node->type_string()), T)) { // If this op is a fwd op, then we need to check if there is an // edge from this node's fwd_slot to bwdop's bwd_slot. If there is // an edge, then we just add an attribute on this node for setting @@ -3461,8 +3451,9 @@ void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded( nb->Attr("workspace_enabled", false); } } else if (orig_node->type_string() == ws.bwd_op && - mkl_op_registry::IsMklOp(mkl_op_registry::GetMklOpName( - orig_node->type_string()), T)) { + mkl_op_registry::IsMklOp( + mkl_op_registry::GetMklOpName(orig_node->type_string()), + T)) { // If this op is a bwd op, then we need to add workspace edge and // it's Mkl tensor edge between its corresponding fwd op and this // op. Corresponding fwd op is specified in 'fwd_op' field of @@ -3477,8 +3468,8 @@ void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded( if (e->src_output() == ws.fwd_slot && // We would have rewritten the forward op, so we need to use // GetMklOpName call to get its Mkl name. - e->src()->type_string() == mkl_op_registry::GetMklOpName( - ws.fwd_op) && + e->src()->type_string() == + mkl_op_registry::GetMklOpName(ws.fwd_op) && e->dst_input() == ws.bwd_slot) { nb->Attr("workspace_enabled", true); CHECK_NOTNULL(ws_tensors); @@ -3645,7 +3636,7 @@ void MklLayoutRewritePass::CopyAttrsDataType(const Node* orig_node, } void MklLayoutRewritePass::CopyAttrsReshape(const Node* orig_node, - NodeBuilder* nb) { + NodeBuilder* nb) { DataType T; DataType Tshape; @@ -3776,8 +3767,9 @@ Status MklLayoutRewritePass::MergeConv2DWithBiasAdd(std::unique_ptr* g, Node* m, Node* n) { CHECK_EQ(((m->type_string() == csinfo_.bias_add && n->type_string() == csinfo_.conv2d)) || - ((n->type_string() == csinfo_.bias_add && - m->type_string() == csinfo_.conv2d)), true); + ((n->type_string() == csinfo_.bias_add && + m->type_string() == csinfo_.conv2d)), + true); // If 'm' is BiasAdd, then 'n' is Conv2D. Since Conv2D feeds BiasAdd, // BiasAdd is successor node, and Conv2D predecessor node. @@ -3796,8 +3788,7 @@ Status MklLayoutRewritePass::MergeConv2DWithBiasAdd(std::unique_ptr* g, TF_CHECK_OK(GetNodeAttr(pred->def(), "strides", &strides)); TF_CHECK_OK(GetNodeAttr(pred->def(), "data_format", &data_format_pred)); TF_CHECK_OK(GetNodeAttr(succ->def(), "data_format", &data_format_succ)); - TF_CHECK_OK( - GetNodeAttr(pred->def(), "use_cudnn_on_gpu", &use_cudnn_on_gnu)); + TF_CHECK_OK(GetNodeAttr(pred->def(), "use_cudnn_on_gpu", &use_cudnn_on_gnu)); // We check to ensure that data formats of both succ and pred are same. // We expect them to be same, so we can enforce this as assert. // But assert can be too strict, so we enforce this as a check. @@ -3900,8 +3891,8 @@ Status MklLayoutRewritePass::MergeConv2DWithBiasAdd(std::unique_ptr* g, // BiasAdd has only 1 output (at slot 0) and merged node also has only 1 // output (at slot 0). const int kConv2DWithBiasOutputSlot = 0; - CHECK_NOTNULL((*g)->AddEdge(new_node, kConv2DWithBiasOutputSlot, - e->dst(), e->dst_input())); + CHECK_NOTNULL((*g)->AddEdge(new_node, kConv2DWithBiasOutputSlot, e->dst(), + e->dst_input())); } } @@ -3924,8 +3915,9 @@ Status MklLayoutRewritePass::MergeConv2DBackpropFilterWithBiasAddGrad( std::unique_ptr* g, Node* m, Node* n) { CHECK_EQ(((m->type_string() == csinfo_.bias_add_grad && n->type_string() == csinfo_.conv2d_grad_filter)) || - ((n->type_string() == csinfo_.bias_add_grad && - m->type_string() == csinfo_.conv2d_grad_filter)), true); + ((n->type_string() == csinfo_.bias_add_grad && + m->type_string() == csinfo_.conv2d_grad_filter)), + true); // If 'm' is BiasAddGrad, then 'n' is BackpropFilter. Node* badd = m->type_string() == csinfo_.bias_add_grad ? m : n; @@ -4132,9 +4124,10 @@ Status MklLayoutRewritePass::RewriteNode(std::unique_ptr* g, // NULL) if we try to add duplicate edge. CHECK_NOTNULL((*g)->AddControlEdge(new_node, e->dst(), true)); } else { - CHECK_NOTNULL((*g)->AddEdge(new_node, GetTensorDataIndex(e->src_output(), - e->src()->num_outputs()), - e->dst(), e->dst_input())); + CHECK_NOTNULL((*g)->AddEdge( + new_node, + GetTensorDataIndex(e->src_output(), e->src()->num_outputs()), + e->dst(), e->dst_input())); } } @@ -4166,9 +4159,9 @@ MklLayoutRewritePass::CheckForNodeRewrite(const Node* n) const { // names. if (n->type_string() != csinfo_.conv2d_with_bias && n->type_string() != csinfo_.conv2d_grad_filter_with_bias && - !mkl_op_registry::IsMklOp(mkl_op_registry::GetMklOpName( - n->type_string()), T)) { - return nullptr; + !mkl_op_registry::IsMklOp(mkl_op_registry::GetMklOpName(n->type_string()), + T)) { + return nullptr; } // For elementwise node, we reuse the Eigen implementation and pass the MKL @@ -4184,29 +4177,30 @@ MklLayoutRewritePass::CheckForNodeRewrite(const Node* n) const { // eigen code to reduce cross-library dependency. VLOG(1) << "ELEMENTWISE: checking op: " << n->type_string(); if (mkl_op_registry::IsMklElementWiseOp( - mkl_op_registry::GetMklOpName(n->type_string()), T) || + mkl_op_registry::GetMklOpName(n->type_string()), T) || n->type_string().find("Identity") != string::npos) { VLOG(1) << "ELEMENTWISE: op is elementwise: " << n->type_string(); bool incoming_mkl_edge = false; int num_parent = 0; for (auto parent : n->in_edges()) { if (mkl_op_registry::IsMklOp(parent->src()->type_string(), T)) { - VLOG(1) << "ELEMENTWISE: parent " << num_parent++ << " is MKL op: " - << parent->src()->type_string(); + VLOG(1) << "ELEMENTWISE: parent " << num_parent++ + << " is MKL op: " << parent->src()->type_string(); incoming_mkl_edge = true; break; } else { - VLOG(1) << "ELEMENTWISE: parent " << num_parent++ << " is NON-MKL op: " - << parent->src()->type_string(); + VLOG(1) << "ELEMENTWISE: parent " << num_parent++ + << " is NON-MKL op: " << parent->src()->type_string(); } } if (incoming_mkl_edge == false) { - VLOG(1) << "ELEMENTWISE: Skipping replacement of elementwise node which has no MKL " + VLOG(1) << "ELEMENTWISE: Skipping replacement of elementwise node which " + "has no MKL " "parents."; return nullptr; } else { - VLOG(1) << "ELEMENTWISE: Replacing elementwise node " << n->type_string() << - " which has MKL parents"; + VLOG(1) << "ELEMENTWISE: Replacing elementwise node " << n->type_string() + << " which has MKL parents"; } } @@ -4214,8 +4208,7 @@ MklLayoutRewritePass::CheckForNodeRewrite(const Node* n) const { // for this op, then we rewrite it to Mkl op. // Find matching RewriteInfo and then check that rewrite rule applies. for (auto ri = rinfo_.cbegin(); ri != rinfo_.cend(); ++ri) { - if (n->type_string().compare(ri->name) == 0 && - ri->rewrite_rule(n)) { + if (n->type_string().compare(ri->name) == 0 && ri->rewrite_rule(n)) { return &*ri; } } @@ -4297,8 +4290,7 @@ bool RunMklLayoutRewritePass(std::unique_ptr* g) { return MklLayoutRewritePass().RunPass(g); } -Status MklLayoutRewritePass::Run( - const GraphOptimizationPassOptions& options) { +Status MklLayoutRewritePass::Run(const GraphOptimizationPassOptions& options) { if (options.graph == nullptr && options.partition_graphs == nullptr) { return Status::OK(); } diff --git a/tensorflow/core/graph/mkl_layout_pass_test.cc b/tensorflow/core/graph/mkl_layout_pass_test.cc index 7b8c3cccc5b81ad09b5a572de9fd2102f23f1f84..5e2a465e22c7cbe45cbea40ea7a11491e2b2ad24 100644 --- a/tensorflow/core/graph/mkl_layout_pass_test.cc +++ b/tensorflow/core/graph/mkl_layout_pass_test.cc @@ -125,8 +125,10 @@ REGISTER_OP("InputList").Output("o: N * float").Attr("N: int").SetIsStateful(); REGISTER_OP("HalfInput").Output("o: half").SetIsStateful(); REGISTER_OP("Int32Input").Output("o: int32").SetIsStateful(); REGISTER_OP("_MklInput").Output("o: uint8").SetIsStateful(); -REGISTER_OP("_MklInput2").Output("o: uint8") - .Output("o1: uint8").SetIsStateful(); +REGISTER_OP("_MklInput2") + .Output("o: uint8") + .Output("o1: uint8") + .SetIsStateful(); ///////////////////////////////////////////////////////////////////// // Unit tests related to node merge optiimization @@ -498,7 +500,6 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_Negative2) { "M->I:3;N->D:4;N->G:4;N->I:4;O->D:5;O->G:5;O->I:5"); } - // BiasAddGrad rewrite to BackpropBias in the presence of BackpropFilter only TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_BpropFilter_Positive) { InitGraph( @@ -874,11 +875,12 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Concat_Basic) { " input: ['A', 'B:0', 'B:1']}" "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" " input: ['C', 'D'] }"); - EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(Const);B(InputList);C(Input);D(_MklConcat);DMT/_0(Const);" - "DMT/_1(Const);DMT/_2(Const);E(Zeta)|A->D;A:control->DMT/_0:control;" - "A:control->DMT/_1:control;A:control->DMT/_2:control;B->D:1;" - "B:1->D:2;C->E;D->E:1;DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5"); + EXPECT_EQ( + DoMklLayoutOptimizationPass(), + "A(Const);B(InputList);C(Input);D(_MklConcat);DMT/_0(Const);" + "DMT/_1(Const);DMT/_2(Const);E(Zeta)|A->D;A:control->DMT/_0:control;" + "A:control->DMT/_1:control;A:control->DMT/_2:control;B->D:1;" + "B:1->D:2;C->E;D->E:1;DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5"); } // Concat with 2 Mkl layers feeding it @@ -1273,7 +1275,8 @@ TEST_F(MklLayoutPassTest, MaxPoolLRN_Positive) { "node { name: 'H' op: 'Input'}" "node { name: 'I' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" " input: ['H', 'G'] }"); - EXPECT_EQ(DoMklLayoutOptimizationPass(), + EXPECT_EQ( + DoMklLayoutOptimizationPass(), "A(Input);B(_MklLRN);C(_MklMaxPool);D(Input);DMT/_0(Const);DMT/_1(Const);" "DMT/_2(Const);E(_MklMaxPoolGrad);F(Input);G(_MklLRNGrad);H(Input);" "I(Zeta)|A->B;A:control->DMT/_0:control;B->C;B->E;B->G:2;B:1->G:3;" @@ -1640,7 +1643,8 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Conv2D_DeviceTest) { " attr { key: 'padding' value { s: 'SAME' } }" " input: ['A', 'B']}" "node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" - " input: ['B', 'C'] }", kGPUDevice); + " input: ['B', 'C'] }", + kGPUDevice); EXPECT_EQ(DoMklLayoutOptimizationPass(), "A(Input);B(Input);C(Conv2D);D(Zeta)|A->C;B->C:1;B->D;C->D:1"); } @@ -1666,7 +1670,8 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_DeviceTest) { "node { name: 'F' op: 'BiasAddGrad'" " attr { key: 'T' value { type: DT_FLOAT } }" " attr { key: 'data_format' value { s: 'NCHW' } }" - " input: ['E'] }", kGPUDevice); + " input: ['E'] }", + kGPUDevice); EXPECT_EQ(DoMklLayoutOptimizationPass(), "A(Input);B(Input);C(Input);D(_MklConv2DWithBias);" "E(Zeta);F(BiasAddGrad);M(_MklInput);N(_MklInput);" @@ -1687,7 +1692,8 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Conv2DGradFilter_DeviceTest) { " attr { key: 'padding' value { s: 'SAME' } }" " input: ['A', 'B', 'C']}" "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" - " input: ['A', 'D'] }", kGPUDevice); + " input: ['A', 'D'] }", + kGPUDevice); EXPECT_EQ(DoMklLayoutOptimizationPass(), "A(Input);B(Int32Input);C(Input);D(Conv2DBackpropFilter);E(Zeta)|" "A->D;A->E;B->D:1;C->D:2;D->E:1"); @@ -1700,7 +1706,8 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Relu_DeviceTest) { " attr { key: 'T' value { type: DT_FLOAT } }" " input: ['A'] }" "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" - " input: ['A', 'B'] }", kGPUDevice); + " input: ['A', 'B'] }", + kGPUDevice); EXPECT_EQ(DoMklLayoutOptimizationPass(), "A(Input);B(Relu);C(Zeta)|A->B;A->C;B->C:1"); } @@ -1713,7 +1720,8 @@ TEST_F(MklLayoutPassTest, NodeRewrite_ReluGrad_DeviceTest) { " attr { key: 'T' value { type: DT_FLOAT } }" " input: ['A', 'B'] }" "node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" - " input: ['A', 'C'] }", kGPUDevice); + " input: ['A', 'C'] }", + kGPUDevice); EXPECT_EQ(DoMklLayoutOptimizationPass(), "A(Input);B(Input);C(ReluGrad);D(Zeta)|A->C;A->D;B->C:1;C->D:1"); } @@ -1729,7 +1737,8 @@ TEST_F(MklLayoutPassTest, NodeRewrite_MaxPool_DeviceTest) { " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" " input: ['A'] }" "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" - " input: ['A', 'B'] }", kGPUDevice); + " input: ['A', 'B'] }", + kGPUDevice); EXPECT_EQ(DoMklLayoutOptimizationPass(), "A(Input);B(MaxPool);C(Zeta)|A->B;A->C;B->C:1"); } @@ -1745,7 +1754,8 @@ TEST_F(MklLayoutPassTest, NodeRewrite_AvgPool_DeviceTest) { " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" " input: ['A'] }" "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" - " input: ['A', 'B'] }", kGPUDevice); + " input: ['A', 'B'] }", + kGPUDevice); EXPECT_EQ(DoMklLayoutOptimizationPass(), "A(Input);B(AvgPool);C(Zeta)|A->B;A->C;B->C:1"); } @@ -1766,7 +1776,8 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Concat_DeviceTest) { " attr { key: 'N' value { i: 2 } }" " input: ['A', 'B:0', 'B:1']}" "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" - " input: ['C', 'D'] }", kGPUDevice); + " input: ['C', 'D'] }", + kGPUDevice); EXPECT_EQ(DoMklLayoutOptimizationPass(), "A(Const);B(InputList);C(Input);D(Concat);E(Zeta)|A->D;" "B->D:1;B:1->D:2;C->E;D->E:1"); @@ -1788,7 +1799,8 @@ TEST_F(MklLayoutPassTest, NodeRewrite_ConcatV2_DeviceTest) { " attr { key: 'N' value { i: 2 } }" " input: ['B:0', 'B:1', 'A']}" "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" - " input: ['C', 'D'] }", kGPUDevice); + " input: ['C', 'D'] }", + kGPUDevice); EXPECT_EQ(DoMklLayoutOptimizationPass(), "A(Const);B(InputList);C(Input);D(ConcatV2);E(Zeta)|" "A->D:2;B->D;B:1->D:1;C->E;D->E:1"); @@ -1808,7 +1820,8 @@ TEST_F(MklLayoutPassTest, NodeRewrite_FusedBatchNorm_DeviceTest) { " attr { key: 'is_training' value { b: true } }" " input: ['A', 'B', 'C', 'D', 'E'] }" "node { name: 'G' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" - " input: ['A', 'F'] }", kGPUDevice); + " input: ['A', 'F'] }", + kGPUDevice); EXPECT_EQ(DoMklLayoutOptimizationPass(), "A(Input);B(Input);C(Input);D(Input);E(Input);" "F(FusedBatchNorm);G(Zeta)|A->F;A->G;B->F:1;C->F:2;D->F:3;" @@ -1837,7 +1850,8 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_DeviceTest) { "node { name: 'Y' op: 'Input'}" "node { name: 'Z' op: 'Zeta'" " attr {key: 'T' value { type: DT_FLOAT } }" - " input: ['E', 'Y']}", kGPUDevice); + " input: ['E', 'Y']}", + kGPUDevice); EXPECT_EQ(DoMklLayoutOptimizationPass(), "A(Input);B(Input);C(_MklConv2D);D(Input);E(BiasAdd);" "M(_MklInput);N(_MklInput);Y(Input);Z(Zeta)|A->C;" @@ -1972,8 +1986,10 @@ REGISTER_OP("InputList").Output("o: N * float").Attr("N: int").SetIsStateful(); REGISTER_OP("HalfInput").Output("o: half").SetIsStateful(); REGISTER_OP("Int32Input").Output("o: int32").SetIsStateful(); REGISTER_OP("_MklInput").Output("o: uint8").SetIsStateful(); -REGISTER_OP("_MklInput2").Output("o: uint8") - .Output("o1: uint8").SetIsStateful(); +REGISTER_OP("_MklInput2") + .Output("o: uint8") + .Output("o1: uint8") + .SetIsStateful(); ///////////////////////////////////////////////////////////////////// // Unit tests related to node merge optiimization @@ -2492,11 +2508,12 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Concat_Basic) { " input: ['A', 'B:0', 'B:1']}" "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" " input: ['C', 'D'] }"); - EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(Const);B(InputList);C(Input);D(_MklConcat);DMT/_0(Const);" - "DMT/_1(Const);DMT/_2(Const);E(Zeta)|A->D;A:control->DMT/_0:control;" - "A:control->DMT/_1:control;A:control->DMT/_2:control;B->D:1;" - "B:1->D:2;C->E;D->E:1;DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5"); + EXPECT_EQ( + DoMklLayoutOptimizationPass(), + "A(Const);B(InputList);C(Input);D(_MklConcat);DMT/_0(Const);" + "DMT/_1(Const);DMT/_2(Const);E(Zeta)|A->D;A:control->DMT/_0:control;" + "A:control->DMT/_1:control;A:control->DMT/_2:control;B->D:1;" + "B:1->D:2;C->E;D->E:1;DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5"); } // Concat with 2 Mkl layers feeding it @@ -2891,7 +2908,8 @@ TEST_F(MklLayoutPassTest, MaxPoolLRN_Positive) { "node { name: 'H' op: 'Input'}" "node { name: 'I' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" " input: ['H', 'G'] }"); - EXPECT_EQ(DoMklLayoutOptimizationPass(), + EXPECT_EQ( + DoMklLayoutOptimizationPass(), "A(Input);B(_MklLRN);C(_MklMaxPool);D(Input);DMT/_0(Const);DMT/_1(Const);" "DMT/_2(Const);E(_MklMaxPoolGrad);F(Input);G(_MklLRNGrad);H(Input);" "I(Zeta)|A->B;A:control->DMT/_0:control;B->C;B->E;B->G:2;B:1->G:3;" @@ -3258,7 +3276,8 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Conv2D_DeviceTest) { " attr { key: 'padding' value { s: 'SAME' } }" " input: ['A', 'B']}" "node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" - " input: ['B', 'C'] }", kGPUDevice); + " input: ['B', 'C'] }", + kGPUDevice); EXPECT_EQ(DoMklLayoutOptimizationPass(), "A(Input);B(Input);C(Conv2D);D(Zeta)|A->C;B->C:1;B->D;C->D:1"); } @@ -3284,7 +3303,8 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_DeviceTest) { "node { name: 'F' op: 'BiasAddGrad'" " attr { key: 'T' value { type: DT_FLOAT } }" " attr { key: 'data_format' value { s: 'NCHW' } }" - " input: ['E'] }", kGPUDevice); + " input: ['E'] }", + kGPUDevice); EXPECT_EQ(DoMklLayoutOptimizationPass(), "A(Input);B(Input);C(Input);D(_MklConv2DWithBias);" "E(Zeta);F(BiasAddGrad);M(_MklInput);N(_MklInput);" @@ -3305,7 +3325,8 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Conv2DGradFilter_DeviceTest) { " attr { key: 'padding' value { s: 'SAME' } }" " input: ['A', 'B', 'C']}" "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" - " input: ['A', 'D'] }", kGPUDevice); + " input: ['A', 'D'] }", + kGPUDevice); EXPECT_EQ(DoMklLayoutOptimizationPass(), "A(Input);B(Int32Input);C(Input);D(Conv2DBackpropFilter);E(Zeta)|" "A->D;A->E;B->D:1;C->D:2;D->E:1"); @@ -3318,7 +3339,8 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Relu_DeviceTest) { " attr { key: 'T' value { type: DT_FLOAT } }" " input: ['A'] }" "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" - " input: ['A', 'B'] }", kGPUDevice); + " input: ['A', 'B'] }", + kGPUDevice); EXPECT_EQ(DoMklLayoutOptimizationPass(), "A(Input);B(Relu);C(Zeta)|A->B;A->C;B->C:1"); } @@ -3331,7 +3353,8 @@ TEST_F(MklLayoutPassTest, NodeRewrite_ReluGrad_DeviceTest) { " attr { key: 'T' value { type: DT_FLOAT } }" " input: ['A', 'B'] }" "node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" - " input: ['A', 'C'] }", kGPUDevice); + " input: ['A', 'C'] }", + kGPUDevice); EXPECT_EQ(DoMklLayoutOptimizationPass(), "A(Input);B(Input);C(ReluGrad);D(Zeta)|A->C;A->D;B->C:1;C->D:1"); } @@ -3347,7 +3370,8 @@ TEST_F(MklLayoutPassTest, NodeRewrite_MaxPool_DeviceTest) { " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" " input: ['A'] }" "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" - " input: ['A', 'B'] }", kGPUDevice); + " input: ['A', 'B'] }", + kGPUDevice); EXPECT_EQ(DoMklLayoutOptimizationPass(), "A(Input);B(MaxPool);C(Zeta)|A->B;A->C;B->C:1"); } @@ -3363,7 +3387,8 @@ TEST_F(MklLayoutPassTest, NodeRewrite_AvgPool_DeviceTest) { " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" " input: ['A'] }" "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" - " input: ['A', 'B'] }", kGPUDevice); + " input: ['A', 'B'] }", + kGPUDevice); EXPECT_EQ(DoMklLayoutOptimizationPass(), "A(Input);B(AvgPool);C(Zeta)|A->B;A->C;B->C:1"); } @@ -3384,7 +3409,8 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Concat_DeviceTest) { " attr { key: 'N' value { i: 2 } }" " input: ['A', 'B:0', 'B:1']}" "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" - " input: ['C', 'D'] }", kGPUDevice); + " input: ['C', 'D'] }", + kGPUDevice); EXPECT_EQ(DoMklLayoutOptimizationPass(), "A(Const);B(InputList);C(Input);D(Concat);E(Zeta)|A->D;" "B->D:1;B:1->D:2;C->E;D->E:1"); @@ -3406,7 +3432,8 @@ TEST_F(MklLayoutPassTest, NodeRewrite_ConcatV2_DeviceTest) { " attr { key: 'N' value { i: 2 } }" " input: ['B:0', 'B:1', 'A']}" "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" - " input: ['C', 'D'] }", kGPUDevice); + " input: ['C', 'D'] }", + kGPUDevice); EXPECT_EQ(DoMklLayoutOptimizationPass(), "A(Const);B(InputList);C(Input);D(ConcatV2);E(Zeta)|" "A->D:2;B->D;B:1->D:1;C->E;D->E:1"); @@ -3426,7 +3453,8 @@ TEST_F(MklLayoutPassTest, NodeRewrite_FusedBatchNorm_DeviceTest) { " attr { key: 'is_training' value { b: true } }" " input: ['A', 'B', 'C', 'D', 'E'] }" "node { name: 'G' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" - " input: ['A', 'F'] }", kGPUDevice); + " input: ['A', 'F'] }", + kGPUDevice); EXPECT_EQ(DoMklLayoutOptimizationPass(), "A(Input);B(Input);C(Input);D(Input);E(Input);" "F(FusedBatchNorm);G(Zeta)|A->F;A->G;B->F:1;C->F:2;D->F:3;" @@ -3455,7 +3483,8 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_DeviceTest) { "node { name: 'Y' op: 'Input'}" "node { name: 'Z' op: 'Zeta'" " attr {key: 'T' value { type: DT_FLOAT } }" - " input: ['E', 'Y']}", kGPUDevice); + " input: ['E', 'Y']}", + kGPUDevice); EXPECT_EQ(DoMklLayoutOptimizationPass(), "A(Input);B(Input);C(_MklConv2D);D(Input);E(BiasAdd);" "M(_MklInput);N(_MklInput);Y(Input);Z(Zeta)|A->C;" diff --git a/tensorflow/core/graph/mkl_tfconversion_pass.cc b/tensorflow/core/graph/mkl_tfconversion_pass.cc index 599bb88f015bfc035b7666747571a652a954139d..5343e6802d1e75f516925d44ab680b96f4e157da 100644 --- a/tensorflow/core/graph/mkl_tfconversion_pass.cc +++ b/tensorflow/core/graph/mkl_tfconversion_pass.cc @@ -33,8 +33,8 @@ limitations under the License. #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/graph/mkl_tfconversion_pass.h" #include "tensorflow/core/graph/mkl_graph_util.h" +#include "tensorflow/core/graph/mkl_tfconversion_pass.h" namespace tensorflow { @@ -152,12 +152,12 @@ Status MklToTfConversionPass::InsertConversionNodeOnEdge( string data_format; TF_CHECK_OK(GetNodeAttr(src->def(), "T", &src_datatype)); - bool dst_dtype_found = GetNodeAttr(dst->def(), "T", &dst_datatype) == - Status::OK(); + bool dst_dtype_found = + GetNodeAttr(dst->def(), "T", &dst_datatype) == Status::OK(); // We compare source and destination datatypes only when both are found. if (dst_dtype_found && (src_datatype != dst_datatype)) { - string err_msg = "T attribute of " + src->name() + " and " + - dst->name() + " do not match. Will not insert" + + string err_msg = "T attribute of " + src->name() + " and " + dst->name() + + " do not match. Will not insert" + " MklToTf node in such case."; return Status(error::Code::INVALID_ARGUMENT, err_msg.c_str()); } @@ -325,12 +325,12 @@ bool MklToTfConversionPass::RunPass(std::unique_ptr* g) { // may not be Mkl node. DataType src_datatype; DataType dst_datatype; - bool src_is_mkl_op = (GetNodeAttr(src->def(), "T", &src_datatype) == - Status::OK() && - IsMklSupportedOp(src->type_string(), src_datatype)); - bool dst_is_mkl_op = (GetNodeAttr(dst->def(), "T", &dst_datatype) == - Status::OK() && - IsMklSupportedOp(dst->type_string(), dst_datatype)); + bool src_is_mkl_op = + (GetNodeAttr(src->def(), "T", &src_datatype) == Status::OK() && + IsMklSupportedOp(src->type_string(), src_datatype)); + bool dst_is_mkl_op = + (GetNodeAttr(dst->def(), "T", &dst_datatype) == Status::OK() && + IsMklSupportedOp(dst->type_string(), dst_datatype)); // Check if src with is Mkl-compliant, while dst is not Mkl-compliant. if (src_is_mkl_op && !dst_is_mkl_op) { diff --git a/tensorflow/core/grappler/clusters/utils.cc b/tensorflow/core/grappler/clusters/utils.cc index 592e4b789d0dcb7369e2f0c6db447eb9daa92870..aacd2ccb72df07ac6b31c9bd5b96deca499038e4 100644 --- a/tensorflow/core/grappler/clusters/utils.cc +++ b/tensorflow/core/grappler/clusters/utils.cc @@ -30,6 +30,7 @@ limitations under the License. #include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/cpu_info.h" +#include "tensorflow/core/platform/mem.h" namespace tensorflow { namespace grappler { @@ -48,6 +49,11 @@ DeviceProperties GetLocalCPUInfo() { device.set_l2_cache_size(Eigen::l2CacheSize()); device.set_l3_cache_size(Eigen::l3CacheSize()); + int64 free_mem = port::AvailableRam(); + if (free_mem < INT64_MAX) { + device.set_memory_size(free_mem); + } + (*device.mutable_environment())["cpu_instruction_set"] = Eigen::SimdInstructionSetsInUse(); diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc index cf317374cfa2bfe1d587e8e4d54a1234717abaa9..1af973855ea6beafa5807f19a2cbad7efa6e67ad 100644 --- a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc +++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc @@ -353,6 +353,9 @@ OpLevelCostEstimator::DeviceInfo OpLevelCostEstimator::GetDeviceInfo( VLOG(1) << "Device: " << device.type() << " gflops: " << gflops << " gb_per_sec: " << gb_per_sec; + DCHECK_LT(0, gflops) << device.DebugString(); + DCHECK_LT(0, gb_per_sec) << device.DebugString(); + return {gflops, gb_per_sec}; } @@ -408,6 +411,12 @@ Costs OpLevelCostEstimator::PredictCostOfAnUnknownOp( Costs OpLevelCostEstimator::PredictOpCountBasedCost( double operations, const OpInfo& op_features) const { DeviceInfo device_perf = GetDeviceInfo(op_features.device()); + if (device_perf.gigaops <= 0 || device_perf.gb_per_sec <= 0) { + VLOG(1) << "BAD DEVICE. Op:" << op_features.op() + << " device type:" << op_features.device().type() + << " device model:" << op_features.device().model(); + } + Costs::NanoSeconds compute_cost(std::ceil(operations / device_perf.gigaops)); VLOG(1) << "Op:" << op_features.op() << " GOps:" << operations / 1e9 << " Execution Time (ns):" << compute_cost.count(); diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.cc b/tensorflow/core/grappler/costs/virtual_scheduler.cc index d7d07ee7a55665a2d809588f45fbfd166bd2f76a..020492a3e9e23a8360a5e8804bc51ba6c5de67d1 100644 --- a/tensorflow/core/grappler/costs/virtual_scheduler.cc +++ b/tensorflow/core/grappler/costs/virtual_scheduler.cc @@ -323,8 +323,13 @@ Status VirtualScheduler::Init() { } // Get the nodes that would run to output fetch_nodes. + bool ill_formed = false; std::vector nodes = - ComputeTransitiveFanin(graph, fetch_nodes); + ComputeTransitiveFanin(graph, fetch_nodes, &ill_formed); + if (ill_formed) { + return errors::InvalidArgument( + "Ill formed graph or invalid set of fetch nodes specified"); + } // TODO(dyoon): this is a bit inefficient as name_to_node is already built in // ComputeTransitiveFanin(). diff --git a/tensorflow/core/grappler/graph_view.h b/tensorflow/core/grappler/graph_view.h index f4e2de75a60182f3b2bbc366c076052bd0fae118..173ce9c09c2fd98d855a801131ed16a796d9caac 100644 --- a/tensorflow/core/grappler/graph_view.h +++ b/tensorflow/core/grappler/graph_view.h @@ -46,6 +46,7 @@ class GraphView { }; explicit GraphView(GraphDef* graph); + GraphDef* GetGraph() const { return graph_; } NodeDef* GetNode(const string& node_name) const; // Get the specified input port. Note that the special '-1' port_id can be // used to access the controlling nodes (i.e. the nodes connected to node_name diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD index 68de03e81ca83c0bc1028d167c10453791ed9afc..2ac31ebf6a24c3b13f396a5176e6402a72a0cb67 100644 --- a/tensorflow/core/grappler/optimizers/BUILD +++ b/tensorflow/core/grappler/optimizers/BUILD @@ -125,6 +125,7 @@ tf_cc_test( "//tensorflow/core:testlib", "//tensorflow/core/grappler:grappler_item", "//tensorflow/core/grappler:utils", + "//tensorflow/core/grappler/utils:grappler_test", ], ) @@ -289,6 +290,7 @@ cc_library( "//tensorflow/core/grappler/costs:graph_memory", "//tensorflow/core/grappler/costs:graph_properties", "//tensorflow/core/grappler/utils:topological_sort", + "//tensorflow/core/grappler/utils:traversal", ], ) @@ -300,11 +302,13 @@ tf_cc_test( "//tensorflow/cc:cc_ops", "//tensorflow/core:ops", "//tensorflow/core:protos_all_cc", + "//tensorflow/core:tensor_testutil", "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core/grappler:grappler_item", "//tensorflow/core/grappler:utils", "//tensorflow/core/grappler/clusters:virtual_cluster", + "//tensorflow/core/grappler/utils:grappler_test", ], ) diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc index 0aeff6222c291455c04cf3fb68a90298724385dd..37a4759fddd5d990e46691cca55b6238e3acf7cc 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding.cc @@ -1658,7 +1658,7 @@ Status ConstantFolding::RunOptimizationPass(Cluster* cluster, // more with the original node name. for (const auto& fetch : item.fetch) { const NodeDef* fetch_node = node_map_->GetNode(fetch); - if (fetch_node && NumOutputs(*fetch_node) == 1) { + if (fetch_node && NumOutputs(*fetch_node, graph_) == 1) { nodes_whitelist_.insert(fetch_node->name()); } } diff --git a/tensorflow/core/grappler/optimizers/constant_folding_test.cc b/tensorflow/core/grappler/optimizers/constant_folding_test.cc index 849a88770ae6127c6f2e3fac968a976c0a523a0b..46998dcc91c8df2313ff92b056f732379b173661 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding_test.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding_test.cc @@ -20,30 +20,15 @@ limitations under the License. #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/grappler/grappler_item.h" #include "tensorflow/core/grappler/utils.h" +#include "tensorflow/core/grappler/utils/grappler_test.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/platform/test.h" -#include "tensorflow/core/public/session.h" namespace tensorflow { namespace grappler { namespace { -class ConstantFoldingTest : public ::testing::Test { - protected: - std::vector EvaluateNodes(const GraphDef& graph, - const std::vector& fetch) { - SessionOptions options; - std::unique_ptr session(NewSession(options)); - TF_CHECK_OK(session->Create(graph)); - RunOptions run_options; - std::vector output_tensors; - TF_CHECK_OK( - session->Run(run_options, {}, fetch, fetch, &output_tensors, nullptr)); - TF_CHECK_OK(session->Close()); - return output_tensors; - } -}; +class ConstantFoldingTest : public GrapplerTest {}; TEST_F(ConstantFoldingTest, SimpleFolding) { // Build a simple graph with a few trivially prunable ops. diff --git a/tensorflow/core/grappler/optimizers/dependency_optimizer.cc b/tensorflow/core/grappler/optimizers/dependency_optimizer.cc index d2da125236ab4f9b386ba2c6dc808e2b030c819c..db64e530264f44d54492f678bc87a08bcf88bc26 100644 --- a/tensorflow/core/grappler/optimizers/dependency_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/dependency_optimizer.cc @@ -36,20 +36,20 @@ namespace grappler { namespace { -int RemoveInput(NodeDef* node, const string& input, NodeMap* node_map) { - int num_removed = 0; +bool RemoveInput(NodeDef* node, const string& input, NodeMap* node_map) { + bool removed_input = false; int pos = 0; while (pos < node->input_size()) { if (node->input(pos) == input) { node->mutable_input()->SwapElements(pos, node->input_size() - 1); node->mutable_input()->RemoveLast(); node_map->RemoveOutput(NodeName(input), node->name()); + removed_input = true; } else { ++pos; } - ++num_removed; } - return num_removed; + return removed_input; } // Remove duplicate control inputs. @@ -71,6 +71,48 @@ void PruneControlInputs(NodeDef* node) { } // namespace +bool DependencyOptimizer::SafeToRemoveIdentity(const NodeDef& node) { + if (!IsIdentity(node)) { + return true; + } + if (nodes_to_preserve_.find(node.name()) != nodes_to_preserve_.end()) { + return false; + } + if (!fetch_nodes_known_) { + // The output values of this node may be needed. + return false; + } + const NodeDef* input = node_map_->GetNode(NodeName(node.input(0))); + CHECK(input != nullptr) << "node = " << node.name() + << " input = " << node.input(0); + // Don't remove Identity nodes corresponding to Variable reads or following + // Recv. + if (IsVariable(*input) || IsRecv(*input)) { + return false; + } else if (IsSwitch(*input)) { + // Don't turn Identity nodes following Switch into NoOp or remove them + // if it requires anchoring a control dependencies the Switch node, which + // is not valid. + if (StringPiece(node.name()).starts_with(kConstantFoldingCtrl)) { + // TODO(rmlarsen): Try to remove this artificial contraint. + return false; + } + } + for (auto consumer : node_map_->GetOutputs(node.name())) { + if (node.input_size() > 1 && IsMerge(*consumer)) { + return false; + } + if (IsSwitch(*input)) { + for (const string& consumer_input : consumer->input()) { + if (consumer_input == AsControlDependency(node.name())) { + return false; + } + } + } + } + return true; +} + bool DependencyOptimizer::SafeToConvertToNoOp(const NodeDef& node) { if (nodes_to_preserve_.find(node.name()) != nodes_to_preserve_.end()) { return false; @@ -100,18 +142,8 @@ bool DependencyOptimizer::SafeToConvertToNoOp(const NodeDef& node) { return false; } - // Don't turn Identity nodes inserted by Grappler after Switch into NoOp, - // since we cannot anchor control dependencies on Switch nodes. - // Don't remove Identity nodes corresponding to Variable reads. - if (IsIdentity(node)) { - const NodeDef* input = node_map_->GetNode(NodeName(node.input(0))); - if (input != nullptr) { - if (IsVariable(*input) || - (StringPiece(node.name()).starts_with(kConstantFoldingCtrl) && - IsSwitch(*input))) { - return false; - } - } + if (!SafeToRemoveIdentity(node)) { + return false; } const std::unordered_set do_not_rewrite_ops{ @@ -125,18 +157,20 @@ void DependencyOptimizer::OptimizeNode(int node_idx, SetVector* nodes_to_simplify, std::set* nodes_to_delete) { NodeDef* node = optimized_graph_->mutable_node(node_idx); - + const bool is_noop = IsNoOp(*node); + const bool is_identity = IsIdentity(*node); + const string node_name = node->name(); // Constant nodes with no input control dependency are always executed early, // so we can prune all their output control dependencies. if (IsConstant(*node) && node->input_size() == 0) { - const std::set output_nodes = node_map_->GetOutputs(node->name()); + const std::set output_nodes = node_map_->GetOutputs(node_name); for (NodeDef* fanout : output_nodes) { bool optimize_fanout = false; bool data_connection = false; for (int i = fanout->input_size() - 1; i >= 0; --i) { int pos; string input_name = ParseNodeName(fanout->input(i), &pos); - if (input_name == node->name()) { + if (input_name == node_name) { if (pos < 0) { fanout->mutable_input()->SwapElements(i, fanout->input_size() - 1); fanout->mutable_input()->RemoveLast(); @@ -149,22 +183,21 @@ void DependencyOptimizer::OptimizeNode(int node_idx, if (optimize_fanout) { nodes_to_simplify->PushBack(node_to_idx_[fanout]); if (!data_connection) { - node_map_->RemoveOutput(node->name(), fanout->name()); + node_map_->RemoveOutput(node_name, fanout->name()); } } } - if (node_map_->GetOutputs(node->name()).empty() && fetch_nodes_known_ && - nodes_to_preserve_.find(node->name()) == nodes_to_preserve_.end()) { + if (node_map_->GetOutputs(node_name).empty() && fetch_nodes_known_ && + nodes_to_preserve_.find(node_name) == nodes_to_preserve_.end()) { // Mark the node for deletion. nodes_to_delete->insert(node_to_idx_[node]); } - return; } // Change ops that only have control dependencies as outputs to NoOps. - if (node->op() != "NoOp" && SafeToConvertToNoOp(*node)) { - VLOG(1) << "***** Replacing " << node->name() << " (" << node->op() + if (!is_noop && SafeToConvertToNoOp(*node)) { + VLOG(1) << "***** Replacing " << node_name << " (" << node->op() << ") with NoOp."; // The outputs of this node are not consumed. Replace its inputs with // control dependencies and replace the op itself with the NoOp op. @@ -186,7 +219,7 @@ void DependencyOptimizer::OptimizeNode(int node_idx, old_input, optimized_graph_, node_map_.get()); if (ctrl_inputs.insert(ctrl_input).second) { node->set_input(pos, ctrl_input); - node_map_->UpdateInput(node->name(), old_input, ctrl_input); + node_map_->UpdateInput(node_name, old_input, ctrl_input); const NodeDef* old_input_node = node_map_->GetNode(old_input); nodes_to_simplify->PushBack(node_to_idx_[old_input_node]); } @@ -194,6 +227,8 @@ void DependencyOptimizer::OptimizeNode(int node_idx, } node->set_op("NoOp"); node->clear_attr(); + nodes_to_simplify->PushBack(node_to_idx_[node]); + return; } // Remove NoOp nodes if the product of their fan-in and fan-out is less than @@ -222,9 +257,30 @@ void DependencyOptimizer::OptimizeNode(int node_idx, // a and x, respectively, are on the same device. Control edges across device // boundaries require inter-device communication (Send/Recv pairs to be // inserted in the graph), which is very costly. + // + // We also remove identity nodes, subject to the same constraints on number of + // resulting control edges and device boundary crossings: + // + // Case a) + // +----------+ ---> a +---+ ---> a + // x --> | Identity | --^> b ==> | x | --^> b + // | | ... | | ... + // +----------+ --^> c +---+ --^> c + // + // Case b) + // x ---> +----------+ ---> a x ---> +---+ + // y --^> | Identity | ==> y --^> | a | + // ... | | ... | | + // z --^> +----------+ z --^> +---+ + // + // Case c) + // +----------+ x ---> +---+ + // x ---> | Identity | ---> a ==> \--^> | a | + // y --^> | | --^> b /\ +---+ + // +----------+ y --^> b - if (node->op() == "NoOp") { - const auto& output_node_set = node_map_->GetOutputs(node->name()); + if (is_noop || is_identity) { + const auto& output_node_set = node_map_->GetOutputs(node_name); const std::vector output_nodes(output_node_set.begin(), output_node_set.end()); const int num_outputs = output_nodes.size(); @@ -233,15 +289,14 @@ void DependencyOptimizer::OptimizeNode(int node_idx, if (num_inputs * num_outputs > num_inputs + num_outputs) { return; } - VLOG(1) << "***** Rerouting input around " << node->name(); std::vector input_nodes; for (int i = 0; i < num_inputs; ++i) { - NodeDef* tmp = node_map_->GetNode(node->input(i)); - CHECK_NE(tmp, nullptr); - input_nodes.push_back(tmp); + NodeDef* input_node = node_map_->GetNode(node->input(i)); + CHECK_NE(input_node, nullptr); + input_nodes.push_back(input_node); } - // Make sure that we don't increase the number of control edges that cross + // Make sure that we don't increase the number of edges that cross // device boundaries. if ((num_inputs == 1 && num_outputs > 1 && input_nodes[0]->device() != node->device()) || @@ -266,40 +321,75 @@ void DependencyOptimizer::OptimizeNode(int node_idx, if (num_cross_after > num_cross_before) { return; } + // To avoid potentially removing Identity nodes following _Recv nodes, + // we require that no device crossings occur in that case. + // TODO(rmlarsen): See if we can relax this condition. + if (is_identity && (num_cross_after > 0 || num_cross_before > 0)) { + return; + } + } + if (is_identity && !SafeToRemoveIdentity(*node)) { + return; } + + VLOG(1) << "***** Rerouting input around\n" << node->DebugString(); + // Now remove the node and re-wire its inputs to its outputs. for (auto consumer : output_nodes) { bool updated_consumer = false; - VLOG(1) << "***** Considering consumer " << consumer->name() << "\n" - << consumer->DebugString(); + VLOG(1) << "consumer before:\n" << consumer->DebugString(); for (int i = 0; i < num_inputs; ++i) { const NodeDef* input = input_nodes[i]; // Forward dependency from input to consumer if it doesn't already // depend on it. - if (node_map_->GetOutputs(input->name()).count(consumer) == 0) { - consumer->add_input(AsControlDependency(input->name())); + if (is_identity && i == 0) { + // Replace regular input from Identity node. + bool found_input = false; + string new_input; + const string& input_to_forward = node->input(0); + CHECK(!IsControlInput(input_to_forward)); + for (int j = 0; j < consumer->input_size(); ++j) { + const string& old_input = consumer->input(j); + if (old_input == node_name) { + new_input = input_to_forward; + node_map_->UpdateInput(consumer->name(), old_input, new_input); + consumer->set_input(j, new_input); + found_input = true; + } else if (old_input == AsControlDependency(NodeName(node_name))) { + new_input = AsControlDependency(NodeName(input_to_forward)); + node_map_->UpdateInput(consumer->name(), old_input, new_input); + consumer->set_input(j, new_input); + found_input = true; + } + } + CHECK(found_input); updated_consumer = true; - node_map_->AddOutput(input->name(), consumer->name()); - nodes_to_simplify->PushBack(node_to_idx_[input]); + } else { + // Forward dependency from input to consumer if it doesn't already + // depend on it. + if (node_map_->GetOutputs(input->name()).count(consumer) == 0) { + consumer->add_input(AsControlDependency(input->name())); + node_map_->AddOutput(input->name(), consumer->name()); + nodes_to_simplify->PushBack(node_to_idx_[input]); + updated_consumer = true; + } } } // Remove dependency on node from consumer. - updated_consumer |= RemoveInput( - consumer, AsControlDependency(node->name()), node_map_.get()); + updated_consumer |= RemoveInput(consumer, AsControlDependency(node_name), + node_map_.get()); if (updated_consumer) { - VLOG(1) << "***** Updated consumer " << consumer->name() << " (" - << consumer->op() << ")"; nodes_to_simplify->PushBack(node_to_idx_[consumer]); } + VLOG(1) << "consumer after:\n" << consumer->DebugString(); } - - node_map_->RemoveOutputs(node->name()); + node_map_->RemoveOutputs(node_name); if (fetch_nodes_known_ && - nodes_to_preserve_.find(node->name()) == nodes_to_preserve_.end()) { + nodes_to_preserve_.find(node_name) == nodes_to_preserve_.end()) { // Mark the node for deletion. nodes_to_delete->insert(node_idx); - // Unconnect the node from its inputs to enable further optimizations. - node_map_->RemoveInputs(node->name()); + // Disconnect the node from its inputs to enable further optimizations. + node_map_->RemoveInputs(node_name); node->clear_input(); } } @@ -330,13 +420,18 @@ Status DependencyOptimizer::OptimizeDependencies() { std::set nodes_to_delete; for (int i = 0; i < optimized_graph_->node_size(); ++i) { const NodeDef& node = optimized_graph_->node(i); - if (node.op() == "NoOp" || IsConstant(node) || SafeToConvertToNoOp(node)) { + if (IsNoOp(node) || IsIdentity(node) || IsConstant(node) || + SafeToConvertToNoOp(node)) { nodes_to_simplify.PushBack(i); } } while (!nodes_to_simplify.Empty()) { - OptimizeNode(nodes_to_simplify.PopBack(), &nodes_to_simplify, - &nodes_to_delete); + int node_to_simplify = nodes_to_simplify.PopBack(); + // Discard nodes that were marked for deletion already. + while (nodes_to_delete.find(node_to_simplify) != nodes_to_delete.end()) { + node_to_simplify = nodes_to_simplify.PopBack(); + } + OptimizeNode(node_to_simplify, &nodes_to_simplify, &nodes_to_delete); } if (fetch_nodes_known_) { @@ -431,9 +526,10 @@ Status DependencyOptimizer::TransitiveReduction() { if (longest_distance[target] > 1) { const int input_slot = control_output.second; control_edges_to_remove[target].emplace(input_slot, source); - VLOG(1) << "Removing edge from:\n" - << optimized_graph_->node(source).DebugString() << "\n\nto:\n\n" - << optimized_graph_->node(target).DebugString(); + // VLOG(1) << "Removing edge from:\n" + // << optimized_graph_->node(source).DebugString() << + // "\n\nto:\n\n" + // << optimized_graph_->node(target).DebugString(); } } } @@ -473,8 +569,8 @@ Status DependencyOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, *optimized_graph_ = item.graph; nodes_to_preserve_ = item.NodesToPreserve(); fetch_nodes_known_ = !item.fetch.empty(); - CleanControlInputs(); + const int num_iterations = 2; for (int iteration = 0; iteration < num_iterations; ++iteration) { Status topo_sort_status; @@ -491,9 +587,12 @@ Status DependencyOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, } else { LOG(ERROR) << topo_sort_status.error_message(); } - - // Turn nodes with only control outputs into NoOps, prune NoOps. + // Turn nodes with only control outputs into NoOps, prune NoOp and Identity + // nodes. TF_RETURN_IF_ERROR(OptimizeDependencies()); + + // Dedup control inputs. + CleanControlInputs(); } return Status::OK(); diff --git a/tensorflow/core/grappler/optimizers/dependency_optimizer.h b/tensorflow/core/grappler/optimizers/dependency_optimizer.h index cfc53244397adb7a3c267aaa6ab865385054736a..0f47528a0435d3e90d92b07306d7b1a4a072ce27 100644 --- a/tensorflow/core/grappler/optimizers/dependency_optimizer.h +++ b/tensorflow/core/grappler/optimizers/dependency_optimizer.h @@ -29,8 +29,9 @@ namespace grappler { // optimizations, such as removing nodes that are effectively noops. class DependencyOptimizer : public GraphOptimizer { public: - DependencyOptimizer() {} - explicit DependencyOptimizer(RewriterConfig::Toggle /*unused*/) {} + DependencyOptimizer() : opt_level_(RewriterConfig::ON) {} + explicit DependencyOptimizer(RewriterConfig::Toggle opt_level) + : opt_level_(opt_level) {} ~DependencyOptimizer() override {} string name() const override { return "dependency_optimizer"; }; @@ -42,6 +43,9 @@ class DependencyOptimizer : public GraphOptimizer { const GraphDef& optimized_graph, double result) override; private: + // Returns true if node is not an Identity node or if it is an Identity + // that is safe to remove. + bool SafeToRemoveIdentity(const NodeDef& node); // Returns true if it is safe to convert node to NoOp. bool SafeToConvertToNoOp(const NodeDef& node); // Removes all duplicate control dependencies. @@ -61,6 +65,7 @@ class DependencyOptimizer : public GraphOptimizer { // Main driver of dependency optimizations. Status OptimizeDependencies(); + RewriterConfig::Toggle opt_level_; bool fetch_nodes_known_; std::unordered_set nodes_to_preserve_; std::unique_ptr node_map_; diff --git a/tensorflow/core/grappler/optimizers/dependency_optimizer_test.cc b/tensorflow/core/grappler/optimizers/dependency_optimizer_test.cc index f5027a4a99e4f28b4b49df914e9247a008036c20..33d6b992d21212fe325c642b87d3c3736185c445 100644 --- a/tensorflow/core/grappler/optimizers/dependency_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/dependency_optimizer_test.cc @@ -167,12 +167,14 @@ TEST_F(DependencyOptimizerTest, ChangeToNoop_SwitchIdentity) { ops::Const(scope.WithOpName("c2").WithControlDependencies(ctrl_dep_id), {1.0f, 2.0f}, {1, 2}); Output neg1 = ops::Neg(scope.WithOpName("neg1"), s.output_false); + Output neg2 = ops::Neg(scope.WithOpName("neg2"), ctrl_dep_id); GrapplerItem item; TF_CHECK_OK(scope.ToGraphDef(&item.graph)); item.fetch.push_back("c1"); item.fetch.push_back("c2"); item.fetch.push_back("neg1"); + item.fetch.push_back("neg2"); DependencyOptimizer optimizer; GraphDef output; @@ -323,25 +325,148 @@ TEST_F(DependencyOptimizerTest, RemoveNoOps_SingleInputOrOutput) { } } +TEST_F(DependencyOptimizerTest, RemoveIdentity) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output x = ops::RandomUniform(s.WithOpName("x"), {1, 2}, DT_FLOAT); + Output y = ops::RandomUniform(s.WithOpName("y"), {1, 2}, DT_FLOAT); + Output z = ops::RandomUniform(s.WithOpName("z"), {1, 2}, DT_FLOAT); + + // Identity nodes to be removed. + // Case a) with a single input- and multiple outputs. + auto id_a = ops::Identity(s.WithOpName("id_a"), x); + // Case b) with multiple inputs and a single output. + auto id_b = ops::Identity( + s.WithOpName("id_b").WithControlDependencies(y).WithControlDependencies( + z), + x); + // Case c) with two inputs and two outputs. + auto id_c = ops::Identity(s.WithOpName("id_c").WithControlDependencies(y), x); + + // Output for Case a. + Output a_a = ops::Identity(s.WithOpName("a_a"), id_a); + Output a_b = ops::Identity(s.WithOpName("a_b"), id_a); + Output a_c = + ops::Identity(s.WithOpName("a_c").WithControlDependencies(id_a), z); + Output a_d = + ops::Identity(s.WithOpName("a_d").WithControlDependencies(id_a), z); + // Output for Case b. + Output b_a = ops::Identity(s.WithOpName("b_a"), id_b); + // Output for Case c. + Output c_a = ops::Identity(s.WithOpName("c_a"), id_c); + Output c_b = + ops::Identity(s.WithOpName("c_b").WithControlDependencies(id_c), z); + + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + item.fetch = {"a_a", "a_b", "a_c", "a_d", "b_a", "c_a", "c_b"}; + + DependencyOptimizer optimizer; + GraphDef output; + Status status = optimizer.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); + + EXPECT_EQ(item.graph.node_size() - 3, output.node_size()); + for (const NodeDef& node : output.node()) { + EXPECT_NE("id_a", node.name()); + EXPECT_NE("id_b", node.name()); + EXPECT_NE("id_c", node.name()); + if (node.name() == "a_a" || node.name() == "a_b") { + EXPECT_EQ(1, node.input_size()); + EXPECT_EQ("x", node.input(0)); + } + if (node.name() == "a_c" || node.name() == "a_d") { + EXPECT_EQ(2, node.input_size()); + EXPECT_EQ("z", node.input(0)); + EXPECT_EQ("^x", node.input(1)); + } + if (node.name() == "b_a") { + EXPECT_EQ(3, node.input_size()); + EXPECT_EQ("x", node.input(0)); + EXPECT_EQ("^y", node.input(1)); + EXPECT_EQ("^z", node.input(2)); + } + if (node.name() == "c_a") { + EXPECT_EQ(2, node.input_size()); + EXPECT_EQ("x", node.input(0)); + EXPECT_EQ("^y", node.input(1)); + } + if (node.name() == "c_b") { + EXPECT_EQ(3, node.input_size()); + EXPECT_EQ("z", node.input(0)); + EXPECT_EQ("^x", node.input(1)); + EXPECT_EQ("^y", node.input(2)); + } + } +} + +TEST_F(DependencyOptimizerTest, RemoveIdentity_RepeatedInputs) { + // Corner cases with repeated inputs. + tensorflow::Scope scope = tensorflow::Scope::NewRootScope(); + ops::Variable x(scope.WithOpName("x"), {}, DT_BOOL); + ops::Variable y(scope.WithOpName("y"), {}, DT_BOOL); + ops::Switch sw(scope.WithOpName("switch"), x, x); + // id0 should be removed. + Output id0 = ops::Identity(scope.WithOpName("id0"), sw.output_true); + // id1 should not be removed, since it would anchor a control dependency + // on the switch. + Output id1 = ops::Identity(scope.WithOpName("id1"), sw.output_false); + Output or0 = ops::LogicalOr(scope.WithOpName("or0"), id0, id0); + Output or1 = ops::LogicalOr(scope.WithOpName("or1"), id0, y); + Output or2 = ops::LogicalOr( + scope.WithOpName("or2").WithControlDependencies(id1), y, y); + + GrapplerItem item; + TF_CHECK_OK(scope.ToGraphDef(&item.graph)); + item.fetch.push_back("or0"); + item.fetch.push_back("or1"); + item.fetch.push_back("or2"); + DependencyOptimizer optimizer; + GraphDef output; + Status status = optimizer.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); + + EXPECT_EQ(item.graph.node_size() - 1, output.node_size()); + for (const NodeDef& node : output.node()) { + EXPECT_NE("id0", node.name()); + if (node.name() == "or0") { + EXPECT_EQ(2, node.input_size()); + EXPECT_EQ("switch:1", node.input(0)); + EXPECT_EQ("switch:1", node.input(1)); + } + if (node.name() == "or1") { + EXPECT_EQ(2, node.input_size()); + EXPECT_EQ("switch:1", node.input(0)); + EXPECT_EQ("y", node.input(1)); + } + if (node.name() == "or2") { + // or1 should be unchanged. + EXPECT_EQ(3, node.input_size()); + EXPECT_EQ("y", node.input(0)); + EXPECT_EQ("y", node.input(1)); + EXPECT_EQ("^id1", node.input(2)); + } + } +} + TEST_F(DependencyOptimizerTest, Transitive_Reduction_Simple) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output c = ops::Const(s.WithOpName("c"), {1.0f, 2.0f}, {1, 2}); Output x = ops::Square(s.WithOpName("x"), c); - Output id1 = ops::Identity(s.WithOpName("id1"), x); - Output id2 = - ops::Identity(s.WithOpName("id2").WithControlDependencies({x}), id1); + Output neg1 = ops::Neg(s.WithOpName("neg1"), x); + Output neg2 = + ops::Neg(s.WithOpName("neg2").WithControlDependencies({x}), neg1); GrapplerItem item; TF_CHECK_OK(s.ToGraphDef(&item.graph)); - item.fetch.push_back("id2"); + item.fetch.push_back("neg2"); DependencyOptimizer optimizer; GraphDef output; Status status = optimizer.Optimize(nullptr, item, &output); TF_EXPECT_OK(status); EXPECT_EQ(4, output.node_size()); - EXPECT_EQ("id2", output.node(3).name()); + EXPECT_EQ("neg2", output.node(3).name()); EXPECT_EQ(1, output.node(3).input_size()); - EXPECT_EQ("id1", output.node(3).input(0)); + EXPECT_EQ("neg1", output.node(3).input(0)); } TEST_F(DependencyOptimizerTest, ChangeToNoop_Identity) { @@ -356,17 +481,18 @@ TEST_F(DependencyOptimizerTest, ChangeToNoop_Identity) { Output grappler_added_id = ops::Identity( scope.WithOpName("ConstantFoldingCtrl/switch_1"), s.output_true); Output c1 = ops::Const(scope.WithOpName("c1") - .WithControlDependencies(id0) .WithControlDependencies(id_after_var) .WithControlDependencies(grappler_added_id), {1.0f, 2.0f}, {1, 2}); Output id1 = ops::Identity(scope.WithOpName("id1"), c1); + Output id2 = ops::Identity(scope.WithOpName("id2"), id0); Output fetch = ops::Identity(scope.WithOpName("fetch").WithControlDependencies(id1), c1); GrapplerItem item; TF_CHECK_OK(scope.ToGraphDef(&item.graph)); item.fetch.push_back("c1"); + item.fetch.push_back("id2"); item.fetch.push_back("fetch"); DependencyOptimizer optimizer; @@ -377,8 +503,8 @@ TEST_F(DependencyOptimizerTest, ChangeToNoop_Identity) { EXPECT_EQ(item.graph.node_size() - 2, output.node_size()); for (int i = 0; i < output.node_size(); ++i) { const NodeDef& node = output.node(i); - // "id0" and "id1" but neither "ConstantFoldingCtrl/switch_1" nor - // "id_after_var" should be eliminated. + // "id0" and "id1" but neither "ConstantFoldingCtrl/switch_1", + // "id_after_var, nor "id2"" should be eliminated. EXPECT_NE("id0", node.name()); EXPECT_NE("id1", node.name()); if (node.name() == "c1") { diff --git a/tensorflow/core/grappler/optimizers/memory_optimizer.cc b/tensorflow/core/grappler/optimizers/memory_optimizer.cc index 6f95a00fa31a894f5647730d074b4a2fd0e918a5..be18e7d6f01ca4c5cbabb9099685fa7b8efeb1e3 100644 --- a/tensorflow/core/grappler/optimizers/memory_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/memory_optimizer.cc @@ -35,6 +35,7 @@ limitations under the License. #include "tensorflow/core/grappler/optimizers/static_schedule.h" #include "tensorflow/core/grappler/utils.h" #include "tensorflow/core/grappler/utils/topological_sort.h" +#include "tensorflow/core/grappler/utils/traversal.h" #include "tensorflow/core/protobuf/rewriter_config.pb.h" namespace tensorflow { @@ -497,7 +498,7 @@ bool SchedulingPass(Cluster* cluster, GrapplerItem* item) { if (!IsAddN(node)) { continue; } - // There is nothing to gain by optimizing nodes with 2 inputs of fewer. + // There is nothing to gain by optimizing nodes with 2 or fewer inputs. if (view.NumFanins(node, false) <= 2) { continue; } @@ -559,6 +560,54 @@ bool SchedulingPass(Cluster* cluster, GrapplerItem* item) { VLOG(1) << "Missing properties for " << node->name(); continue; } + + // Compute a topological ordering for the node fanin. + std::unordered_map topo_order; + ReverseDfs(view, {node}, nullptr, + [&topo_order](NodeDef* n) { + int topo_index = topo_order.size(); + topo_order[n] = topo_index; + }, + nullptr); + + std::vector input_topo_index; + + for (int i = 0; i < node->input_size(); ++i) { + const string& input = node->input(i); + const string node_name = NodeName(input); + NodeDef* node = view.GetNode(node_name); + input_topo_index.push_back(topo_order.at(node)); + } + int min_input_topo_index = INT_MAX; + int min_input_id = -1; + for (int i = 0; i < node->input_size(); ++i) { + if (IsControlInput(node->input(i))) { + // control inputs are always last. + break; + } + const int current = input_topo_index[i]; + if (current < min_input_topo_index) { + min_input_topo_index = current; + min_input_id = i; + } + } + CHECK_LE(0, min_input_id); + std::vector pre_ctrl_deps; + std::vector post_ctrl_deps; + for (int i = node->input_size() - 1; i >= 0; --i) { + if (!IsControlInput(node->input(i))) { + // control inputs are always last. + break; + } + if (input_topo_index[i] < min_input_topo_index) { + // These control dependencies can be executed before the node. + pre_ctrl_deps.push_back(node->input(i)); + } else { + // These control dependencies should be executed after the node. + post_ctrl_deps.push_back(node->input(i)); + } + } + const TensorShapeProto& shape = properties.GetOutputProperties(node->name())[0].shape(); DataType dtype = node->attr().at("T").type(); @@ -573,13 +622,19 @@ bool SchedulingPass(Cluster* cluster, GrapplerItem* item) { *(*tmp_var->mutable_attr())["shape"].mutable_shape() = shape; (*tmp_var->mutable_attr())["var_name"].set_s(tmp_var->name()); + for (const string& ctrl_dep : pre_ctrl_deps) { + *tmp_var->add_input() = ctrl_dep; + } + *tmp_var->add_input() = + AsControlDependency(NodeName(node->input(min_input_id))); + // Initialize it to zero NodeDef* zeros = item->graph.add_node(); zeros->set_name(strings::StrCat(node->name(), "/tmp_var_zeros")); zeros->set_op("ZerosLike"); zeros->set_device(device); (*zeros->mutable_attr())["T"].set_type(dtype); - *zeros->add_input() = node->input(0); + *zeros->add_input() = node->input(min_input_id); NodeDef* initialize = item->graph.add_node(); initialize->set_name(strings::StrCat(node->name(), "/tmp_var_initializer")); @@ -593,15 +648,14 @@ bool SchedulingPass(Cluster* cluster, GrapplerItem* item) { std::vector accumulates; for (int i = 0; i < node->input_size(); ++i) { const string& input = node->input(i); - if (IsControlInput(input)) { - *zeros->add_input() = input; - } else { + if (!IsControlInput(input)) { NodeDef* accumulate = item->graph.add_node(); accumulate->set_name( strings::StrCat(node->name(), "/tmp_var_accum_", i)); accumulate->set_op("AssignAdd"); accumulate->set_device(device); (*accumulate->mutable_attr())["T"].set_type(dtype); + (*accumulate->mutable_attr())["use_locking"].set_b(true); *accumulate->add_input() = initialize->name(); *accumulate->add_input() = input; accumulates.push_back(accumulate); @@ -618,6 +672,10 @@ bool SchedulingPass(Cluster* cluster, GrapplerItem* item) { for (const NodeDef* accum : accumulates) { *node->add_input() = AsControlDependency(accum->name()); } + for (const string& ctrl_dep : post_ctrl_deps) { + *node->add_input() = ctrl_dep; + } + updated_graph = true; } diff --git a/tensorflow/core/grappler/optimizers/memory_optimizer_test.cc b/tensorflow/core/grappler/optimizers/memory_optimizer_test.cc index f5d9c87992655c7fbd94919ca5a31f64207cd79c..5d7913e0c018ecf14cc09ab91d3a71125c720aa5 100644 --- a/tensorflow/core/grappler/optimizers/memory_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/memory_optimizer_test.cc @@ -19,17 +19,18 @@ limitations under the License. #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/grappler/clusters/virtual_cluster.h" #include "tensorflow/core/grappler/grappler_item.h" #include "tensorflow/core/grappler/utils.h" +#include "tensorflow/core/grappler/utils/grappler_test.h" #include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/platform/test.h" namespace tensorflow { namespace grappler { namespace { -class RecomputeSubgraphTest : public ::testing::Test {}; +class RecomputeSubgraphTest : public GrapplerTest {}; TEST_F(RecomputeSubgraphTest, SimpleSubgraph) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); @@ -193,7 +194,7 @@ TEST_F(RecomputeSubgraphTest, MultiNode) { EXPECT_EQ("^gradients/BN1Grad", recompute_trigger_c->input(0)); } -class MemoryOptimizerTest : public ::testing::Test { +class MemoryOptimizerTest : public GrapplerTest { public: static std::unique_ptr CreateVirtualCluster() { DeviceProperties cpu_device; @@ -201,6 +202,7 @@ class MemoryOptimizerTest : public ::testing::Test { cpu_device.set_frequency(1000); cpu_device.set_num_cores(4); cpu_device.set_bandwidth(32); + cpu_device.set_memory_size(1024 * 1024); DeviceProperties gpu_device; gpu_device.set_type("GPU"); gpu_device.set_frequency(1000); @@ -346,17 +348,18 @@ TEST_F(MemoryOptimizerTest, UnswappableInputs) { TEST_F(MemoryOptimizerTest, AccumulationRewrites) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); - Output a = ops::Variable(s.WithOpName("a").WithDevice("/gpu:0"), - {128, 128, 8}, DT_FLOAT); - Output b = ops::Variable(s.WithOpName("b").WithDevice("/gpu:0"), - {128, 128, 8}, DT_FLOAT); - Output c = ops::Variable(s.WithOpName("c").WithDevice("/gpu:0"), - {128, 128, 8}, DT_FLOAT); - Output d = ops::AddN(s.WithOpName("d").WithDevice("/gpu:0"), {a, b, c}); + Output a = ops::RandomNormal(s.WithOpName("a").WithDevice("/cpu:0"), + {128, 128, 8}, DT_FLOAT); + Output b = ops::RandomNormal(s.WithOpName("b").WithDevice("/cpu:0"), + {128, 128, 8}, DT_FLOAT); + Output c = ops::RandomNormal(s.WithOpName("c").WithDevice("/cpu:0"), + {128, 128, 8}, DT_FLOAT); + Output d = ops::AddN(s.WithOpName("d").WithDevice("/cpu:0"), {a, b, c}); + Output e = ops::Square(s.WithOpName("e").WithDevice("/cpu:0"), d); GrapplerItem item; TF_CHECK_OK(s.ToGraphDef(&item.graph)); - item.fetch = {"d"}; + item.fetch = {"e"}; std::unique_ptr cluster(CreateVirtualCluster()); MemoryOptimizer optimizer(RewriterConfig::SCHEDULING_HEURISTICS); @@ -375,9 +378,27 @@ TEST_F(MemoryOptimizerTest, AccumulationRewrites) { } else if (node.name() == "d/tmp_var") { EXPECT_EQ("TemporaryVariable", node.op()); count++; + } else if (node.name() == "e") { + EXPECT_EQ("Square", node.op()); + EXPECT_EQ("d", node.input(0)); + count++; + } + } + EXPECT_EQ(4, count); + + std::vector fetch = {"a", "b", "c", "e"}; + auto tensors = EvaluateNodes(output, fetch); + EXPECT_EQ(4, tensors.size()); + + for (int i = 0; i < tensors[0].NumElements(); ++i) { + float actual = tensors[3].flat()(i); + float expected = 0.0f; + for (int j = 0; j < 3; ++j) { + expected += tensors[j].flat()(i); } + expected *= expected; + EXPECT_NEAR(actual, expected, 1e-4); } - EXPECT_EQ(3, count); } } // namespace diff --git a/tensorflow/core/grappler/utils.cc b/tensorflow/core/grappler/utils.cc index 8099214c2bd81e642bbcc8fc913d1ec3307d6251..634577ed305cf191d62fcb7e2edee2d2e3777d9e 100644 --- a/tensorflow/core/grappler/utils.cc +++ b/tensorflow/core/grappler/utils.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_def.pb.h" #include "tensorflow/core/framework/types.h" @@ -207,7 +208,7 @@ string AsControlDependency(const string& node_name) { : strings::StrCat("^", node_name); } -int NumOutputs(const NodeDef& node) { +int NumOutputs(const NodeDef& node, GraphDef* graph) { int num_outputs = 0; const OpDef* op_def = nullptr; auto status = OpRegistry::Global()->LookUpOpDef(node.op(), &op_def); @@ -222,6 +223,12 @@ int NumOutputs(const NodeDef& node) { num_outputs++; } } + } else { + FunctionLibraryDefinition fdef(OpRegistry::Global(), graph->library()); + auto status = fdef.LookUpOpDef(node.op(), &op_def); + if (status.ok()) { + num_outputs = op_def->output_arg_size(); + } } return num_outputs; } diff --git a/tensorflow/core/grappler/utils.h b/tensorflow/core/grappler/utils.h index c04a9a666dd68c42f378543bd2fc997a4bde872c..8840c44d050900b51e7707c1ff55e95b86bd0832 100644 --- a/tensorflow/core/grappler/utils.h +++ b/tensorflow/core/grappler/utils.h @@ -135,7 +135,7 @@ string AsControlDependency(const string& node); // Returns the number of outputs of a node according to its OpDef. Note that // some of the outputs may be unconnected. -int NumOutputs(const NodeDef& node); +int NumOutputs(const NodeDef& node, GraphDef* graph); // Number of connected non-control inputs. int NumNonControlInputs(const NodeDef& node); diff --git a/tensorflow/core/grappler/utils/BUILD b/tensorflow/core/grappler/utils/BUILD index 534f7a063fe90bf72f8a2afba7ae8f75b8472a36..0a9dbe22cfe3cd01c2c61661adcdd4839a957f03 100644 --- a/tensorflow/core/grappler/utils/BUILD +++ b/tensorflow/core/grappler/utils/BUILD @@ -99,3 +99,49 @@ tf_cc_test( "//tensorflow/core:test_main", ], ) + +cc_library( + name = "traversal", + srcs = ["traversal.cc"], + hdrs = ["traversal.h"], + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/grappler:graph_view", + "//tensorflow/core/grappler:op_types", + "//tensorflow/core/grappler:utils", + ], +) + +tf_cc_test( + name = "traversal_test", + srcs = ["traversal_test.cc"], + deps = [ + ":traversal", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +cc_library( + name = "grappler_test", + testonly = 1, + srcs = [ + "grappler_test.cc", + ], + hdrs = ["grappler_test.h"], + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/core:all_kernels", + "//tensorflow/core:core_cpu", + "//tensorflow/core:direct_session", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core/grappler:utils", + ], +) diff --git a/tensorflow/core/grappler/utils/grappler_test.cc b/tensorflow/core/grappler/utils/grappler_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..813f65f825759ca22dba2bdfd8433d946b7dd852 --- /dev/null +++ b/tensorflow/core/grappler/utils/grappler_test.cc @@ -0,0 +1,39 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/utils/grappler_test.h" +#include +#include "tensorflow/core/grappler/utils.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/public/session.h" + +namespace tensorflow { +namespace grappler { + +std::vector GrapplerTest::EvaluateNodes( + const GraphDef& graph, const std::vector& node_names) { + SessionOptions options; + std::unique_ptr session(NewSession(options)); + TF_CHECK_OK(session->Create(graph)); + RunOptions run_options; + std::vector output_tensors; + TF_CHECK_OK(session->Run(run_options, {}, node_names, node_names, + &output_tensors, nullptr)); + TF_CHECK_OK(session->Close()); + return output_tensors; +} + +} // namespace grappler +} // namespace tensorflow diff --git a/tensorflow/core/grappler/utils/grappler_test.h b/tensorflow/core/grappler/utils/grappler_test.h new file mode 100644 index 0000000000000000000000000000000000000000..46ce47c8c3b6bc18b6eac76bbdb8ec1f8a58fab2 --- /dev/null +++ b/tensorflow/core/grappler/utils/grappler_test.h @@ -0,0 +1,37 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_GRAPPLER_GRAPPLER_TEST_H_ +#define TENSORFLOW_GRAPPLER_GRAPPLER_TEST_H_ + +#include + +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace grappler { + +class GrapplerTest : public ::testing::Test { + protected: + std::vector EvaluateNodes(const GraphDef& graph, + const std::vector& node_names); +}; + +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_GRAPPLER_GRAPPLER_TEST_H_ diff --git a/tensorflow/core/grappler/utils/traversal.cc b/tensorflow/core/grappler/utils/traversal.cc new file mode 100644 index 0000000000000000000000000000000000000000..f44f53c4e63805544fa480628e805303064edb3d --- /dev/null +++ b/tensorflow/core/grappler/utils/traversal.cc @@ -0,0 +1,80 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/utils/traversal.h" +#include "tensorflow/core/framework/node_def.pb.h" + +namespace tensorflow { +namespace grappler { + +void ReverseDfs(const GraphView& graph_view, const std::vector& from, + const std::function& pre_order, + const std::function& post_order, + const std::function& on_back_edge) { + // Stack of work to do. + struct StackElem { + NodeDef* node; + bool children_visited; + NodeDef* src; + }; + std::vector stack; + + stack.reserve(from.size()); + for (NodeDef* node : from) { + stack.push_back(StackElem{node, false}); + } + + enum NodeState { NOT_VISITED = 0, VISITING = 1, DONE = 2 }; + std::unordered_map node_state; + while (!stack.empty()) { + StackElem w = stack.back(); + stack.pop_back(); + + if (w.children_visited) { + // We've processed all the children of this node + node_state[w.node] = DONE; + if (post_order) { + post_order(w.node); + } + continue; + } + + auto& rslt = node_state[w.node]; + if (rslt == DONE) { + continue; + } else if (rslt == VISITING) { + // Loop detected + if (on_back_edge) { + on_back_edge(w.src, w.node); + } + continue; + } + rslt = VISITING; + if (pre_order) { + pre_order(w.node); + } + + // Enqueue the node again with the children_visited flag set to true. + stack.push_back(StackElem{w.node, true, w.src}); + + // Now enqueu the node children. + for (const auto fanin : graph_view.GetFanins(*w.node, true)) { + stack.push_back(StackElem{fanin.node, false, w.node}); + } + } +} + +} // namespace grappler +} // namespace tensorflow diff --git a/tensorflow/core/grappler/utils/traversal.h b/tensorflow/core/grappler/utils/traversal.h new file mode 100644 index 0000000000000000000000000000000000000000..bb3fa090e8fdaf12ed6dcb18eb1511c55496a125 --- /dev/null +++ b/tensorflow/core/grappler/utils/traversal.h @@ -0,0 +1,39 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_UTILS_TRAVERSAL_H_ +#define TENSORFLOW_CORE_GRAPPLER_UTILS_TRAVERSAL_H_ + +#include +#include "tensorflow/core/grappler/graph_view.h" + +namespace tensorflow { +namespace grappler { + +// Traverse the graph in reverse dfs order, starting from the list of nodes +// specified in the 'from' argument. The pre_order and post_order functors will +// be called on each reachable node (including the 'from' nodes) in pre and post +// order. If loops are found, the on_back_edge functor will be called on the +// corresponding back edges. Moreover, the pre and post order will assume that +// these back edges will be cut. +void ReverseDfs(const GraphView& graph_view, const std::vector& from, + const std::function& pre_order, + const std::function& post_order, + const std::function& on_back_edge); + +} // namespace grappler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_UTILS_TRAVERSAL_H_ diff --git a/tensorflow/core/grappler/utils/traversal_test.cc b/tensorflow/core/grappler/utils/traversal_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..cc68bd1a9637cb6f61955e8fa5d495a34f19cb09 --- /dev/null +++ b/tensorflow/core/grappler/utils/traversal_test.cc @@ -0,0 +1,101 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/utils/traversal.h" +//#include "tensorflow/core/framework/node_def.pb.h" +//#include "tensorflow/core/lib/core/status_test_util.h" +//#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace grappler { +namespace { + +class TraversalTest : public ::testing::Test { + protected: + static NodeDef CreateNode(const string& name, + const std::vector& inputs) { + return CreateNode(name, "", inputs); + } + static NodeDef CreateNode(const string& name, const string& op, + const std::vector& inputs) { + NodeDef node; + node.set_name(name); + if (!op.empty()) { + node.set_op(op); + } + for (const string& input : inputs) { + node.add_input(input); + } + return node; + } +}; + +TEST_F(TraversalTest, ReverseDfsNoLoop) { + GraphDef graph; + *graph.add_node() = CreateNode("2", {"5"}); + *graph.add_node() = CreateNode("0", {"5", "4"}); + *graph.add_node() = CreateNode("1", {"4", "3"}); + *graph.add_node() = CreateNode("3", {"2"}); + *graph.add_node() = CreateNode("5", {}); + *graph.add_node() = CreateNode("4", {}); + + std::vector start_nodes = {graph.mutable_node(1), + graph.mutable_node(2)}; + std::vector pre_order; + std::vector post_order; + bool found_back_edge = false; + ReverseDfs( + GraphView(&graph), start_nodes, + [&pre_order](NodeDef* n) { pre_order.push_back(n->name()); }, + [&post_order](NodeDef* n) { post_order.push_back(n->name()); }, + [&found_back_edge](NodeDef*, NodeDef*) { found_back_edge = true; }); + + EXPECT_EQ(std::vector({"1", "4", "3", "2", "5", "0"}), pre_order); + EXPECT_EQ(std::vector({"4", "5", "2", "3", "1", "0"}), post_order); + EXPECT_FALSE(found_back_edge); +} + +TEST_F(TraversalTest, ReverseDfsWithLoop) { + GraphDef graph; + // Create a loop + *graph.add_node() = CreateNode("2", "Merge", {"1", "5"}); + *graph.add_node() = CreateNode("3", "Switch", {"2"}); + *graph.add_node() = CreateNode("4", "Identity", {"3"}); + *graph.add_node() = CreateNode("5", "NextIteration", {"4"}); + *graph.add_node() = CreateNode("1", "Enter", {}); + *graph.add_node() = CreateNode("6", "Exit", {"3"}); + + std::vector start_nodes = {graph.mutable_node(5)}; + std::vector pre_order; + std::vector post_order; + std::vector back_edges; + ReverseDfs( + GraphView(&graph), start_nodes, + [&pre_order](NodeDef* n) { pre_order.push_back(n->name()); }, + [&post_order](NodeDef* n) { post_order.push_back(n->name()); }, + [&back_edges](NodeDef* src, NodeDef* dst) { + back_edges.push_back(strings::StrCat(src->name(), "->", dst->name())); + }); + + EXPECT_EQ(std::vector({"6", "3", "2", "1", "5", "4"}), pre_order); + EXPECT_EQ(std::vector({"1", "4", "5", "2", "3", "6"}), post_order); + EXPECT_EQ(std::vector({"4->3"}), back_edges); +} + +} // namespace +} // namespace grappler +} // namespace tensorflow diff --git a/tensorflow/core/grappler/utils_test.cc b/tensorflow/core/grappler/utils_test.cc index 77371c399e5fc7321f7c2b271aae32ce9655244b..ba4e6b1bae6073b42e4abc4071fdbc7f569f5a93 100644 --- a/tensorflow/core/grappler/utils_test.cc +++ b/tensorflow/core/grappler/utils_test.cc @@ -177,9 +177,10 @@ TEST_F(UtilsTest, ExecuteWithTimeout) { } TEST_F(UtilsTest, NumOutputs) { - EXPECT_EQ(2, NumOutputs(CreateConcatOffsetNode())); - EXPECT_EQ(5, NumOutputs(CreateFusedBatchNormNode())); - EXPECT_EQ(1, NumOutputs(CreateDequeueNode())); + GraphDef graph; + EXPECT_EQ(2, NumOutputs(CreateConcatOffsetNode(), &graph)); + EXPECT_EQ(5, NumOutputs(CreateFusedBatchNormNode(), &graph)); + EXPECT_EQ(1, NumOutputs(CreateDequeueNode(), &graph)); } TEST_F(UtilsTest, AsControlDependency) { diff --git a/tensorflow/core/kernels/batch_util.cc b/tensorflow/core/kernels/batch_util.cc index 7f2df95e2d55ac93f8a934010244dcbd1dcd28c8..1a45212ad29a7b8a578ce176db20eaf3d2193afd 100644 --- a/tensorflow/core/kernels/batch_util.cc +++ b/tensorflow/core/kernels/batch_util.cc @@ -19,6 +19,8 @@ limitations under the License. #include "tensorflow/core/framework/types.h" #include "tensorflow/core/lib/core/errors.h" +#define TF_CALL_DATASET_TYPES(m) TF_CALL_ALL_TYPES(m) TF_CALL_QUANTIZED_TYPES(m) + namespace tensorflow { namespace batch_util { @@ -61,6 +63,21 @@ Status HandleElementToSlice(Tensor element, Tensor* parent, int64 index, return Status::OK(); } +template <> +Status HandleElementToSlice(Tensor element, Tensor* parent, + int64 index, bool can_move) { + auto parent_as_matrix = parent->flat_outer_dims(); + auto element_flat = element.flat(); + if (can_move) { + for (int64 i = 0; i < element.NumElements(); ++i) { + parent_as_matrix(index, i) = std::move(element_flat(i)); + } + } else { + parent_as_matrix.chip(index, 0) = element_flat; + } + return Status::OK(); +} + // TODO(jsimsa): Add HandleElementToSlice specialization that moves // the data when possible. @@ -87,7 +104,6 @@ Status CopyElementToSlice(Tensor element, Tensor* parent, int64 index) { switch (element.dtype()) { TF_CALL_ALL_TYPES(HANDLE_TYPE); TF_CALL_QUANTIZED_TYPES(HANDLE_TYPE); - TF_CALL_variant(HANDLE_TYPE); #undef HANDLE_TYPE default: return errors::Unimplemented("CopyElementToSlice Unhandled data type: ", @@ -107,7 +123,6 @@ Status CopySliceToElement(const Tensor& parent, Tensor* element, int64 index) { switch (parent.dtype()) { TF_CALL_ALL_TYPES(HANDLE_TYPE); TF_CALL_QUANTIZED_TYPES(HANDLE_TYPE); - TF_CALL_variant(HANDLE_TYPE); #undef HANDLE_TYPE default: return errors::Unimplemented("CopySliceToElement Unhandled data type: ", @@ -115,5 +130,101 @@ Status CopySliceToElement(const Tensor& parent, Tensor* element, int64 index) { } } +// The following five functions are copied from padding_fifo_queue.cc. +// TODO(mrry): Reconcile these functions with the similar methods in the +// queue implementation. +Status ValidateElementToLargerSlice(const Tensor& element, Tensor* parent) { + DCHECK_NE(parent->dim_size(0), 0); + if (element.NumElements() > (parent->NumElements() / parent->dim_size(0))) { + TensorShape chip_shape = parent->shape(); + chip_shape.RemoveDim(0); + return errors::Internal( + "HandleElementToLargerSlice Cannot copy slice: number of entries in " + "element is greater than number of elements in parent slice. ", + "Shapes are: [element]: ", element.shape().DebugString(), + ", [parent slice]: ", chip_shape.DebugString()); + } + return Status::OK(); +} + +template +Status HandleElementToLargerSlice(const Tensor& element, Tensor* parent, + int index) { + TF_RETURN_IF_ERROR(ValidateElementToLargerSlice(element, parent)); + if (element.NumElements() == 0) { + return Status::OK(); + } + auto element_t = element.tensor(); + auto parent_t = parent->tensor(); + Eigen::DSizes slice_indices; + slice_indices[0] = index; + Eigen::DSizes slice_size; + slice_size[0] = 1; + for (size_t i = 1; i < slice_size.size(); ++i) { + slice_size[i] = element_t.dimension(i - 1); + } + parent_t.slice(slice_indices, slice_size) = element_t.reshape(slice_size); + return Status::OK(); +} + +template +Status HandleElementToLargerSliceWithRank(const Tensor& element, Tensor* parent, + int index) { +#define HANDLE_TYPE(T) \ + case DataTypeToEnum::value: { \ + return HandleElementToLargerSlice(element, parent, index); \ + } + + switch (element.dtype()) { + TF_CALL_DATASET_TYPES(HANDLE_TYPE); +#undef HANDLE_TYPE + default: + return errors::Unimplemented( + "HandleElementToLargerSliceWithRank Unhandled data type: ", + element.dtype()); + } +} + +Status CopyElementToLargerSlice(const Tensor& element, Tensor* parent, + int index) { + if (parent->dims() != element.dims() + 1) { + return errors::Internal( + "Mismatched ranks. Element's rank is: ", element.dims(), + " but element is meant to be a slice in output Tensor having rank: ", + parent->dims(), " (should be: ", element.dims() + 1, ")"); + } + +#define HANDLE_DIMS(NDIMS) \ + case NDIMS: { \ + TF_RETURN_IF_ERROR( \ + HandleElementToLargerSliceWithRank(element, parent, index)); \ + return Status::OK(); \ + } + + switch (element.dims()) { + HANDLE_DIMS(0); + HANDLE_DIMS(1); + HANDLE_DIMS(2); + HANDLE_DIMS(3); + HANDLE_DIMS(4); +#undef HANDLE_DIMS + default: + return errors::Unimplemented("CopyElementToLargerSlice Unhandled rank: ", + element.dims()); + } +} + +Status SetElementZero(Tensor* element, const Tensor& padding) { +#define HANDLE_TYPE(T) \ + if (element->dtype() == DataTypeToEnum::value) { \ + element->flat().setConstant(padding.scalar()()); \ + return Status::OK(); \ + } + TF_CALL_DATASET_TYPES(HANDLE_TYPE); +#undef HANDLE_TYPE + return errors::Unimplemented("SetElementZero Unhandled data type: ", + element->dtype()); +} + } // namespace batch_util } // namespace tensorflow diff --git a/tensorflow/core/kernels/batch_util.h b/tensorflow/core/kernels/batch_util.h index 0d634ae7b07ee641eb13167d6f9fcb9ed5f0d974..a47bf1935db611417cea1d98ed8aff496efbf689 100644 --- a/tensorflow/core/kernels/batch_util.h +++ b/tensorflow/core/kernels/batch_util.h @@ -32,6 +32,16 @@ Status CopyElementToSlice(Tensor element, Tensor* parent, int64 index); // Copies the index^th slice of parent (in the 0th dimension) into element. Status CopySliceToElement(const Tensor& parent, Tensor* element, int64 index); +// Zero-initializes the tensor `element` using the scalar stored in `padding`. +// Both `element` and `padding` must have matching `dtype`. +Status SetElementZero(Tensor* element, const Tensor& padding); + +// Copies `element` into a (0th dimension) slice of `parent`, assuming +// the shape of `element` is strictly not larger along any axis than a +// slice. +Status CopyElementToLargerSlice(const Tensor& element, Tensor* parent, + int index); + } // namespace batch_util } // namespace tensorflow diff --git a/tensorflow/core/kernels/batching_util/periodic_function.h b/tensorflow/core/kernels/batching_util/periodic_function.h index dbf1733dcc399522a673e5724dfeb62446f72a0f..36a4019002aa55c26fb5419c7a4d17562a367de8 100644 --- a/tensorflow/core/kernels/batching_util/periodic_function.h +++ b/tensorflow/core/kernels/batching_util/periodic_function.h @@ -114,7 +114,7 @@ class PeriodicFunction { void RunLoop(int64 start) LOCKS_EXCLUDED(mutex_); const std::function function_; // Actual client function - const int64 interval_micros_; // Interval between calls. + const int64 interval_micros_; // Interval between calls. const Options options_; // Protects state below. diff --git a/tensorflow/core/kernels/batching_util/shared_batch_scheduler_test.cc b/tensorflow/core/kernels/batching_util/shared_batch_scheduler_test.cc index d73dcf0fa0e1b2b387b3ed53acd63d5c65683fd4..d5ea2b648f35efd03c04d00abc838edadd37570e 100644 --- a/tensorflow/core/kernels/batching_util/shared_batch_scheduler_test.cc +++ b/tensorflow/core/kernels/batching_util/shared_batch_scheduler_test.cc @@ -55,15 +55,14 @@ Status ScheduleTask(size_t task_size, BatchScheduler* scheduler) { // use the clock to be destroyed. std::unique_ptr CreateFakeClockAdvancerThread( test_util::FakeClockEnv* env, Notification* start, Notification* stop) { - return std::unique_ptr( - Env::Default()->StartThread({}, "FakeClockAdvancerThread", - [env, start, stop] { - start->WaitForNotification(); - while (!stop->HasBeenNotified()) { - env->AdvanceByMicroseconds(10); - Env::Default()->SleepForMicroseconds(10); - } - })); + return std::unique_ptr(Env::Default()->StartThread( + {}, "FakeClockAdvancerThread", [env, start, stop] { + start->WaitForNotification(); + while (!stop->HasBeenNotified()) { + env->AdvanceByMicroseconds(10); + Env::Default()->SleepForMicroseconds(10); + } + })); } TEST(SharedBatchSchedulerTest, Basic) { @@ -258,7 +257,7 @@ TEST(SharedBatchSchedulerTest, ObeysTimeout) { TEST(SharedBatchSchedulerTest, ObeysTimeoutWithRealClock) { Notification first_batch_processed, second_batch_processed; auto callback = [&first_batch_processed, &second_batch_processed]( - std::unique_ptr> batch) { + std::unique_ptr> batch) { ASSERT_TRUE(batch->IsClosed()); if (batch->size() == 1) { first_batch_processed.Notify(); @@ -301,7 +300,7 @@ TEST(SharedBatchSchedulerTest, { Notification first_batch_processed, second_batch_processed; auto callback = [&first_batch_processed, &second_batch_processed]( - std::unique_ptr> batch) { + std::unique_ptr> batch) { ASSERT_TRUE(batch->IsClosed()); if (batch->size() == 1) { first_batch_processed.Notify(); @@ -349,7 +348,7 @@ TEST(SharedBatchSchedulerTest, Fairness) { auto queue_0_callback = [&queue_0_first_batch_scheduled, &queue_0_first_batch_proceed, &queue_0_second_batch_scheduled]( - std::unique_ptr> batch) { + std::unique_ptr> batch) { if (!queue_0_first_batch_scheduled.HasBeenNotified()) { queue_0_first_batch_scheduled.Notify(); queue_0_first_batch_proceed.WaitForNotification(); @@ -467,7 +466,7 @@ TEST(SharedBatchSchedulerTest, ConstMethods) { TEST(SharedBatchSchedulerTest, OneFullQueueDoesntBlockOtherQueues) { Notification queue_0_processing, queue_0_proceed; auto queue_0_callback = [&queue_0_processing, &queue_0_proceed]( - std::unique_ptr> batch) { + std::unique_ptr> batch) { if (!queue_0_processing.HasBeenNotified()) { queue_0_processing.Notify(); queue_0_proceed.WaitForNotification(); diff --git a/tensorflow/core/kernels/concat_lib_cpu.cc b/tensorflow/core/kernels/concat_lib_cpu.cc index fc5a3e62885c92ec16a906df5a6e2d6245ccbbd6..547a7b40b9245d4b10c12830a0189b09c9dacc76 100644 --- a/tensorflow/core/kernels/concat_lib_cpu.cc +++ b/tensorflow/core/kernels/concat_lib_cpu.cc @@ -73,7 +73,6 @@ REGISTER(qint8) REGISTER(quint16) REGISTER(qint16) REGISTER(qint32) -TF_CALL_variant(REGISTER) #if defined(IS_MOBILE_PLATFORM) && !defined(SUPPORT_SELECTIVE_REGISTRATION) && \ !defined(__ANDROID_TYPES_FULL__) diff --git a/tensorflow/core/kernels/constant_op.cc b/tensorflow/core/kernels/constant_op.cc index 920cd87858ab62357d6d65e0d4db4c26d157a75c..4ab6fdbca1a3415937213d46fac3058097130f55 100644 --- a/tensorflow/core/kernels/constant_op.cc +++ b/tensorflow/core/kernels/constant_op.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/core/kernels/constant_op.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.pb.h" @@ -41,8 +42,33 @@ limitations under the License. namespace tensorflow { +namespace { + +std::unique_ptr StripTensorDataFromNodeDef( + OpKernelConstruction* ctx) { +#ifndef __ANDROID__ + DCHECK_EQ(NodeDef::descriptor()->field_count(), 5) + << "The NodeDef format has changed, and the attr-stripping code may need " + << "to be updated."; +#endif + const NodeDef& original = ctx->def(); + NodeDef* ret = new NodeDef; + ret->set_name(original.name()); + ret->set_op(original.op()); + ret->set_device(original.device()); + // Strip the "value" attr from the returned NodeDef. + // NOTE(mrry): The present implementation of `OpKernel::OpKernel()` only uses + // attrs that affect the cardinality of list-typed inputs and outputs, so it + // is safe to drop other attrs from the NodeDef. + AddNodeAttr("dtype", ctx->output_type(0), ret); + return std::unique_ptr(ret); +} + +} // namespace + ConstantOp::ConstantOp(OpKernelConstruction* ctx) - : OpKernel(ctx), tensor_(ctx->output_type(0)) { + : OpKernel(ctx, StripTensorDataFromNodeDef(ctx)), + tensor_(ctx->output_type(0)) { const TensorProto* proto = nullptr; OP_REQUIRES_OK(ctx, ctx->GetAttr("value", &proto)); OP_REQUIRES_OK(ctx, ctx->device()->MakeTensorFromProto( diff --git a/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc b/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc index e58f5f61f35fc9abfbe1da67c9e51a84c7ddbb6f..a376534badc73065e3ec01972dde85da7bbdb0f8 100644 --- a/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc +++ b/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc @@ -648,8 +648,9 @@ struct BatchNarrowMatrixTransposeDispatcher { static_assert( (TileLongSide & (TileLongSide - 1)) == 0, "The length of the longer side of the tile is always a power of 2."); - bool request_satisfied = max(tile_size_i, tile_size_j) <= TileLongSide && - min(tile_size_i, tile_size_j) <= TileShortSide; + bool request_satisfied = + std::max(tile_size_i, tile_size_j) <= TileLongSide && + std::min(tile_size_i, tile_size_j) <= TileShortSide; if (request_satisfied) { LaunchBatchNarrowMatrixTransposeKernel( @@ -662,7 +663,7 @@ struct BatchNarrowMatrixTransposeDispatcher { // determine whether it is the long side or the short side that falls short // of the request and increase that parameter accordingly. const bool long_side_request_not_satisfied = - max(tile_size_i, tile_size_j) > TileLongSide; + std::max(tile_size_i, tile_size_j) > TileLongSide; if (long_side_request_not_satisfied) { BatchNarrowMatrixTransposeDispatcher< @@ -690,8 +691,9 @@ struct BatchNarrowMatrixTransposeDispatcher< static_assert( (TileLongSide & (TileLongSide - 1)) == 0, "The length of the longer side of the tile is always a power of 2."); - bool request_satisfied = max(tile_size_i, tile_size_j) <= TileLongSide && - min(tile_size_i, tile_size_j) <= TileShortSide; + bool request_satisfied = + std::max(tile_size_i, tile_size_j) <= TileLongSide && + std::min(tile_size_i, tile_size_j) <= TileShortSide; if (request_satisfied) { LaunchBatchNarrowMatrixTransposeKernel( @@ -816,7 +818,7 @@ void SwapDimension1And2InTensor3WithNarrowMatrices( int tile_long_side_len = 0; int tile_short_side_len = 0; float lowest_cost = std::numeric_limits::max(); - int data_long_side = max(input_dims[1], input_dims[2]); + int data_long_side = std::max(input_dims[1], input_dims[2]); for (auto tile_size_pair : tile_spec) { int proposed_tile_long_side_len = tile_size_pair.first; @@ -861,12 +863,14 @@ void SwapDimension1And2InTensor3WithNarrowMatrices( // Truncate the shorter size requested according to the manual limit set in // tile_spec to make sure that we do not launch configurations violating // hardware limits. - requested_tile_size_i = requested_tile_size_i == tile_long_side_len - ? tile_long_side_len - : min(requested_tile_size_i, tile_short_side_len); - requested_tile_size_j = requested_tile_size_j == tile_long_side_len - ? tile_long_side_len - : min(requested_tile_size_j, tile_short_side_len); + requested_tile_size_i = + requested_tile_size_i == tile_long_side_len + ? tile_long_side_len + : std::min(requested_tile_size_i, tile_short_side_len); + requested_tile_size_j = + requested_tile_size_j == tile_long_side_len + ? tile_long_side_len + : std::min(requested_tile_size_j, tile_short_side_len); Dimension<3> input_dims_in_tiles = { input_dims[0], diff --git a/tensorflow/core/kernels/cwise_op_maximum.cc b/tensorflow/core/kernels/cwise_op_maximum.cc index 8c54f22f10887da8020d8f16d21097fcb002483c..e8a58eea80e611d29886af773be5f1ee061d6f66 100644 --- a/tensorflow/core/kernels/cwise_op_maximum.cc +++ b/tensorflow/core/kernels/cwise_op_maximum.cc @@ -16,8 +16,8 @@ limitations under the License. #include "tensorflow/core/kernels/cwise_ops_common.h" namespace tensorflow { -REGISTER5(BinaryOp, CPU, "Maximum", functor::maximum, float, Eigen::half, - double, int32, int64); +REGISTER6(BinaryOp, CPU, "Maximum", functor::maximum, float, Eigen::half, + bfloat16, double, int32, int64); #if GOOGLE_CUDA REGISTER4(BinaryOp, GPU, "Maximum", functor::maximum, float, Eigen::half, double, int64); diff --git a/tensorflow/core/kernels/data/BUILD b/tensorflow/core/kernels/data/BUILD index 45505ef716fa801e4740424374aeb4fe8f5a29b7..c4e21257ffc4c14cac2cadd6dcae14f0900183e1 100644 --- a/tensorflow/core/kernels/data/BUILD +++ b/tensorflow/core/kernels/data/BUILD @@ -49,6 +49,7 @@ cc_library( srcs = ["dataset.cc"], hdrs = ["dataset.h"], deps = [ + "//tensorflow/core:core_cpu", "//tensorflow/core:framework", "//tensorflow/core:graph", "//tensorflow/core:lib", @@ -120,6 +121,7 @@ tf_kernel_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "//tensorflow/core/kernels:batch_util", ], ) @@ -400,6 +402,19 @@ tf_kernel_library( ], ) +tf_kernel_library( + name = "tensor_queue_dataset_op", + srcs = ["tensor_queue_dataset_op.cc"], + deps = [ + ":dataset", + "//tensorflow/core:dataset_ops_op_lib", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core/kernels:batch_util", + ], +) + tf_kernel_library( name = "tensor_slice_dataset_op", srcs = ["tensor_slice_dataset_op.cc"], @@ -538,6 +553,7 @@ tf_kernel_library( ":stats_dataset_ops", ":take_dataset_op", ":tensor_dataset_op", + ":tensor_queue_dataset_op", ":tensor_slice_dataset_op", ":unique_dataset_op", ":zip_dataset_op", diff --git a/tensorflow/core/kernels/data/batch_dataset_op.cc b/tensorflow/core/kernels/data/batch_dataset_op.cc index 2d6e06398f66c0b07ae17d4fd25d7ba6b5cfef03..7fa67efb9e22e6877b97524150b9024521619dbc 100644 --- a/tensorflow/core/kernels/data/batch_dataset_op.cc +++ b/tensorflow/core/kernels/data/batch_dataset_op.cc @@ -92,7 +92,6 @@ class BatchDatasetOp : public UnaryDatasetOpKernel { } private: - class Iterator : public DatasetIterator { public: explicit Iterator(const Params& params) @@ -145,7 +144,7 @@ class BatchDatasetOp : public UnaryDatasetOpKernel { const Tensor& first_element = batch_elements[0][component_index]; TensorShape batch_component_shape({num_batch_elements}); batch_component_shape.AppendShape(first_element.shape()); - Tensor batch_component(cpu_allocator(), first_element.dtype(), + Tensor batch_component(ctx->allocator({}), first_element.dtype(), batch_component_shape); // Build the output tuple component by copying one slice // from each input element in the batch. diff --git a/tensorflow/core/kernels/data/captured_function.cc b/tensorflow/core/kernels/data/captured_function.cc index 1f6d32f8df39948a4529bdf53091ff742ba88edb..f3e4f1cd3fd27c79eec4379dcd79472bde7ab5ea 100644 --- a/tensorflow/core/kernels/data/captured_function.cc +++ b/tensorflow/core/kernels/data/captured_function.cc @@ -22,7 +22,6 @@ limitations under the License. #include "tensorflow/core/lib/random/random.h" #include "tensorflow/core/platform/notification.h" - namespace tensorflow { /* static */ @@ -185,8 +184,7 @@ Status CapturedFunction::MaybeInstantiate( return Status::OK(); } -Status CapturedFunction::Run(IteratorContext* ctx, - std::vector&& args, +Status CapturedFunction::Run(IteratorContext* ctx, std::vector&& args, std::vector* rets) { FunctionLibraryRuntime::Handle handle; TF_RETURN_IF_ERROR(MaybeInstantiate(ctx, &handle)); diff --git a/tensorflow/core/kernels/data/dataset.cc b/tensorflow/core/kernels/data/dataset.cc index 2ea6875567604e4e5bf7c990ad6a42ed8c5dafaa..d18cb160189e832592b2bfdf7769396010859cc6 100644 --- a/tensorflow/core/kernels/data/dataset.cc +++ b/tensorflow/core/kernels/data/dataset.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include "tensorflow/core/kernels/data/dataset.h" +#include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/graph/graph_def_builder.h" #include "tensorflow/core/graph/node_builder.h" @@ -264,6 +265,10 @@ void BinaryDatasetOpKernel::MakeDataset(OpKernelContext* ctx, MakeDataset(ctx, input, another_input, output); } +Allocator* IteratorContext::allocator(AllocatorAttributes attrs) { + return params_.lib->device()->GetAllocator(attrs); +} + const char GraphDatasetBase::kDatasetGraphKey[] = "_DATASET_GRAPH"; const char GraphDatasetBase::kDatasetGraphOutputNodeKey[] = "_DATASET_GRAPH_OUTPUT_NODE"; diff --git a/tensorflow/core/kernels/data/dataset.h b/tensorflow/core/kernels/data/dataset.h index 2ef31ddfaaa2fd1bd6a4898726d788d1ceece82e..2c6fc8d5b4f607c026e683b3086ef0cf5e9e8e76 100644 --- a/tensorflow/core/kernels/data/dataset.h +++ b/tensorflow/core/kernels/data/dataset.h @@ -15,595 +15,6 @@ limitations under the License. #ifndef TENSORFLOW_CORE_KERNELS_DATA_DATASET_H_ #define TENSORFLOW_CORE_KERNELS_DATA_DATASET_H_ -#include - -#include "tensorflow/core/framework/attr_value.pb.h" -#include "tensorflow/core/framework/attr_value_util.h" #include "tensorflow/core/framework/dataset.h" -#include "tensorflow/core/framework/function.h" -#include "tensorflow/core/framework/graph.pb.h" -#include "tensorflow/core/framework/node_def.pb.h" -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/register_types.h" -#include "tensorflow/core/framework/types.pb.h" -#include "tensorflow/core/framework/variant_encode_decode.h" -#include "tensorflow/core/framework/variant_tensor_data.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/platform/tracing.h" - -// Polymorphic datasets should support all primitive TensorFlow -// types. Use this macro to expand `m(T)` once for each primitive type -// `T`, e.g. to build a `switch` statement. -#define TF_CALL_DATASET_TYPES(m) TF_CALL_ALL_TYPES(m) TF_CALL_QUANTIZED_TYPES(m) - -namespace tensorflow { - -// Interface for reading values from a key-value store. -// Used for restoring iterator state. -class IteratorStateReader { - public: - virtual Status ReadScalar(StringPiece key, int64* val) = 0; - virtual Status ReadScalar(StringPiece key, string* val) = 0; - virtual Status ReadTensor(StringPiece key, Tensor* val) = 0; - virtual bool Contains(StringPiece key) = 0; - - virtual ~IteratorStateReader() {} -}; - -// Interface for writing values to a key-value store. -// Used for saving iterator state. -class IteratorStateWriter { - public: - virtual Status WriteScalar(StringPiece key, const int64 val) = 0; - virtual Status WriteScalar(StringPiece key, const string& val) = 0; - virtual Status WriteTensor(StringPiece key, const Tensor& val) = 0; - - virtual ~IteratorStateWriter() {} -}; - -// Forward declarations to avoid introducing a dependency on headers in -// "tensorflow/core/graph/...". -class GraphDefBuilder; -class GraphDatasetBase; -class Node; - -// Wrapper around GraphDefBuilder. Used to serialize Dataset graph. -class GraphDefBuilderWrapper { - public: - explicit GraphDefBuilderWrapper(GraphDefBuilder* b) : b_(b) {} - - // Adds a Const node with scalar value to the Graph. - // `*output` contains a pointer to the output `Node`. It is guaranteed to be - // non-null if the method returns with an OK status. - // The returned Node pointer is owned by the backing Graph of GraphDefBuilder. - template - Status AddScalar(const T& val, Node** output) { - Tensor val_t = Tensor(DataTypeToEnum::v(), TensorShape({})); - val_t.scalar()() = val; - AddTensorInternal(val_t, output); - if (*output == nullptr) { - return errors::Internal("AddScalar: Failed to build Const op."); - } - return Status::OK(); - } - - // Adds a Const node with vector value to the Graph. - // `*output` contains a pointer to the output `Node`. It is guaranteed to be - // non-null if the method returns with an OK status. - // The returned Node pointer is owned by the backing Graph of GraphDefBuilder. - // TODO(shivaniagrawal): Consider changing to gtl::ArraySlice? - template - Status AddVector(const std::vector& val, Node** output) { - Tensor val_t = Tensor(DataTypeToEnum::v(), - TensorShape({static_cast(val.size())})); - for (int i = 0; i < val.size(); i++) { - val_t.flat()(i) = val[i]; - } - AddTensorInternal(val_t, output); - if (*output == nullptr) { - return errors::Internal("AddVector: Failed to build Const op."); - } - return Status::OK(); - } - - // Adds a Const node with Tensor value to the Graph. - // `*output` contains a pointer to the output `Node`. It is guaranteed to be - // non-null if the method returns with an OK status. - // The returned Node pointer is owned by the backing Graph of GraphDefBuilder. - Status AddTensor(const Tensor& val, Node** output) { - AddTensorInternal(val, output); - if (*output == nullptr) { - return errors::Internal("AddTensor: Failed to build Const op."); - } - return Status::OK(); - } - - Status AddDataset(const GraphDatasetBase* dataset, - const std::vector& inputs, Node** output) { - return AddDataset(dataset, inputs, {}, output); - } - - // Adds a node corresponding to the `DatasetType` to the Graph. - // Return value of `DatasetType::op_name()` is used as the op type for the - // node. - // Values for the output_types and output_shapes node attributes are also - // written if those attributes are defined in the OpDef. - // `*output` contains a pointer to the output `Node`. It is guaranteed to be - // non-null if the method returns with an OK status. - // The returned Node pointer is owned by the backing Graph of GraphDefBuilder. - Status AddDataset(const GraphDatasetBase* dataset, - const std::vector& inputs, - const std::vector>& attrs, - Node** output) { - std::vector> enumerated_inputs(inputs.size()); - for (int i = 0; i < inputs.size(); i++) { - enumerated_inputs[i] = std::make_pair(i, inputs[i]); - } - return AddDataset(dataset, enumerated_inputs, {}, attrs, output); - } - - Status AddDataset( - const GraphDatasetBase* dataset, - const std::vector>& inputs, - const std::vector>>& list_inputs, - const std::vector>& attrs, - Node** output); - - // Adds a user-defined function with name `function_name` to the graph and - // recursively adds all functions it references. If a function with a matching - // name has already been added, returns with OK status. If a user-defined with - // name `function_name` is not found in the FunctionLibraryDefinition, returns - // an InvalidArgumentError. If the function with name `function_name` or any - // of its dependent functions are stateful, returns an InvalidArgument error. - Status AddFunction(OpKernelContext* ctx, const string& function_name); - - template - void BuildAttrValue(const T& value, AttrValue* attr) { - SetAttrValue(value, attr); - } - - private: - void AddTensorInternal(const Tensor& val, Node** output); - - Status EnsureFunctionIsStateless(OpKernelContext* ctx, - const string& function_name) const { - const FunctionLibraryDefinition* lib_def = - ctx->function_library()->GetFunctionLibraryDefinition(); - const FunctionDef* function_def = lib_def->Find(function_name); - if (!function_def) { - return errors::InvalidArgument("Unable to find FunctionDef for ", - function_name, " in registry."); - } - for (const NodeDef& node_def : function_def->node_def()) { - const OpDef* op_def; - TF_RETURN_IF_ERROR(lib_def->LookUpOpDef(node_def.op(), &op_def)); - // TODO(b/65524810): Hack to allow functions to capture Dataset op - // nodes needed for FlatMap. Currently, source datasets nodes have been - // marked stateful to avoid constant folding since we do not have a - // good way of serializing them. - if (IsOpWhitelisted(op_def)) { - continue; - } - if (op_def->is_stateful()) { - return errors::InvalidArgument( - "Op[name: ", node_def.name(), ", type: ", node_def.op(), "] ", - "in function ", function_name, " is stateful. ", - "Saving stateful functions is not supported yet."); - } - } - return Status::OK(); - } - - // Returns whether an op has been whitelisted for use inside map_fns. - // Uses a heuristic to whitelist source dataset ops which have been - // marked stateful due to b/65524810. - // Also looks up the `op_def->name` in the global - // `WhitelistedStatefulOpRegistry`. - bool IsOpWhitelisted(const OpDef* op_def) const { - return (StringPiece(op_def->name()).ends_with("Dataset") && - op_def->output_arg_size() == 1 && - op_def->output_arg(0).type() == DT_VARIANT) || - dataset::WhitelistedStatefulOpRegistry::Global()->Contains( - op_def->name()); - } - - bool HasAttr(const string& op_type_name, const string& attr_name) const; - - bool HasAttr(const OpDef* op_def, const string& attr_name) const { - for (auto attr : op_def->attr()) { - if (attr.name() == attr_name) { - return true; - } - } - return false; - } - - Status AddAttrFunctions(const AttrValue& attr_value, OpKernelContext* ctx) { - if (attr_value.has_func()) { - TF_RETURN_IF_ERROR(AddFunction(ctx, attr_value.func().name())); - } else if (attr_value.has_list()) { - for (const NameAttrList& name_attr_list : attr_value.list().func()) { - TF_RETURN_IF_ERROR(AddFunction(ctx, name_attr_list.name())); - } - } - return Status::OK(); - } - - GraphDefBuilder* b_; -}; - -class StatsAggregator; - -// A cut-down version of OpKernelContext for running computations in -// iterators. Note that we cannot simply use OpKernelContext here -// because we might run computation in an iterator whose lifetime is -// not nested within the lifetime of a single OpKernelContext -// (e.g. asynchronous prefetching). -// -// TODO(mrry): We will probably need to support more of -// OpKernelContext here. For example, should allocation be handled by -// the IteratorContext? -// TODO(mrry): We're making some daring assumptions about the lifetime -// of the runner passed in here. A runner will be deleted when the original -// step ends, but all existing runners only close over session-lifetime (or -// longer-lived) state, so we can make a copy of the function. There's nothing -// in the definition of the API from which we took the runner to guarantee that -// what we are doing is safe. We should formalize the properties here. -class IteratorContext { - public: - struct Params { - // Interface to operating system functionality. - Env* env; - - // Function call support. - std::function)> runner = nullptr; - - // A function that returns the current `StatsAggregator` instance to be - // used when recording statistics about the iterator. - // - // NOTE(mrry): This is somewhat awkward, because (i) the `StatsAggregator` - // is a property of the `IteratorResource` (which this class does not know - // about), and (ii) it can change after the `IteratorContext` has been - // created. Better suggestions are welcome! - std::function()> stats_aggregator_getter = - nullptr; - - // The FunctionLibraryRuntime object to be used to make function calls. - FunctionLibraryRuntime* lib = nullptr; - std::shared_ptr function_library = nullptr; - }; - - explicit IteratorContext(Params params) : params_(std::move(params)) {} - - Env* env() const { return params_.env; } - - std::function)>* runner() { - return ¶ms_.runner; - } - - std::shared_ptr stats_aggregator() { - if (params_.stats_aggregator_getter) { - return params_.stats_aggregator_getter(); - } else { - return nullptr; - } - } - - std::shared_ptr function_library() { - return params_.function_library; - } - - FunctionLibraryRuntime* lib() { return params_.lib; } - - void set_lib(FunctionLibraryRuntime* lib) { params_.lib = lib; } - - private: - Params params_; -}; - -// Represents the current position in a range of outputs, where the -// range of outputs is typically represented by an `DatasetBase`, -// defined below. -class IteratorBase { - public: - virtual ~IteratorBase() {} - - // Gets the next output from the range that this iterator is traversing. - // - // If at least one output remains in this iterator's range, that - // output will be stored in `*out_tensors` and `false` will be - // stored in `*end_of_sequence`. - // - // If no more outputs remain in this iterator's range, `true` will - // be stored in `*end_of_sequence`, and the content of - // `*out_tensors` will be undefined. - // - // This method is thread-safe. - // - // TODO(mrry): Define `GetNextAsync()` or `GetNextManyAsync()`, and - // potentially remove this method. - virtual Status GetNext(IteratorContext* ctx, std::vector* out_tensors, - bool* end_of_sequence) = 0; - - // Returns a vector of DataType values, representing the respective - // element types of each tuple component in the outputs of this - // iterator. - virtual const DataTypeVector& output_dtypes() const = 0; - - // Returns a vector of tensor shapes, representing the respective - // (and possibly partially defined) shapes of each tuple component - // in the outputs of this iterator. - virtual const std::vector& output_shapes() const = 0; - - // Saves the state of this iterator. - virtual Status Save(OpKernelContext* ctx, IteratorStateWriter* writer) { - return SaveInternal(writer); - } - - // Restores the state of this iterator. - virtual Status Restore(IteratorContext* ctx, IteratorStateReader* reader) { - return RestoreInternal(ctx, reader); - } - - protected: - // This is needed so that sub-classes of IteratorBase can call - // `SaveInternal` on their parent iterators, e.g., in - // `RepeatDataasetOp::Dataset`. - Status SaveParent(IteratorStateWriter* writer, - const std::unique_ptr& parent) { - return parent->SaveInternal(writer); - } - - // This is needed so that sub-classes of IteratorBase can call - // `RestoreInternal` on their parent iterators, e.g., in - // `RepeatDataasetOp::Dataset`. - Status RestoreParent(IteratorContext* ctx, IteratorStateReader* reader, - const std::unique_ptr& parent) { - return parent->RestoreInternal(ctx, reader); - } - - // Saves the state of this iterator recursively. - virtual Status SaveInternal(IteratorStateWriter* writer) { - return errors::Unimplemented("SaveInternal"); - } - - // Restores the state of this iterator recursively. - virtual Status RestoreInternal(IteratorContext* ctx, - IteratorStateReader* reader) { - return errors::Unimplemented("RestoreInternal"); - } -}; - -// Represents a (potentially infinite) range of outputs, where each -// output is a tuple of tensors. -class DatasetBase : public core::RefCounted { - public: - // Returns a new iterator for iterating over the range of elements in - // this dataset. - // - // This method may be called multiple times on the same instance, - // and the resulting iterators will have distinct state. Each - // iterator will traverse all elements in this dataset from the - // start. - // - // Ownership of the created iterator will be transferred to the caller. - // - // The prefix identifies the sequence of iterators leading up to the newly - // created iterator. - virtual std::unique_ptr MakeIterator( - const string& prefix) const = 0; - - // Returns a vector of DataType values, representing the respective - // element types of each tuple component in the outputs of this - // dataset. - virtual const DataTypeVector& output_dtypes() const = 0; - - // Returns a vector of tensor shapes, representing the respective - // (and possibly partially defined) shapes of each tuple component - // in the outputs of this dataset. - virtual const std::vector& output_shapes() const = 0; - - // A human-readable debug string for this dataset. - virtual string DebugString() = 0; - - // Serializes the dataset and writes it to the `writer`. - virtual Status Save(OpKernelContext* ctx, IteratorStateWriter* writer) const { - return errors::Unimplemented("DatasetBase::Save"); - } - - protected: - // TODO(srbs): Ideally all graph related logic should reside in - // GraphDatasetBase. However, that would require Datasets defined in all ops - // to derive from GraphDatasetBase. Once that is done we can move - // DatasetGraphDefBuilder and AsGraphDefInternal to GraphDatasetBase. - class DatasetGraphDefBuilder : public GraphDefBuilderWrapper { - public: - DatasetGraphDefBuilder(GraphDefBuilder* b) : GraphDefBuilderWrapper(b) {} - Status AddParentDataset(OpKernelContext* ctx, const DatasetBase* dataset, - Node** output) { - return dataset->AsGraphDefInternal(ctx, this, output); - } - }; - - virtual Status AsGraphDefInternal(OpKernelContext* ctx, - DatasetGraphDefBuilder* b, - Node** node) const { - return AsGraphDefInternal(b, node); - } - - virtual Status AsGraphDefInternal(DatasetGraphDefBuilder* b, - Node** node) const { - return errors::Unimplemented("AsGraphDefInternal"); - } -}; - -// Base-class for datasets that are built by ops. -class GraphDatasetBase : public DatasetBase { - public: - GraphDatasetBase(OpKernelContext* ctx) - : op_name_(ctx->op_kernel().type_string()) {} - - const string op_name() const { return op_name_; } - - Status Save(OpKernelContext* ctx, - IteratorStateWriter* writer) const override { - string serialized_graph_def; - string output_node; - TF_RETURN_IF_ERROR(Serialize(ctx, &serialized_graph_def, &output_node)); - TF_RETURN_IF_ERROR( - writer->WriteScalar(kDatasetGraphKey, serialized_graph_def)); - TF_RETURN_IF_ERROR( - writer->WriteScalar(kDatasetGraphOutputNodeKey, output_node)); - return Status::OK(); - } - - // Key for storing the Dataset graph in the serialized format. - static const char kDatasetGraphKey[]; - - // Key for storing the output node of the Dataset graph in the serialized - // format. - static const char kDatasetGraphOutputNodeKey[]; - - private: - Status Serialize(OpKernelContext* ctx, string* serialized_graph_def, - string* output_node) const; - - const string op_name_; -}; - -// Represents an iterator that is associated with a particular parent dataset. -template -class DatasetIterator : public IteratorBase { - public: - struct Params { - // Owns one reference on the shared dataset resource. - const DatasetType* dataset; - - // Identifies the sequence of iterators leading up to this iterator. - const string prefix; - }; - - explicit DatasetIterator(const Params& params) : params_(params) { - params_.dataset->Ref(); - } - - ~DatasetIterator() override { params_.dataset->Unref(); } - - // The dataset from which this iterator was created. - const DatasetType* dataset() const { return params_.dataset; } - - // The sequence of iterators leading up to this iterator. - const string prefix() const { return params_.prefix; } - - const DataTypeVector& output_dtypes() const override { - return params_.dataset->output_dtypes(); - } - - const std::vector& output_shapes() const override { - return params_.dataset->output_shapes(); - } - - Status GetNext(IteratorContext* ctx, std::vector* out_tensors, - bool* end_of_sequence) final { - port::Tracing::TraceMe activity(params_.prefix); - Status s = GetNextInternal(ctx, out_tensors, end_of_sequence); - if (TF_PREDICT_FALSE(errors::IsOutOfRange(s) && !*end_of_sequence)) { - s = errors::Internal( - "Iterator \"", params_.prefix, - "\" returned OutOfRange without setting `*end_of_sequence`. This " - "indicates that an error may have occurred. Original message: ", - s.error_message()); - LOG(ERROR) << s; - } - return s; - } - - Status Save(OpKernelContext* ctx, IteratorStateWriter* writer) final { - TF_RETURN_IF_ERROR(dataset()->Save(ctx, writer)); - return IteratorBase::Save(ctx, writer); - } - - protected: - // Internal implementation of GetNext that is wrapped in tracing logic. - virtual Status GetNextInternal(IteratorContext* ctx, - std::vector* out_tensors, - bool* end_of_sequence) = 0; - - string full_name(const string& name) const { - return strings::StrCat(prefix(), ":", name); - } - - private: - Params params_; -}; - -// Encapsulates the work required to plug a DatasetBase into the core TensorFlow -// graph execution engine. -class DatasetOpKernel : public OpKernel { - public: - DatasetOpKernel(OpKernelConstruction* ctx) : OpKernel(ctx) {} - void Compute(OpKernelContext* ctx) final; - - protected: - // Subclasses should implement this method. It will be called during Compute - // execution. - virtual void MakeDataset(OpKernelContext* ctx, DatasetBase** output) = 0; - - template - Status ParseScalarArgument(OpKernelContext* ctx, - const StringPiece& argument_name, T* output) { - const Tensor* argument_t; - TF_RETURN_IF_ERROR(ctx->input(argument_name, &argument_t)); - if (!TensorShapeUtils::IsScalar(argument_t->shape())) { - return errors::InvalidArgument(argument_name, " must be a scalar"); - } - *output = argument_t->scalar()(); - return Status::OK(); - } -}; - -// Encapsulates the work required to plug unary Datasets into the core -// TensorFlow graph execution engine. -class UnaryDatasetOpKernel : public DatasetOpKernel { - public: - UnaryDatasetOpKernel(OpKernelConstruction* ctx) : DatasetOpKernel(ctx) {} - - protected: - void MakeDataset(OpKernelContext* ctx, DatasetBase** output) final; - virtual void MakeDataset(OpKernelContext* ctx, DatasetBase* input, - DatasetBase** output) = 0; -}; - -// Encapsulates the work required to plug binary Datasets into the core -// TensorFlow graph execution engine. -class BinaryDatasetOpKernel : public DatasetOpKernel { - public: - BinaryDatasetOpKernel(OpKernelConstruction* ctx) : DatasetOpKernel(ctx) {} - - protected: - void MakeDataset(OpKernelContext* ctx, DatasetBase** output) final; - virtual void MakeDataset(OpKernelContext* ctx, DatasetBase* input, - DatasetBase* another_input, - DatasetBase** output) = 0; -}; - -// Validates and extracts a `DatasetBase` object from `tensor`. -// -// `tensor` must have been written by a call to SetVariantTensorToDataset(). -// -// The retrieved pointer is a borrowed reference to the dataset, which is owned -// by the tensor. The consumer must either acquire its own reference to the -// dataset by calling `(*out_dataset)->Ref()`, or ensure that `tensor` is not -// destroyed or mutated while the retrieved pointer is in use. -Status GetDatasetFromVariantTensor(const Tensor& tensor, - DatasetBase** out_dataset); - -// Stores a `DatasetBase` object in `tensor`. -// -// The ownership of `dataset` is transferred to `tensor`. -Status StoreDatasetInVariantTensor(DatasetBase* dataset, Tensor* tensor); - -} // namespace tensorflow #endif // TENSORFLOW_CORE_KERNELS_DATA_DATASET_H_ diff --git a/tensorflow/core/kernels/data/dense_to_sparse_batch_dataset_op.cc b/tensorflow/core/kernels/data/dense_to_sparse_batch_dataset_op.cc index e7224bb547f60f943c7c91c37edfbbf561f5351a..132808a5f140a31fc3c1852cb83e5cd8579b6d95 100644 --- a/tensorflow/core/kernels/data/dense_to_sparse_batch_dataset_op.cc +++ b/tensorflow/core/kernels/data/dense_to_sparse_batch_dataset_op.cc @@ -155,7 +155,7 @@ class DenseToSparseBatchDatasetOp : public UnaryDatasetOpKernel { // Determine the size of the output tensors: // * dense_shape will be [`row_shape + 1`]. - Tensor dense_shape(cpu_allocator(), DT_INT64, {row_ndims + 1}); + Tensor dense_shape(ctx->allocator({}), DT_INT64, {row_ndims + 1}); auto dense_shape_vec = dense_shape.vec(); for (size_t i = 0; i < row_ndims; ++i) { if (row_shape.dim_size(i) == -1) { @@ -215,10 +215,10 @@ class DenseToSparseBatchDatasetOp : public UnaryDatasetOpKernel { // * indices will be [`total_elements`, `row_shape + 1`]. // * values will be [`total_elements`]. - Tensor indices(cpu_allocator(), DT_INT64, + Tensor indices(ctx->allocator({}), DT_INT64, {total_elements, row_ndims + 1}); Tensor values( - cpu_allocator(), + ctx->allocator({}), DatasetIterator>::dataset()->input_->output_dtypes()[0], {total_elements}); auto indices_matrix = indices.matrix(); diff --git a/tensorflow/core/kernels/data/iterator_ops.cc b/tensorflow/core/kernels/data/iterator_ops.cc index b37bd672addc5eef9eca51259dfcf86bb77782e8..8a420ac26dde851ec9ae13ecc241685a698d9fac 100644 --- a/tensorflow/core/kernels/data/iterator_ops.cc +++ b/tensorflow/core/kernels/data/iterator_ops.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/graph_runner.h" +#include "tensorflow/core/common_runtime/renamed_device.h" #include "tensorflow/core/common_runtime/threadpool_device.h" #include "tensorflow/core/framework/iterator.pb.h" #include "tensorflow/core/framework/partial_tensor_shape.h" @@ -458,7 +459,7 @@ class IteratorHandleOp : public OpKernel { { mutex_lock l(mu_); if (resource_ == nullptr) { - FunctionLibraryRuntime* lib = context->function_library(); + FunctionLibraryRuntime* lib; std::unique_ptr device_mgr(nullptr); std::unique_ptr flib_def(nullptr); std::unique_ptr pflr(nullptr); @@ -468,6 +469,9 @@ class IteratorHandleOp : public OpKernel { // is sufficient demand, but it will require a significant refactoring. if (!name_.empty()) { lib = CreatePrivateFLR(context, &device_mgr, &flib_def, &pflr); + } else { + OP_REQUIRES_OK(context, context->function_library()->Clone( + &flib_def, &pflr, &lib)); } ResourceMgr* mgr = context->resource_manager(); @@ -516,15 +520,32 @@ class IteratorHandleOp : public OpKernel { return Status::OK(); } + template // use like this: down_cast(foo); + static inline To down_cast(From* f) { // so we only accept pointers + static_assert( + (std::is_base_of::type>::value), + "target type not derived from source type"); + + // We skip the assert and hence the dynamic_cast if RTTI is disabled. +#if !defined(__GNUC__) || defined(__GXX_RTTI) + // Uses RTTI in dbg and fastbuild. asserts are disabled in opt builds. + assert(f == nullptr || dynamic_cast(f) != nullptr); +#endif // !defined(__GNUC__) || defined(__GXX_RTTI) + return static_cast(f); + } + FunctionLibraryRuntime* CreatePrivateFLR( OpKernelContext* ctx, std::unique_ptr* device_mgr, std::unique_ptr* flib_def, std::unique_ptr* pflr) { - Device* device = new ThreadPoolDevice( - SessionOptions(), ctx->device()->attributes().name(), Bytes(256 << 20), - DeviceLocality(), cpu_allocator()); - - device_mgr->reset(new DeviceMgr({device})); + // Wrap the existing device in order to see any captured resources + // in its resource manager. The existing device will outlive the + // IteratorResource, because we are storing the IteratorResource + // in that device's resource manager. + Device* wrapped_device = RenamedDevice::NewRenamedDevice( + ctx->device()->name(), down_cast(ctx->device()), + false /* owns_underlying */, false /* isolate_session_state */); + device_mgr->reset(new DeviceMgr({wrapped_device})); flib_def->reset(new FunctionLibraryDefinition( *ctx->function_library()->GetFunctionLibraryDefinition())); pflr->reset(new ProcessFunctionLibraryRuntime( @@ -532,7 +553,7 @@ class IteratorHandleOp : public OpKernel { {} /* TODO(mrry): OptimizerOptions? */, nullptr /* TODO(mrry): ClusterFLR */)); - return (*pflr)->GetFLR(device->name()); + return (*pflr)->GetFLR(ctx->device()->name()); } mutex mu_; diff --git a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc index c529f671f2bb7fd3eb5277c23867e25ba70fd046..9ce263732f6e6c907dfdc89692455daa5dca86d1 100644 --- a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc +++ b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc @@ -183,7 +183,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { TensorShape component_shape( batch_results_[current_batch_index_].output[i].shape()); component_shape.set_dim(0, num_elements); - Tensor component(cpu_allocator(), output[i].dtype(), + Tensor component(ctx->allocator({}), output[i].dtype(), component_shape); TF_RETURN_IF_ERROR( CopyPartialBatch(&component, output[i], num_elements)); @@ -244,7 +244,8 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { return Status::OK(); } - void EnsureOutputAllocated(BatchResult* batch_result, + void EnsureOutputAllocated(IteratorContext* ctx, + BatchResult* batch_result, const std::vector& return_values) { mutex_lock l(batch_result->mu); if (batch_result->output_allocated) { @@ -254,7 +255,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { for (size_t i = 0; i < num_components; ++i) { TensorShape component_shape({dataset()->batch_size_}); component_shape.AppendShape(return_values[i].shape()); - Tensor component(cpu_allocator(), return_values[i].dtype(), + Tensor component(ctx->allocator({}), return_values[i].dtype(), component_shape); batch_result->output.emplace_back(std::move(component)); } @@ -285,10 +286,9 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { dataset()->captured_func_->RunAsync( ctx, std::move(input_element), &result->return_values, [this, ctx, result, batch_result, offset](Status ret_status) { - delete ctx; result->status.Update(ret_status); if (ret_status.ok()) { - EnsureOutputAllocated(batch_result, + EnsureOutputAllocated(ctx, batch_result, result->return_values); const size_t num_components = result->return_values.size(); @@ -318,6 +318,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { } } } + delete ctx; // NOTE(mrry): We clear the return values here to release // any memory associated with them and to paralellize the // destruction of the tensors (which can be surprisingly diff --git a/tensorflow/core/kernels/data/padded_batch_dataset_op.cc b/tensorflow/core/kernels/data/padded_batch_dataset_op.cc index 346eca0bb2ab1c7a82ddba98063c0ccb71b4e58f..cfb4efda9a56fde04994201f509cf3d9fb45ea82 100644 --- a/tensorflow/core/kernels/data/padded_batch_dataset_op.cc +++ b/tensorflow/core/kernels/data/padded_batch_dataset_op.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_util.h" +#include "tensorflow/core/kernels/batch_util.h" #include "tensorflow/core/kernels/data/dataset.h" namespace tensorflow { @@ -24,102 +25,6 @@ namespace { // See documentation in ../ops/dataset_ops.cc for a high-level // description of the following op. -// The following five functions are copied from padding_fifo_queue.cc. -// TODO(mrry): Reconcile these functions with the similar methods in the -// queue implementation. -Status ValidateElementToLargerSlice(const Tensor& element, Tensor* parent) { - DCHECK_NE(parent->dim_size(0), 0); - if (element.NumElements() > (parent->NumElements() / parent->dim_size(0))) { - TensorShape chip_shape = parent->shape(); - chip_shape.RemoveDim(0); - return errors::Internal( - "HandleElementToLargerSlice Cannot copy slice: number of entries in " - "element is greater than number of elements in parent slice. ", - "Shapes are: [element]: ", element.shape().DebugString(), - ", [parent slice]: ", chip_shape.DebugString()); - } - return Status::OK(); -} - -template -Status HandleElementToLargerSlice(const Tensor& element, Tensor* parent, - int index) { - TF_RETURN_IF_ERROR(ValidateElementToLargerSlice(element, parent)); - if (element.NumElements() == 0) { - return Status::OK(); - } - auto element_t = element.tensor(); - auto parent_t = parent->tensor(); - Eigen::DSizes slice_indices; - slice_indices[0] = index; - Eigen::DSizes slice_size; - slice_size[0] = 1; - for (size_t i = 1; i < slice_size.size(); ++i) { - slice_size[i] = element_t.dimension(i - 1); - } - parent_t.slice(slice_indices, slice_size) = element_t.reshape(slice_size); - return Status::OK(); -} - -template -Status HandleElementToLargerSliceWithRank(const Tensor& element, Tensor* parent, - int index) { -#define HANDLE_TYPE(T) \ - case DataTypeToEnum::value: { \ - return HandleElementToLargerSlice(element, parent, index); \ - } - - switch (element.dtype()) { - TF_CALL_DATASET_TYPES(HANDLE_TYPE); -#undef HANDLE_TYPE - default: - return errors::Unimplemented( - "HandleElementToLargerSliceWithRank Unhandled data type: ", - element.dtype()); - } -} - -Status CopyElementToLargerSlice(const Tensor& element, Tensor* parent, - int index) { - if (parent->dims() != element.dims() + 1) { - return errors::Internal( - "Mismatched ranks. Element's rank is: ", element.dims(), - " but element is meant to be a slice in output Tensor having rank: ", - parent->dims(), " (should be: ", element.dims() + 1, ")"); - } - -#define HANDLE_DIMS(NDIMS) \ - case NDIMS: { \ - TF_RETURN_IF_ERROR( \ - HandleElementToLargerSliceWithRank(element, parent, index)); \ - return Status::OK(); \ - } - - switch (element.dims()) { - HANDLE_DIMS(0); - HANDLE_DIMS(1); - HANDLE_DIMS(2); - HANDLE_DIMS(3); - HANDLE_DIMS(4); -#undef HANDLE_DIMS - default: - return errors::Unimplemented("CopyElementToLargerSlice Unhandled rank: ", - element.dims()); - } -} - -Status SetElementZero(Tensor* element, const Tensor& padding) { -#define HANDLE_TYPE(T) \ - if (element->dtype() == DataTypeToEnum::value) { \ - element->flat().setConstant(padding.scalar()()); \ - return Status::OK(); \ - } - TF_CALL_DATASET_TYPES(HANDLE_TYPE); -#undef HANDLE_TYPE - return errors::Unimplemented("SetElementZero Unhandled data type: ", - element->dtype()); -} - class PaddedBatchDatasetOp : public UnaryDatasetOpKernel { public: explicit PaddedBatchDatasetOp(OpKernelConstruction* ctx) @@ -376,20 +281,27 @@ class PaddedBatchDatasetOp : public UnaryDatasetOpKernel { // 2. Copy each batch element to the appropriate location in // the output component tensor. - Tensor batch_component(cpu_allocator(), + Tensor batch_component(ctx->allocator({}), output_dtypes()[component_index], batch_component_shape); - TF_RETURN_IF_ERROR(SetElementZero( + TF_RETURN_IF_ERROR(batch_util::SetElementZero( &batch_component, dataset()->padding_values_[component_index])); // Build the output tuple component by copying one slice // from each input element in the batch. + TensorShape component_shape({}); + for (int i = 1; i < batch_component_shape.dims(); ++i) { + component_shape.AddDim(batch_component_shape.dim_size(i)); + } for (int64 i = 0; i < num_batch_elements; ++i) { - TF_RETURN_IF_ERROR(ValidateElementToLargerSlice( - batch_elements[i][component_index], &batch_component)); - - TF_RETURN_IF_ERROR(CopyElementToLargerSlice( - batch_elements[i][component_index], &batch_component, i)); + // Take the fast path if possible. + if (batch_elements[i][component_index].shape() == component_shape) { + TF_RETURN_IF_ERROR(batch_util::CopyElementToSlice( + batch_elements[i][component_index], &batch_component, i)); + } else { + TF_RETURN_IF_ERROR(batch_util::CopyElementToLargerSlice( + batch_elements[i][component_index], &batch_component, i)); + } } out_tensors->push_back(std::move(batch_component)); } diff --git a/tensorflow/core/kernels/data/random_dataset_op.cc b/tensorflow/core/kernels/data/random_dataset_op.cc index bc638864b0147f4d71b3382ea320453e972ba8d7..210b9ad1b84eeb0c106b0ee538b4957aba7ce1b2 100644 --- a/tensorflow/core/kernels/data/random_dataset_op.cc +++ b/tensorflow/core/kernels/data/random_dataset_op.cc @@ -99,7 +99,7 @@ class RandomDatasetOp : public DatasetOpKernel { std::vector* out_tensors, bool* end_of_sequence) override { mutex_lock l(mu_); - Tensor value_tensor(cpu_allocator(), DT_INT64, {}); + Tensor value_tensor(ctx->allocator({}), DT_INT64, {}); value_tensor.scalar()() = Random(); out_tensors->emplace_back(std::move(value_tensor)); *end_of_sequence = false; diff --git a/tensorflow/core/kernels/data/range_dataset_op.cc b/tensorflow/core/kernels/data/range_dataset_op.cc index d0bc61acd99afae14ddc8a3e678acb4197fcea71..b57518e678ed185a183e0413d6e90f2a9f85e9fc 100644 --- a/tensorflow/core/kernels/data/range_dataset_op.cc +++ b/tensorflow/core/kernels/data/range_dataset_op.cc @@ -100,7 +100,7 @@ class RangeDatasetOp : public DatasetOpKernel { *end_of_sequence = true; return Status::OK(); } - Tensor value_tensor(cpu_allocator(), DT_INT64, {}); + Tensor value_tensor(ctx->allocator({}), DT_INT64, {}); value_tensor.scalar()() = next_; out_tensors->emplace_back(std::move(value_tensor)); *end_of_sequence = false; diff --git a/tensorflow/core/kernels/data/reader_dataset_ops.cc b/tensorflow/core/kernels/data/reader_dataset_ops.cc index aa39fffc2e344db8143b700cbba4c29bdb134964..34d7d9f914d7a726135febabb1fbe35b0146977c 100644 --- a/tensorflow/core/kernels/data/reader_dataset_ops.cc +++ b/tensorflow/core/kernels/data/reader_dataset_ops.cc @@ -141,7 +141,7 @@ class TextLineDatasetOp : public DatasetOpKernel { if (s.ok()) { // Produce the line as output. - Tensor line_tensor(cpu_allocator(), DT_STRING, {}); + Tensor line_tensor(ctx->allocator({}), DT_STRING, {}); line_tensor.scalar()() = line_contents; out_tensors->emplace_back(std::move(line_tensor)); *end_of_sequence = false; @@ -384,7 +384,7 @@ class FixedLengthRecordDatasetOp : public DatasetOpKernel { TF_RETURN_IF_ERROR( input_buffer_->ReadNBytes(dataset()->record_bytes_, &record)); // Produce the record as output. - Tensor record_tensor(cpu_allocator(), DT_STRING, {}); + Tensor record_tensor(ctx->allocator({}), DT_STRING, {}); record_tensor.scalar()() = record; out_tensors->emplace_back(std::move(record_tensor)); *end_of_sequence = false; @@ -589,7 +589,7 @@ class TFRecordDatasetOp : public DatasetOpKernel { do { // We are currently processing a file, so try to read the next record. if (reader_) { - Tensor result_tensor(cpu_allocator(), DT_STRING, {}); + Tensor result_tensor(ctx->allocator({}), DT_STRING, {}); Status s = reader_->ReadRecord(&result_tensor.scalar()()); if (s.ok()) { out_tensors->emplace_back(std::move(result_tensor)); diff --git a/tensorflow/core/kernels/data/skip_dataset_op.cc b/tensorflow/core/kernels/data/skip_dataset_op.cc index 13c2501bbbd43bdb6c3c521db4c3830934ee91db..d636c37afe2aa0566df7d4a38a8d393c34fd0195 100644 --- a/tensorflow/core/kernels/data/skip_dataset_op.cc +++ b/tensorflow/core/kernels/data/skip_dataset_op.cc @@ -128,8 +128,8 @@ class SkipDatasetOp : public UnaryDatasetOpKernel { while (i_ < dataset()->count_) { // Fetch and throw away Tensors. std::vector dummy_out_tensors; - TF_RETURN_IF_ERROR(input_impl_->GetNext(ctx, &dummy_out_tensors, - end_of_sequence)); + TF_RETURN_IF_ERROR( + input_impl_->GetNext(ctx, &dummy_out_tensors, end_of_sequence)); if (*end_of_sequence) { // We reached the end before the count was reached. input_impl_.reset(); @@ -140,8 +140,8 @@ class SkipDatasetOp : public UnaryDatasetOpKernel { } // Return GetNext() on the underlying iterator. - TF_RETURN_IF_ERROR(input_impl_->GetNext(ctx, out_tensors, - end_of_sequence)); + TF_RETURN_IF_ERROR( + input_impl_->GetNext(ctx, out_tensors, end_of_sequence)); if (*end_of_sequence) { input_impl_.reset(); } @@ -184,8 +184,7 @@ class SkipDatasetOp : public UnaryDatasetOpKernel { }; }; -REGISTER_KERNEL_BUILDER(Name("SkipDataset").Device(DEVICE_CPU), - SkipDatasetOp); +REGISTER_KERNEL_BUILDER(Name("SkipDataset").Device(DEVICE_CPU), SkipDatasetOp); } // namespace diff --git a/tensorflow/core/kernels/data/sql/BUILD b/tensorflow/core/kernels/data/sql/BUILD index 0286825af3ef7c04fff6911ddf7daec76479a715..f4698bdaf7ae9767e068e49dad61d2a3d9f739a8 100644 --- a/tensorflow/core/kernels/data/sql/BUILD +++ b/tensorflow/core/kernels/data/sql/BUILD @@ -33,6 +33,7 @@ cc_library( deps = [ "//tensorflow/core:framework", "//tensorflow/core:lib", + "//tensorflow/core/kernels/data:dataset", "//tensorflow/core/lib/db:sqlite", ], ) diff --git a/tensorflow/core/kernels/data/sql/query_connection.h b/tensorflow/core/kernels/data/sql/query_connection.h index f31017bd1981c3809d9b7daaa2dc56256d19d914..e9ffca202ff32f0c0130427c2699ce0449a0903a 100644 --- a/tensorflow/core/kernels/data/sql/query_connection.h +++ b/tensorflow/core/kernels/data/sql/query_connection.h @@ -19,6 +19,8 @@ limitations under the License. namespace tensorflow { +class IteratorContext; + namespace sql { // This interface allows a user to connect to a database, execute a query, and // iterate over the result set, putting the results into an output tensor. @@ -56,7 +58,7 @@ class QueryConnection { // If there are no more rows in the result set, then instead `true` will be // stored in `*end_of_sequence`, and the content of `*out_tensors` will be // undefined. - virtual Status GetNext(std::vector* out_tensors, + virtual Status GetNext(IteratorContext* ctx, std::vector* out_tensors, bool* end_of_sequence) = 0; }; diff --git a/tensorflow/core/kernels/data/sql/sqlite_query_connection.cc b/tensorflow/core/kernels/data/sql/sqlite_query_connection.cc index 029a0aab97290e30783e415274323a1e43f9740b..7cd07bd8eca160bfc62e15adc568742c84711779 100644 --- a/tensorflow/core/kernels/data/sql/sqlite_query_connection.cc +++ b/tensorflow/core/kernels/data/sql/sqlite_query_connection.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/core/kernels/data/sql/sqlite_query_connection.h" #include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/kernels/data/dataset.h" #include "tensorflow/core/lib/strings/stringprintf.h" namespace tensorflow { @@ -48,14 +49,16 @@ Status SqliteQueryConnection::Close() { return Status::OK(); } -Status SqliteQueryConnection::GetNext(std::vector* out_tensors, +Status SqliteQueryConnection::GetNext(IteratorContext* ctx, + std::vector* out_tensors, bool* end_of_sequence) { if (!stmt_) TF_RETURN_IF_ERROR(PrepareQuery()); TF_RETURN_IF_ERROR(stmt_.Step(end_of_sequence)); if (!*end_of_sequence) { for (int i = 0; i < column_count_; i++) { DataType dt = output_types_[i]; - Tensor tensor(cpu_allocator(), dt, {}); + // TODO(mrry): Pass in the `IteratorContext::allocator()`. + Tensor tensor(ctx->allocator({}), dt, {}); FillTensorWithResultSetEntry(dt, i, &tensor); out_tensors->emplace_back(std::move(tensor)); } diff --git a/tensorflow/core/kernels/data/sql/sqlite_query_connection.h b/tensorflow/core/kernels/data/sql/sqlite_query_connection.h index 787c17d6c00d99afad3d7814c3c2daaf4295b1b3..81b19530b7d5964e17bde996de9fa7766af318b7 100644 --- a/tensorflow/core/kernels/data/sql/sqlite_query_connection.h +++ b/tensorflow/core/kernels/data/sql/sqlite_query_connection.h @@ -32,7 +32,7 @@ class SqliteQueryConnection : public QueryConnection { Status Open(const string& data_source_name, const string& query, const DataTypeVector& output_types) override; Status Close() override; - Status GetNext(std::vector* out_tensors, + Status GetNext(IteratorContext* ctx, std::vector* out_tensors, bool* end_of_sequence) override; private: diff --git a/tensorflow/core/kernels/data/sql_dataset_ops.cc b/tensorflow/core/kernels/data/sql_dataset_ops.cc index 72302190802d17f2cb1ed5471017180238aedff3..d50e9c9cf9739044379c7bbe753fc4acc2de311e 100644 --- a/tensorflow/core/kernels/data/sql_dataset_ops.cc +++ b/tensorflow/core/kernels/data/sql_dataset_ops.cc @@ -116,7 +116,7 @@ class SqlDatasetOp : public DatasetOpKernel { } } - Status GetNextInternal(IteratorContext* /*ctx*/, + Status GetNextInternal(IteratorContext* ctx, std::vector* out_tensors, bool* end_of_sequence) override { mutex_lock l(mu_); @@ -132,7 +132,7 @@ class SqlDatasetOp : public DatasetOpKernel { return s; } } - return query_connection_->GetNext(out_tensors, end_of_sequence); + return query_connection_->GetNext(ctx, out_tensors, end_of_sequence); } private: diff --git a/tensorflow/core/kernels/data/tensor_queue_dataset_op.cc b/tensorflow/core/kernels/data/tensor_queue_dataset_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..ff412a4671bd0307e4975027ebd1e098353de238 --- /dev/null +++ b/tensorflow/core/kernels/data/tensor_queue_dataset_op.cc @@ -0,0 +1,646 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/core/framework/partial_tensor_shape.h" +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/variant.h" +#include "tensorflow/core/framework/variant_encode_decode.h" +#include "tensorflow/core/kernels/batch_util.h" +#include "tensorflow/core/kernels/data/dataset.h" + +namespace tensorflow { + +namespace { + +bool IsGreaterEqualToOrCompatibleWith(const PartialTensorShape& a, + const PartialTensorShape& b) { + // Returns true if dims[a] >= dims[b], or are compatible. + if (a.unknown_rank()) return true; + if (a.dims() != b.dims()) return false; + for (int d = 0; d < a.dims(); ++d) { + if (a.dim_size(d) == -1 || b.dim_size(d) == -1) continue; + if (a.dim_size(d) < b.dim_size(d)) return false; + } + return true; +} + +DataTypeVector PrependQueueType(const DataTypeVector& dtypes) { + DataTypeVector out; + out.reserve(dtypes.size() + 1); + out.push_back(DT_VARIANT); // The queue component. + for (const DataType& d : dtypes) out.push_back(d); + return out; +} + +std::vector PrependQueueShapeWithBatch( + const std::vector& shapes) { + std::vector out; + out.reserve(shapes.size() + 1); + out.emplace_back(PartialTensorShape({-1})); // The queue component. + for (PartialTensorShape s : shapes) { + s.InsertDim(0, -1); // Unknown batch size. + out.push_back(std::move(s)); + } + return out; +} + +class EnqueueInQueueDatasetOp; + +class PrependFromQueueAndPaddedBatchDataset : public GraphDatasetBase { + public: + PrependFromQueueAndPaddedBatchDataset( + OpKernelContext* ctx, const int64 batch_size, const DatasetBase* input, + const DataTypeVector& dtypes, + const std::vector& shapes, + std::vector padding_values) + : GraphDatasetBase(ctx), + batch_size_(batch_size), + input_(input), + dtypes_(dtypes), + shapes_(shapes), + padding_values_(std::move(padding_values)), + dtypes_with_queue_(PrependQueueType(dtypes)), + batched_shapes_with_queue_(PrependQueueShapeWithBatch(shapes)) { + input_->Ref(); + } + + ~PrependFromQueueAndPaddedBatchDataset() override { input_->Unref(); } + + std::unique_ptr MakeIterator( + const string& prefix) const override { + return std::unique_ptr(new Iterator( + {this, strings::StrCat(prefix, "::PrependFromQueueAndPaddedBatch")})); + } + + const DataTypeVector& output_dtypes() const override { + return dtypes_with_queue_; + } + const std::vector& output_shapes() const override { + return batched_shapes_with_queue_; + } + + string DebugString() override { + return "PrependFromQueueAndPaddedBatchDatasetOp::Dataset"; + } + + protected: + Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, + Node** output) const override { + Node* input_graph = nullptr; + TF_RETURN_IF_ERROR(b->AddParentDataset(ctx, input_, &input_graph)); + Node* batch_size = nullptr; + TF_RETURN_IF_ERROR(b->AddScalar(batch_size_, &batch_size)); + + std::vector padded_shapes; + padded_shapes.reserve(shapes_.size()); + for (int i = 0; i < shapes_.size(); i++) { + Node* node; + Tensor t(DT_INT64, TensorShape({shapes_[i].dims()})); + for (int j = 0; j < shapes_[i].dims(); j++) { + t.vec()(j) = shapes_[i].dim_size(j); + } + TF_RETURN_IF_ERROR(b->AddTensor(t, &node)); + padded_shapes.emplace_back(node); + } + + std::vector padding_values; + padding_values.reserve(padding_values_.size()); + for (const Tensor& t : padding_values_) { + Node* node; + TF_RETURN_IF_ERROR(b->AddTensor(t, &node)); + padding_values.emplace_back(node); + } + + AttrValue output_types; + b->BuildAttrValue(dtypes_, &output_types); + + AttrValue output_shapes; + b->BuildAttrValue(batched_shapes_with_queue_, &output_shapes); + + AttrValue N; + b->BuildAttrValue(shapes_.size(), &N); + + TF_RETURN_IF_ERROR(b->AddDataset(this, {{0, input_graph}, {1, batch_size}}, + {{2, padded_shapes}, {3, padding_values}}, + {{"Toutput_types", output_types}, + {"output_shapes", output_shapes}, + {"N", N}}, + output)); + + return Status::OK(); + } + + private: + friend class EnqueueInQueueDatasetOp; + + class Iterator + : public DatasetIterator { + public: + explicit Iterator(const Params& params) + : DatasetIterator(params), + queue_(new TensorQueue(/*input_impl*/ + params.dataset->input_->MakeIterator( + params.prefix), + params.dataset->dtypes_, + params.dataset->shapes_)) {} + + ~Iterator() override { queue_->Unref(); } + + Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { + std::vector> batch; + TF_RETURN_IF_ERROR(queue_->GetNext(ctx, dataset()->batch_size_, &batch, + end_of_sequence)); + const auto& dtypes = dataset()->dtypes_; + const auto& shapes = dataset()->shapes_; + const auto& input_shapes = dataset()->input_->output_shapes(); + const auto& padding_values = dataset()->padding_values_; + const int64 batch_size = batch.size(); + out_tensors->reserve(dtypes.size()); + + std::vector max_shapes; // Of non-queue components. + for (int i = 0; i < dtypes.size(); ++i) { + const PartialTensorShape& shape = shapes[i]; + TensorShape out_shape({batch_size}); + for (int r = 0; r < shape.dims(); ++r) { + if (shape.dim_size(r) >= 0) { + // padded_shape[r] is known. + out_shape.AddDim(shape.dim_size(r)); + } else { + // padded_shape[r] is unknown, find the maximum across + // the batch. + int64 dim = 0; + for (int b = 0; b < batch.size(); ++b) { + dim = std::max(dim, batch[b][i].dim_size(r)); + } + out_shape.AddDim(dim); + } + } + max_shapes.push_back(std::move(out_shape)); + } + + Tensor queues_t(cpu_allocator(), DT_VARIANT, TensorShape({batch_size})); + if (!batch.empty()) { + auto queues = queues_t.flat(); + Variant& queue_inserter = queues(0); + queue_inserter = TensorQueueInserter(); + queue_inserter.get()->set_queue(queue_); + for (int b = 1; b < batch.size(); ++b) { + // Copy the TensorQueueInserter. Each copy increments the + // Ref on the queue_. + queues(b) = queues(0); + } + } + out_tensors->push_back(std::move(queues_t)); + + for (int i = 0; i < max_shapes.size(); ++i) { + Tensor component(cpu_allocator(), dtypes[i], max_shapes[i]); + // Try hard to take the fast path. + if (shapes[i].IsFullyDefined() && + shapes[i].IsIdenticalTo(input_shapes[i])) { + // Take the fast path if we know all the shapes statically. + for (int64 b = 0; b < batch.size(); ++b) { + TF_RETURN_IF_ERROR( + batch_util::CopyElementToSlice(batch[b][i], &component, b)); + } + } else { + TF_RETURN_IF_ERROR( + batch_util::SetElementZero(&component, padding_values[i])); + for (int64 b = 0; b < batch.size(); ++b) { + if (batch[b][i].shape() == max_shapes[i]) { + TF_RETURN_IF_ERROR( + batch_util::CopyElementToSlice(batch[b][i], &component, b)); + } else { + TF_RETURN_IF_ERROR(batch_util::CopyElementToLargerSlice( + batch[b][i], &component, b)); + } + } + } + out_tensors->push_back(std::move(component)); + } + + // end_of_sequence was set before we populated out_tensors, so + // it's ok to return now. + return Status::OK(); + } + + protected: + // Work around bug in MSVC that disallows access to protected + // members of Iterator from within TensorQueue. + class TensorQueue; + friend class TensorQueue; + + class TensorQueue : public core::RefCounted { + public: + TensorQueue(std::unique_ptr input_impl, + const DataTypeVector& dtypes, + const std::vector& shapes) + : dtypes_(dtypes), + shapes_(shapes), + input_impl_(std::move(input_impl)) {} + + void MaybeWaitForNotificationLocked(mutex_lock* lock) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + // This essentially just releases the lock and immediately relocks. + cv_.wait_for(*lock, std::chrono::milliseconds(0)); + } + + void NotifyLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) { cv_.notify_all(); } + + Status GetNext(IteratorContext* ctx, const int64 batch_size, + std::vector>* batch, + bool* end_of_sequence) { + mutex_lock lock(mu_); + + *end_of_sequence = false; + + for (int64 b = 0; b < batch_size;) { + if (!entries_.empty()) { + batch->push_back(std::move(entries_.front())); + entries_.pop_front(); + ++b; + continue; + } else { + if (input_impl_) { + // There's still input coming in. + std::vector tensors; + bool input_end; + TF_RETURN_IF_ERROR( + input_impl_->GetNext(ctx, &tensors, &input_end)); + if (!input_end) { + batch->push_back(std::move(tensors)); + ++b; + continue; + } else { + input_impl_.reset(); + } + } + if (!input_impl_) { + // There's no more input coming in. + if (RefCountIsOne()) { + // No TensorQueueInserters in the wild. + if (batch->empty()) { + *end_of_sequence = true; + } + break; + } else { + MaybeWaitForNotificationLocked(&lock); + // If there's data available, try to add entries again. + // Otherwise return a smaller batch and hope the next + // iterator request has a non-empty or unused queue_. + if (entries_.empty()) { + break; + } + } + } + } + } // for (int64 b = ... batch_size) + return Status::OK(); + } + + Status Insert(const std::vector& tensors) { + if (tensors.size() != dtypes_.size()) { + return errors::InvalidArgument( + "TensorQueue::Insert: mismatched number of tensors. Queue " + "expects ", + dtypes_.size(), " tensors but tried to insert ", tensors.size()); + } + for (int i = 0; i < tensors.size(); ++i) { + if (tensors[i].dtype() != dtypes_[i]) { + return errors::InvalidArgument( + "TensorQueue::Insert: mismatched dtypes at component ", i, + ". Attempted " + "to insert tensor of type ", + DataTypeString(tensors[i].dtype()), + " but queue expected type: ", DataTypeString(dtypes_[i])); + } + if (!shapes_[i].IsCompatibleWith(tensors[i].shape())) { + return errors::InvalidArgument( + "TensorQueue::Insert: mismatched shapes at component ", i, + ". Attempted " + "to insert tensor with shape ", + tensors[i].shape().DebugString(), + " but queue expected shape: ", shapes_[i].DebugString()); + } + } + mutex_lock lock(mu_); + entries_.push_back(tensors); + NotifyLocked(); + return Status::OK(); + } + + Status Save(Iterator* iter, IteratorStateWriter* writer) { + mutex_lock lock(mu_); + if (input_impl_) { + TF_RETURN_IF_ERROR(iter->SaveParent(writer, input_impl_)); + } else { + TF_RETURN_IF_ERROR( + writer->WriteScalar(iter->full_name("input_exhausted"), "")); + } + TF_RETURN_IF_ERROR(writer->WriteScalar(iter->full_name("entries_size"), + entries_.size())); + for (int64 b = 0; b < entries_.size(); ++b) { + for (int i = 0; i < dtypes_.size(); ++i) { + TF_RETURN_IF_ERROR( + writer->WriteTensor(strings::StrCat(iter->full_name("entries"), + "[", b, "][", i, "]"), + entries_[b][i])); + } + } + return Status::OK(); + } + + Status Restore(Iterator* iter, IteratorContext* ctx, + IteratorStateReader* reader) { + mutex_lock l(mu_); + if (reader->Contains(iter->full_name("input_exhausted"))) { + input_impl_.reset(); + } else { + input_impl_ = iter->dataset_input()->MakeIterator(iter->prefix()); + TF_RETURN_IF_ERROR(iter->RestoreParent(ctx, reader, input_impl_)); + } + entries_.clear(); + int64 entries_size = -1; + TF_RETURN_IF_ERROR( + reader->ReadScalar(iter->full_name("entries_size"), &entries_size)); + if (entries_size < 0) { + return errors::DataLoss( + "Expected entries_size key '", iter->full_name("entries_size"), + "' to have nonnegative value, but saw: ", entries_size); + } + for (int64 b = 0; b < entries_size; ++b) { + std::vector entry; + for (int i = 0; i < dtypes_.size(); ++i) { + Tensor value; + TF_RETURN_IF_ERROR( + reader->ReadTensor(strings::StrCat(iter->full_name("entries"), + "[", b, "][", i, "]"), + &value)); + entry.push_back(std::move(value)); + } + entries_.push_back(std::move(entry)); + } + return Status::OK(); + } + + mutex* mu() { return &mu_; } + + private: + DataTypeVector dtypes_; + std::vector shapes_; + + mutex mu_; + std::unique_ptr input_impl_ GUARDED_BY(mu_); + std::deque> entries_ GUARDED_BY(mu_); + condition_variable cv_ GUARDED_BY(mu_); + }; + + const DatasetBase* dataset_input() const { return dataset()->input_; } + + Status SaveInternal(IteratorStateWriter* writer) override { + return queue_->Save(this, writer); + } + + Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { + return queue_->Restore(this, ctx, reader); + } + + public: + class TensorQueueInserter { + public: + TensorQueueInserter() : queue_(nullptr) {} + + void set_queue(TensorQueue* queue) { + queue_ = queue; + queue_->Ref(); + } + + TensorQueueInserter(const TensorQueueInserter& rhs) { + queue_ = rhs.queue_; + queue_->Ref(); + }; + + TensorQueueInserter(TensorQueueInserter&& rhs) { + queue_ = rhs.queue_; + rhs.queue_ = nullptr; + } + + TensorQueueInserter& operator=(const TensorQueueInserter& rhs) = delete; + + string TypeName() const { return "tensorflow::TensorQueueInserter"; } + string DebugString() const { return TypeName(); } + + void Encode(VariantTensorData*) const {} + bool Decode(const VariantTensorData&) { return false; } + + ~TensorQueueInserter() { + if (queue_) { + mutex_lock lock(*queue_->mu()); + queue_->Unref(); + queue_->NotifyLocked(); + queue_ = nullptr; + } + } + + Status Insert(const std::vector& tensors) const { + CHECK(queue_); + return queue_->Insert(tensors); + } + + private: + mutable TensorQueue* queue_; + }; + + private: + TensorQueue* const queue_; + }; + + private: + const int64 batch_size_; + const DatasetBase* input_; + const DataTypeVector dtypes_; + const std::vector shapes_; + const std::vector padding_values_; + const DataTypeVector dtypes_with_queue_; + const std::vector batched_shapes_with_queue_; +}; + +class PrependFromQueueAndPaddedBatchDatasetOp : public UnaryDatasetOpKernel { + public: + explicit PrependFromQueueAndPaddedBatchDatasetOp(OpKernelConstruction* ctx) + : UnaryDatasetOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("Toutput_types", &output_types_)); + } + + void MakeDataset(OpKernelContext* ctx, DatasetBase* input, + DatasetBase** output) override { + int64 batch_size = 0; + OP_REQUIRES_OK(ctx, + ParseScalarArgument(ctx, "batch_size", &batch_size)); + OP_REQUIRES( + ctx, batch_size > 0, + errors::InvalidArgument("Batch size must be greater than zero.")); + + OpInputList padded_shape_tensors; + OP_REQUIRES_OK(ctx, + ctx->input_list("padded_shapes", &padded_shape_tensors)); + std::vector padded_shapes; + padded_shapes.reserve(padded_shape_tensors.size()); + OP_REQUIRES(ctx, + padded_shape_tensors.size() == input->output_shapes().size(), + errors::InvalidArgument("Number of padded shapes (", + padded_shape_tensors.size(), + ") must match the number of components " + "in the input dataset's elements (", + input->output_shapes().size(), ")")); + for (const Tensor& padded_shape_t : padded_shape_tensors) { + OP_REQUIRES(ctx, TensorShapeUtils::IsVector(padded_shape_t.shape()), + errors::InvalidArgument("All padded shapes must be vectors")); + PartialTensorShape padded_shape; + OP_REQUIRES_OK(ctx, PartialTensorShape::MakePartialShape( + padded_shape_t.vec().data(), + padded_shape_t.NumElements(), &padded_shape)); + padded_shapes.push_back(std::move(padded_shape)); + } + + OP_REQUIRES( + ctx, input->output_dtypes() == output_types_, + errors::InvalidArgument("Input dataset and this dataset " + "have different output_types: ", + DataTypeVectorString(input->output_dtypes()), + " and ", DataTypeVectorString(output_types_))); + + for (int i = 0; i < input->output_shapes().size(); ++i) { + // Exclude the queue from the tensor_shapes calculation. + const PartialTensorShape& tensor_shape = padded_shapes[i]; + OP_REQUIRES( + ctx, + IsGreaterEqualToOrCompatibleWith(tensor_shape, + input->output_shapes()[i]), + errors::InvalidArgument("Incompatible input shapes at component ", i, + " between input dataset this dataset: ", + input->output_shapes()[i].DebugString(), + " vs. ", tensor_shape.DebugString())); + } + + OpInputList padding_values_list; + OP_REQUIRES_OK(ctx, + ctx->input_list("padding_values", &padding_values_list)); + std::vector padding_values; + OP_REQUIRES(ctx, + padding_values_list.size() == input->output_shapes().size(), + errors::InvalidArgument( + "Number of padding values (", padding_values_list.size(), + ") must match the number of components in the input " + "dataset's elements (", + input->output_shapes().size(), ")")); + for (int i = 0; i < padding_values_list.size(); ++i) { + const Tensor& padding_value_t = padding_values_list[i]; + OP_REQUIRES( + ctx, TensorShapeUtils::IsScalar(padding_value_t.shape()), + errors::InvalidArgument( + "All padding values must be scalars; but at component ", i, + " saw shape: ", padding_value_t.shape().DebugString())); + OP_REQUIRES(ctx, padding_value_t.dtype() == input->output_dtypes()[i], + errors::InvalidArgument( + "Mismatched type between padding value ", i, + " and input dataset's component ", i, ": ", + DataTypeString(padding_value_t.dtype()), " vs. ", + DataTypeString(input->output_dtypes()[i]))); + padding_values.push_back(padding_value_t); + } + + *output = new PrependFromQueueAndPaddedBatchDataset( + ctx, batch_size, input, output_types_, padded_shapes, + std::move(padding_values)); + } + + private: + DataTypeVector output_types_; +}; + +REGISTER_KERNEL_BUILDER( + Name("PrependFromQueueAndPaddedBatchDataset").Device(DEVICE_CPU), + PrependFromQueueAndPaddedBatchDatasetOp); + +class EnqueueInQueueDatasetOp : public OpKernel { + public: + explicit EnqueueInQueueDatasetOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} + void Compute(OpKernelContext* ctx) override { + using TensorQueueInserter = + PrependFromQueueAndPaddedBatchDataset::Iterator::TensorQueueInserter; + + // TODO(ebrevdo): accept list of sequence lengths to do proper + // sub-slicing of tensors for placement into the queue? + const Tensor& tensor_queue_t = ctx->input(0); + OP_REQUIRES(ctx, TensorShapeUtils::IsVector(tensor_queue_t.shape()), + errors::InvalidArgument("queue must be a vector, saw shape: ", + tensor_queue_t.shape().DebugString())); + std::vector inserters; + const int64 batch_size = tensor_queue_t.NumElements(); + inserters.reserve(batch_size); + const Variant* variants = tensor_queue_t.flat().data(); + for (int i = 0; i < batch_size; ++i) { + const auto* inserter = variants[i].get(); + OP_REQUIRES(ctx, inserter != nullptr, + errors::InvalidArgument( + "Could not access TensorQueueInserter from queue[", i, + "]. Received variant: ", variants[i].DebugString())); + inserters.push_back(inserter); + } + + OpInputList components; + OP_REQUIRES_OK(ctx, ctx->input_list("components", &components)); + for (int i = 0; i < components.size(); ++i) { + OP_REQUIRES( + ctx, + components[i].dims() > 0 && components[i].dim_size(0) == batch_size, + errors::InvalidArgument( + "Expected component ", i, " to have batched shape [", batch_size, + ",...], but saw shape: ", components[i].shape().DebugString())); + } + std::vector element_shapes; + for (int i = 0; i < components.size(); ++i) { + TensorShape element_shape = components[i].shape(); + element_shape.RemoveDim(0); + element_shapes.push_back(std::move(element_shape)); + } + for (int64 b = 0; b < batch_size; ++b) { + std::vector tensors; + tensors.reserve(components.size()); + for (int i = 0; i < components.size(); ++i) { + Tensor t(components[i].dtype(), element_shapes[i]); + OP_REQUIRES_OK(ctx, + batch_util::CopySliceToElement(components[i], &t, b)); + tensors.push_back(std::move(t)); + } + // TODO(ebrevdo): Acquire the lock once for all inserters with + // the same underlying queue? Add InsertLocked? + OP_REQUIRES_OK(ctx, inserters[b]->Insert(tensors)); + } + } +}; + +REGISTER_KERNEL_BUILDER(Name("EnqueueInQueueDataset").Device(DEVICE_CPU), + EnqueueInQueueDatasetOp); + +} // namespace + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/data/tensor_slice_dataset_op.cc b/tensorflow/core/kernels/data/tensor_slice_dataset_op.cc index 18adae1ea32316ffd995a95fb25198309fda3361..d5be4c778074e406122dc3a1a9c23681fca491d0 100644 --- a/tensorflow/core/kernels/data/tensor_slice_dataset_op.cc +++ b/tensorflow/core/kernels/data/tensor_slice_dataset_op.cc @@ -117,7 +117,7 @@ class TensorSliceDatasetOp : public DatasetOpKernel { out_tensors->reserve(dataset()->tensors_.size()); for (int i = 0; i < dataset()->tensors_.size(); ++i) { const Tensor& t = dataset()->tensors_[i]; - Tensor t_slice(cpu_allocator(), t.dtype(), + Tensor t_slice(ctx->allocator({}), t.dtype(), TensorShape(dataset()->shapes_[i].dim_sizes())); TF_RETURN_IF_ERROR(batch_util::CopySliceToElement(t, &t_slice, i_)); out_tensors->emplace_back(std::move(t_slice)); diff --git a/tensorflow/core/kernels/fuzzing/decode_base64_fuzz.cc b/tensorflow/core/kernels/fuzzing/decode_base64_fuzz.cc index 6d4a9dfdef4609a45d3a38e49a32492408043617..37edd1ce0f95d7f6d6a366f5b0d83bac7f6159d5 100644 --- a/tensorflow/core/kernels/fuzzing/decode_base64_fuzz.cc +++ b/tensorflow/core/kernels/fuzzing/decode_base64_fuzz.cc @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/kernels/fuzzing/fuzz_session.h" #include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/kernels/fuzzing/fuzz_session.h" namespace tensorflow { namespace fuzzing { diff --git a/tensorflow/core/kernels/fuzzing/decode_jpeg_fuzz.cc b/tensorflow/core/kernels/fuzzing/decode_jpeg_fuzz.cc index b084a972049cc2b1997df64a2f43a6d79b6b4e6d..f3b24b2341e590adfbeac1a18b6a65fbfd34f598 100644 --- a/tensorflow/core/kernels/fuzzing/decode_jpeg_fuzz.cc +++ b/tensorflow/core/kernels/fuzzing/decode_jpeg_fuzz.cc @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/kernels/fuzzing/fuzz_session.h" #include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/kernels/fuzzing/fuzz_session.h" namespace tensorflow { namespace fuzzing { diff --git a/tensorflow/core/kernels/fuzzing/decode_json_example_fuzz.cc b/tensorflow/core/kernels/fuzzing/decode_json_example_fuzz.cc index 9dd795b94e82c48ad037df67f3218ed62feb722e..e9ffad178616a7b0872d461653cb01c40b292d88 100644 --- a/tensorflow/core/kernels/fuzzing/decode_json_example_fuzz.cc +++ b/tensorflow/core/kernels/fuzzing/decode_json_example_fuzz.cc @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/kernels/fuzzing/fuzz_session.h" #include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/kernels/fuzzing/fuzz_session.h" namespace tensorflow { namespace fuzzing { diff --git a/tensorflow/core/kernels/fuzzing/decode_png_fuzz.cc b/tensorflow/core/kernels/fuzzing/decode_png_fuzz.cc index 4a68a5b5803f363ab93bf280df54fa8f14206a84..020f18b1895c480748cafbfb8f7f267887db1fba 100644 --- a/tensorflow/core/kernels/fuzzing/decode_png_fuzz.cc +++ b/tensorflow/core/kernels/fuzzing/decode_png_fuzz.cc @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/kernels/fuzzing/fuzz_session.h" #include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/kernels/fuzzing/fuzz_session.h" namespace tensorflow { namespace fuzzing { diff --git a/tensorflow/core/kernels/fuzzing/encode_base64_fuzz.cc b/tensorflow/core/kernels/fuzzing/encode_base64_fuzz.cc index 2d6c82826cf9dad1ca67d6e5ee1d13a059f9c8ea..a8f07f4bad3a7e7ccff4ebefd4c56c695d0b2573 100644 --- a/tensorflow/core/kernels/fuzzing/encode_base64_fuzz.cc +++ b/tensorflow/core/kernels/fuzzing/encode_base64_fuzz.cc @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/kernels/fuzzing/fuzz_session.h" #include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/kernels/fuzzing/fuzz_session.h" namespace tensorflow { namespace fuzzing { diff --git a/tensorflow/core/kernels/fuzzing/encode_jpeg_fuzz.cc b/tensorflow/core/kernels/fuzzing/encode_jpeg_fuzz.cc index 81b6e491248fda37f602c0365c1e90d4b08f7c2a..f5dd47a052cd098937d66394ed04c66831ee5972 100644 --- a/tensorflow/core/kernels/fuzzing/encode_jpeg_fuzz.cc +++ b/tensorflow/core/kernels/fuzzing/encode_jpeg_fuzz.cc @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/kernels/fuzzing/fuzz_session.h" #include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/kernels/fuzzing/fuzz_session.h" namespace tensorflow { namespace fuzzing { diff --git a/tensorflow/core/kernels/fuzzing/example_proto_fast_parsing_fuzz.cc b/tensorflow/core/kernels/fuzzing/example_proto_fast_parsing_fuzz.cc index d91a351c5969e71385348b76376202c14e86daac..4d736a21602b34b560ea1c8d9ede4645d806ca29 100644 --- a/tensorflow/core/kernels/fuzzing/example_proto_fast_parsing_fuzz.cc +++ b/tensorflow/core/kernels/fuzzing/example_proto_fast_parsing_fuzz.cc @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/kernels/fuzzing/fuzz_session.h" #include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/kernels/fuzzing/fuzz_session.h" namespace tensorflow { namespace fuzzing { diff --git a/tensorflow/core/kernels/fuzzing/identity_fuzz.cc b/tensorflow/core/kernels/fuzzing/identity_fuzz.cc index ac3a12aa399a3efe532c71c49a092b6cecd6059b..5c3fc4a2795430d1f8f269f42131e882106db7b0 100644 --- a/tensorflow/core/kernels/fuzzing/identity_fuzz.cc +++ b/tensorflow/core/kernels/fuzzing/identity_fuzz.cc @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/kernels/fuzzing/fuzz_session.h" #include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/kernels/fuzzing/fuzz_session.h" namespace tensorflow { namespace fuzzing { diff --git a/tensorflow/core/kernels/fuzzing/parse_tensor_op_fuzz.cc b/tensorflow/core/kernels/fuzzing/parse_tensor_op_fuzz.cc index 978fcd102822a6a2690478eaca473eabc6ae83ab..c90ad2cfeb7222f4c75e718fcaea6955567f3a4a 100644 --- a/tensorflow/core/kernels/fuzzing/parse_tensor_op_fuzz.cc +++ b/tensorflow/core/kernels/fuzzing/parse_tensor_op_fuzz.cc @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/kernels/fuzzing/fuzz_session.h" #include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/kernels/fuzzing/fuzz_session.h" namespace tensorflow { namespace fuzzing { diff --git a/tensorflow/core/kernels/fuzzing/string_split_fuzz.cc b/tensorflow/core/kernels/fuzzing/string_split_fuzz.cc index 7d1aa1fbf3a149d25e82b454543a5add522145af..738d78e99a0081a2b9f0f59c94433372acec19e2 100644 --- a/tensorflow/core/kernels/fuzzing/string_split_fuzz.cc +++ b/tensorflow/core/kernels/fuzzing/string_split_fuzz.cc @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/kernels/fuzzing/fuzz_session.h" #include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/kernels/fuzzing/fuzz_session.h" namespace tensorflow { namespace fuzzing { diff --git a/tensorflow/core/kernels/fuzzing/string_to_number_fuzz.cc b/tensorflow/core/kernels/fuzzing/string_to_number_fuzz.cc index 94255d215e5292bf77ab1104eb1d36c0cc1d661c..e98363ffbf166782649f3fa12dc2ab70024908cf 100644 --- a/tensorflow/core/kernels/fuzzing/string_to_number_fuzz.cc +++ b/tensorflow/core/kernels/fuzzing/string_to_number_fuzz.cc @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/kernels/fuzzing/fuzz_session.h" #include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/kernels/fuzzing/fuzz_session.h" namespace tensorflow { namespace fuzzing { diff --git a/tensorflow/core/kernels/gather_op.cc b/tensorflow/core/kernels/gather_op.cc index d6cbcf1d936b8efe861b9d2aa34b23f591324ab2..08adf4badbcd9c8baf664b13098f23dfb0584e24 100644 --- a/tensorflow/core/kernels/gather_op.cc +++ b/tensorflow/core/kernels/gather_op.cc @@ -18,6 +18,8 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/variant.h" +#include "tensorflow/core/framework/variant_encode_decode.h" #include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/kernels/gather_functor.h" #include "tensorflow/core/platform/mem.h" diff --git a/tensorflow/core/kernels/hexagon/graph_transfer_utils.cc b/tensorflow/core/kernels/hexagon/graph_transfer_utils.cc index f0d7c670a62bf0a520cb37f01beda530d157d5c7..4040bf52bffe638d601f954f9a81d9eda78346a6 100644 --- a/tensorflow/core/kernels/hexagon/graph_transfer_utils.cc +++ b/tensorflow/core/kernels/hexagon/graph_transfer_utils.cc @@ -46,7 +46,7 @@ GraphTransferUtils::GetTopNFloatResults(const float* const data, GetTopNFloatResults(data, labels, element_count); LOG(INFO) << "=== Dump ranking ==="; for (int i = 0; i < top_n; ++i) { - const std::tuple &entry = queue.top(); + const std::tuple& entry = queue.top(); LOG(INFO) << i << ": " << std::get<1>(entry) << ", " << std::get<2>(entry) << ", " << std::get<0>(entry); queue.pop(); diff --git a/tensorflow/core/kernels/hexagon/graph_transferer.h b/tensorflow/core/kernels/hexagon/graph_transferer.h index a360d188cc2246b87af348db9958152418742822..0d43d028cdbea02b820d8ac0c48378524e875e78 100644 --- a/tensorflow/core/kernels/hexagon/graph_transferer.h +++ b/tensorflow/core/kernels/hexagon/graph_transferer.h @@ -181,8 +181,8 @@ class GraphTransferer { void AppendNodeInputParams(const int id, const Node& node, const std::vector& extra_inputs); - void AppendNodeOutputParams(const ShapeRefiner& shape_refiner, - const int id, const Node& node); + void AppendNodeOutputParams(const ShapeRefiner& shape_refiner, const int id, + const Node& node); static std::array BuildShapeArray( const shape_inference::ShapeHandle& shape_handle, diff --git a/tensorflow/core/kernels/hexagon/graph_transferer_test.cc b/tensorflow/core/kernels/hexagon/graph_transferer_test.cc index 536d295506c9669b0434059e26094cb70a4f1e87..20b09f144bab5482f2cf1bfa86cf22f0b7ff815e 100644 --- a/tensorflow/core/kernels/hexagon/graph_transferer_test.cc +++ b/tensorflow/core/kernels/hexagon/graph_transferer_test.cc @@ -42,8 +42,7 @@ constexpr float VALUE_TOLERANCE_FLOAT = 1e-8f; class GraphTransfererTest : public ::testing::Test { protected: - void SetUp() final { - } + void SetUp() final {} GraphTransferer gt_; }; @@ -61,7 +60,7 @@ class TestGraphTransferOpsDefinitions : public IRemoteFusedGraphOpsDefinitions { } } return -1; -} + } private: const std::vector op_types_{"INPUT", "OUTPUT", "Conv2D", diff --git a/tensorflow/core/kernels/hexagon/hexagon_graph_execution_test.cc b/tensorflow/core/kernels/hexagon/hexagon_graph_execution_test.cc index 71bc4187b74cd6501d203aa3779c6d01e01f0d38..3f794dfb1a04cfdd6f7c114e0b2c7c0aac319a61 100644 --- a/tensorflow/core/kernels/hexagon/hexagon_graph_execution_test.cc +++ b/tensorflow/core/kernels/hexagon/hexagon_graph_execution_test.cc @@ -420,7 +420,7 @@ TEST(GraphTransferer, false, // is_text_proto false, // shape_inference_for_unknown_shape true // dry_run_for_unknown_shape - ); + ); ASSERT_TRUE(status.ok()) << status; prof.Stop(); prof.DumpStatistics("LoadGraphFromProtoFile"); @@ -487,7 +487,7 @@ TEST(GraphTransferer, false, // is_text_proto true, // shape_inference_for_unknown_shape false // dry_run_for_unknown_shape - ); + ); ASSERT_TRUE(status.ok()) << status; prof.Stop(); prof.DumpStatistics("LoadGraphFromProtoFile"); @@ -556,7 +556,7 @@ TEST(GraphTransferer, DISABLED_CheckShapeInferencePerformance) { false, // is_text_proto false, // shape_inference_for_unknown_shape true // dry_run_for_unknown_shape - ); + ); const GraphTransferInfo& gfi0 = gt0.GetGraphTransferInfo(); ASSERT_TRUE(status.ok()); @@ -576,7 +576,7 @@ TEST(GraphTransferer, DISABLED_CheckShapeInferencePerformance) { false, // is_text_proto true, // shape_inference_for_unknown_shape false // dry_run_for_unknown_shape - ); + ); const GraphTransferInfo& gfi1 = gt1.GetGraphTransferInfo(); ASSERT_TRUE(status.ok()); diff --git a/tensorflow/core/kernels/identity_op.cc b/tensorflow/core/kernels/identity_op.cc index 1db9263e5d396b4cdb0920db18e5189149128758..a18a72c66dc659ffd372c231524dbf038df6ac22 100644 --- a/tensorflow/core/kernels/identity_op.cc +++ b/tensorflow/core/kernels/identity_op.cc @@ -128,6 +128,7 @@ REGISTER_GPU_KERNEL(Variant); REGISTER_GPU_HOST_KERNEL(int32); REGISTER_GPU_HOST_KERNEL(bool); +REGISTER_GPU_HOST_KERNEL(string); #undef REGISTER_GPU_HOST_KERNEL diff --git a/tensorflow/core/kernels/matrix_band_part_op.cc b/tensorflow/core/kernels/matrix_band_part_op.cc index d7fff4bb0c2b03bdfa2845f3ff89d938e07466e1..1439141f6493943c94516e6f0f9c05e8314401d5 100644 --- a/tensorflow/core/kernels/matrix_band_part_op.cc +++ b/tensorflow/core/kernels/matrix_band_part_op.cc @@ -62,7 +62,15 @@ class MatrixBandPartOp : public OpKernel { OP_REQUIRES(context, TensorShapeUtils::IsScalar(num_lower_in.shape()), errors::InvalidArgument("num_lower must be scalar, got shape ", num_lower_in.shape().DebugString())); - const int64 num_lower = num_lower_in.scalar()(); + + auto as_int64_scalar = [](const Tensor& tensor) -> int64 { + if (tensor.dtype() == DT_INT32) { + return tensor.scalar()(); + } else { + return tensor.scalar()(); + } + }; + const int64 num_lower = as_int64_scalar(num_lower_in); OP_REQUIRES( context, num_lower <= input_reshaped.dimension(1), errors::InvalidArgument( @@ -73,7 +81,7 @@ class MatrixBandPartOp : public OpKernel { OP_REQUIRES(context, TensorShapeUtils::IsScalar(num_upper_in.shape()), errors::InvalidArgument("num_upper must be scalar, got shape ", num_upper_in.shape().DebugString())); - const int64 num_upper = num_upper_in.scalar()(); + const int64 num_upper = as_int64_scalar(num_upper_in); OP_REQUIRES(context, num_upper <= input_reshaped.dimension(2), errors::InvalidArgument("num_upper must be negative or less or " "equal to number of columns (", diff --git a/tensorflow/core/kernels/mkl_aggregate_ops.cc b/tensorflow/core/kernels/mkl_aggregate_ops.cc index 49c34fed0224a60db90d3c21b8d6dd2c39cfae40..ef724f0a296577539fa33176e7f1a4cd55e8c663 100644 --- a/tensorflow/core/kernels/mkl_aggregate_ops.cc +++ b/tensorflow/core/kernels/mkl_aggregate_ops.cc @@ -317,8 +317,11 @@ class MklAddNOp : public OpKernel { : src2_tensor.dims(); // if the shapes of two tensors are not same raise op error TensorShape src1_shape, src2_shape; - src1_shape = src1_tensor.shape(); - src2_shape = src2_tensor.shape(); + src1_shape = input1_in_mkl_format ? src1_mkl_shape.GetTfShape() + : src1_tensor.shape(); + src2_shape = input2_in_mkl_format ? src2_mkl_shape.GetTfShape() + : src2_tensor.shape(); + if (!src1_shape.IsSameSize(src2_shape)) { ctx->SetStatus(errors::InvalidArgument( "Inputs to operation ", this->name(), " of type ", diff --git a/tensorflow/core/kernels/mkl_avgpooling_op.cc b/tensorflow/core/kernels/mkl_avgpooling_op.cc index ebaa0f4e2aef072bdbc6479b72d5f8174a0bd0ea..cff1bd18a74841d91acc98e0d3cc90041a0e7142 100644 --- a/tensorflow/core/kernels/mkl_avgpooling_op.cc +++ b/tensorflow/core/kernels/mkl_avgpooling_op.cc @@ -468,6 +468,28 @@ class MklAvgPoolingOp : public MklPoolingForwardOpBase { memory::dims output_dims_mkl_order; this->GetOutputDims(pool_params, &output_dims_mkl_order); + // If input is an empty tensor, allocate an empty output tensor and return + if (input_tensor.NumElements() == 0) { + MklDnnShape output_mkl_shape; + output_mkl_shape.SetMklTensor(false); + TensorShape output_tf_shape; + if (pool_params.data_format == TensorFormat::FORMAT_NCHW) { + output_tf_shape = MklDnnDimsToTFShape(output_dims_mkl_order); + } else { + memory::dims output_dims_NHWC_order; + output_dims_NHWC_order = {pool_params.tensor_in_batch, + static_cast(pool_params.out_height), + static_cast(pool_params.out_width), + pool_params.out_depth}; + output_tf_shape = MklDnnDimsToTFShape(output_dims_NHWC_order); + } + const int kOutputIndex = 0; + AllocateOutputSetMklShape(context, kOutputIndex, &output_tensor, + output_tf_shape, output_mkl_shape); + CHECK_NOTNULL(output_tensor); + return; + } + // If input is in Mkl layout, then just get the memory format from it // directly, instead of using input data_format to AvgPool. if (dnn_shape_input.IsMklTensor()) { diff --git a/tensorflow/core/kernels/mkl_cwise_ops_common.cc b/tensorflow/core/kernels/mkl_cwise_ops_common.cc index c065724e0dbbe091d253eb2315c9a5f3c041d695..58f0c30f32b0eebd7ceff856b2e3bd881b28121c 100644 --- a/tensorflow/core/kernels/mkl_cwise_ops_common.cc +++ b/tensorflow/core/kernels/mkl_cwise_ops_common.cc @@ -1,4 +1,4 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0(the "License"); you may not use this file except in compliance with the License. diff --git a/tensorflow/core/kernels/mkl_input_conversion_op.cc b/tensorflow/core/kernels/mkl_input_conversion_op.cc index 4337e4b49e420c4ea40cef87946702b724119d92..acb0db57b38c08af345dc2b22a7822c0f0f202f0 100644 --- a/tensorflow/core/kernels/mkl_input_conversion_op.cc +++ b/tensorflow/core/kernels/mkl_input_conversion_op.cc @@ -293,14 +293,56 @@ class MklInputConversionOp : public OpKernel { // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - // If both inputs are in MKL format if (input_shape_0.IsMklTensor() && input_shape_1.IsMklTensor()) { - // If both have the same shape, pass them through if (tf_shapes_are_same) { - VLOG(1) << "MklInputConversionOp: No conversion needed, " - << "copying MKL inputs with identical shapes to output"; - - ForwardMklTensorInToOut(context, 0, 0); - ForwardMklTensorInToOut(context, 1, 1); - return; + auto input0_md = input_shape_0.GetMklLayout(); + auto input1_md = input_shape_1.GetMklLayout(); + + // If both have the same shape and same format, pass them through + if ( input0_md.data.format == input1_md.data.format) { + VLOG(1) << "MklInputConversionOp: No conversion needed, " + << "copying MKL inputs with identical shapes to output"; + + ForwardMklTensorInToOut(context, 0, 0); + ForwardMklTensorInToOut(context, 1, 1); + return; + } else { + VLOG(1) << "MklInputConversionOp: Shape is same, but format is different, " + << "need to convert to same format"; + + // Convert input0, and keep input1 unchanged + // Create MklDnnShape for output mkl tensor based on input0 + Tensor* tensor_out; + MklDnnShape mkl_output_mkl_shape; + mkl_output_mkl_shape.SetMklTensor(true); + mkl_output_mkl_shape.SetElemType(MklDnnType()); + mkl_output_mkl_shape.SetTfLayout(input_shape_0.GetDimension(), + input_shape_0.GetSizesAsMklDnnDims(), + input_shape_0.GetTfDataFormat()); + + // Get MKL layout from input1 as destination layout + mkl_output_mkl_shape.SetMklLayout(&input1_md); + + // Create output Mkl tensor for index 0 + AllocateOutputSetMklShape(context, 0, &tensor_out, + input_tensor_0.shape(), mkl_output_mkl_shape); + + // Create MklDnnData object for input0 tesnsor + auto cpu_engine = engine(engine::cpu, 0); + MklDnnData input(&cpu_engine); + input.SetUsrMem(input0_md, &input_tensor_0); + + // Create reorder from input0's layout to input1's layout + std::vector net; + CHECK_EQ(input.CheckReorderToOpMem(memory::primitive_desc( + input1_md, cpu_engine), + tensor_out, &net), + true); + stream(stream::kind::eager).submit(net).wait(); + + // Input1 will be passed through + ForwardMklTensorInToOut(context, 1, 1); + return; + } } // Sanity check diff --git a/tensorflow/core/kernels/neon/neon_depthwise_conv_op.cc b/tensorflow/core/kernels/neon/neon_depthwise_conv_op.cc index 17f2af550f248a6924bb3d1e7546eca84d4c1e51..0e820bbb6208ae9c13ac2fb33f67590b9e66ba7e 100644 --- a/tensorflow/core/kernels/neon/neon_depthwise_conv_op.cc +++ b/tensorflow/core/kernels/neon/neon_depthwise_conv_op.cc @@ -71,10 +71,10 @@ class NeonDepthwiseConv2dNativeOp : public BinaryOp { filter.shape().DebugString())); const int32 in_depth = input.dim_size(3); - OP_REQUIRES( - context, in_depth == filter.dim_size(2), - errors::InvalidArgument("input and filter must have the same depth: ", - in_depth, " vs ", filter.dim_size(2))); + OP_REQUIRES(context, in_depth == filter.dim_size(2), + errors::InvalidArgument( + "input and filter must have the same depth: ", in_depth, + " vs ", filter.dim_size(2))); const int32 batch = input.dim_size(0); const int32 input_rows = input.dim_size(1); const int32 input_cols = input.dim_size(2); diff --git a/tensorflow/core/kernels/pack_op.cc b/tensorflow/core/kernels/pack_op.cc index e0ae5de0f45063dd55fe567519942437e4ea889a..5645275cfa98eb820b7d1e885b18894bfab17e49 100644 --- a/tensorflow/core/kernels/pack_op.cc +++ b/tensorflow/core/kernels/pack_op.cc @@ -139,7 +139,6 @@ class PackOp : public OpKernel { TF_CALL_ALL_TYPES(REGISTER_PACK); TF_CALL_QUANTIZED_TYPES(REGISTER_PACK); -TF_CALL_variant(REGISTER_PACK); #if defined(IS_MOBILE_PLATFORM) && !defined(SUPPORT_SELECTIVE_REGISTRATION) // Primarily used for SavedModel support on mobile. diff --git a/tensorflow/core/kernels/parse_tensor_op.cc b/tensorflow/core/kernels/parse_tensor_op.cc index dd41744f023dd06c66e0f5a921880dbe6d5b843d..8e175fe8d4b4fa203809e5871bfd301188c985da 100644 --- a/tensorflow/core/kernels/parse_tensor_op.cc +++ b/tensorflow/core/kernels/parse_tensor_op.cc @@ -91,7 +91,6 @@ class SerializeTensorOp : public OpKernel { Name("SerializeTensor").Device(DEVICE_CPU).TypeConstraint("T"), \ SerializeTensorOp); TF_CALL_ALL_TYPES(REGISTER) -TF_CALL_variant(REGISTER) #undef REGISTER } // namespace tensorflow diff --git a/tensorflow/core/kernels/reshape_op.cc b/tensorflow/core/kernels/reshape_op.cc index 8b86596721aa41c124b35b19cac7aac264b6f574..33c63e70500971cbcfb847d03239e0721d4871ff 100644 --- a/tensorflow/core/kernels/reshape_op.cc +++ b/tensorflow/core/kernels/reshape_op.cc @@ -43,7 +43,6 @@ REGISTER_KERNEL_BUILDER(Name("Reshape") .TypeConstraint("Tshape"), \ ReshapeOp); TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL); -TF_CALL_bfloat16(REGISTER_GPU_KERNEL); TF_CALL_bool(REGISTER_GPU_KERNEL); #undef REGISTER_GPU_KERNEL diff --git a/tensorflow/core/kernels/resource_variable_ops.cc b/tensorflow/core/kernels/resource_variable_ops.cc index 9cc8e03e3ac6b17f16d65f1a9ade04d8fdcba034..5b4aad3cdd83905716df0fd67cec4817e04a1ee1 100644 --- a/tensorflow/core/kernels/resource_variable_ops.cc +++ b/tensorflow/core/kernels/resource_variable_ops.cc @@ -387,7 +387,6 @@ class AssignVariableOp : public OpKernel { TF_CALL_ALL_TYPES(REGISTER_KERNELS); TF_CALL_QUANTIZED_TYPES(REGISTER_KERNELS); -TF_CALL_variant(REGISTER_KERNELS); #undef REGISTER_KERNELS #if GOOGLE_CUDA @@ -635,6 +634,9 @@ class ResourceScatterUpdateOp : public OpKernel { TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ARITHEMTIC_CPU); +REGISTER_SCATTER_KERNEL(string, CPU, "ResourceScatterUpdate", + scatter_op::UpdateOp::ASSIGN); + // Registers GPU kernels. #if GOOGLE_CUDA #define REGISTER_SCATTER_ARITHEMTIC_GPU(type) \ diff --git a/tensorflow/core/kernels/segment_reduction_ops.cc b/tensorflow/core/kernels/segment_reduction_ops.cc index 27b8081eb88a13c68d434e82c2e59d1aea068b78..6c4685a50a4139b9f33d22b409059f7c03fa2812 100644 --- a/tensorflow/core/kernels/segment_reduction_ops.cc +++ b/tensorflow/core/kernels/segment_reduction_ops.cc @@ -20,10 +20,10 @@ limitations under the License. #define EIGEN_USE_GPU #endif // GOOGLE_CUDA -#include "tensorflow/core/kernels/segment_reduction_ops.h" -#include #include "third_party/eigen3/Eigen/Core" #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/kernels/segment_reduction_ops.h" +#include #include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" @@ -356,158 +356,180 @@ TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_SORTED_KERNELS_ALL); #undef REGISTER_GPU_SORTED_KERNELS_ALL #endif // GOOGLE_CUDA +// ____________________________________________________________________________ +// Unsorted segment reduction ops. + namespace functor { -// UnsortedSegmentSumFunctor implementation for CPUDevice. -// todo: Remove duplicate code in UnsortedSegmentSumFunctor and -// UnsortedSegmentMaxFunctor. -template -struct UnsortedSegmentSumFunctor - : UnsortedSegmentBaseFunctor { - void operator()(OpKernelContext* ctx, const CPUDevice& d, - const Index output_rows, const TensorShape& segment_ids_shape, +// The ReductionFunctor implementation for CPU. +template +struct UnsortedSegmentFunctor { + void operator()(OpKernelContext* ctx, const Index num_segments, + const TensorShape& segment_ids_shape, typename TTypes::ConstFlat segment_ids, const Index data_size, const T* data, - typename TTypes::Tensor output) override { - output.setZero(); + typename TTypes::Tensor output) { + output.setConstant(InitialValueF()()); if (data_size == 0) { return; } const int64 N = segment_ids.dimension(0); + ReductionF reduction; auto data_flat = typename TTypes::ConstTensor(data, N, data_size / N); for (int64 i = 0; i < N; ++i) { Index j = internal::SubtleMustCopy(segment_ids(i)); if (j < 0) { continue; } - OP_REQUIRES(ctx, FastBoundsCheck(j, output_rows), + OP_REQUIRES(ctx, FastBoundsCheck(j, num_segments), errors::InvalidArgument( "segment_ids", SliceDebugString(segment_ids_shape, i), - " = ", j, " is out of range [0, ", output_rows, ")")); - output.template chip<0>(j) += data_flat.template chip<0>(i); + " = ", j, " is out of range [0, ", num_segments, ")")); + reduction(data_flat.template chip<0>(i), output.template chip<0>(j)); } } }; -// UnsortedSegmentMaxFunctor implementation for CPUDevice. -template -struct UnsortedSegmentMaxFunctor - : UnsortedSegmentBaseFunctor { - void operator()(OpKernelContext* ctx, const CPUDevice& d, - const Index output_rows, const TensorShape& segment_ids_shape, - typename TTypes::ConstFlat segment_ids, - const Index data_size, const T* data, - typename TTypes::Tensor output) override { - output.setConstant(std::numeric_limits::lowest()); - if (data_size == 0) { - return; - } - const int64 N = segment_ids.dimension(0); - auto data_flat = typename TTypes::ConstTensor(data, N, data_size / N); - for (int64 i = 0; i < N; ++i) { - Index j = internal::SubtleMustCopy(segment_ids(i)); - OP_REQUIRES(ctx, FastBoundsCheck(j, output_rows), - errors::InvalidArgument( - "segment_ids", SliceDebugString(segment_ids_shape, i), - " = ", j, " is out of range [0, ", output_rows, ")")); - output.template chip<0>(j) = - data_flat.template chip<0>(i).cwiseMax(output.template chip<0>(j)); - } + +template +using MatrixChip = Eigen::TensorChippingOp<0l, typename TTypes::Matrix>; + +template +using constMatrixChip = + Eigen::TensorChippingOp<0l, const typename TTypes::ConstMatrix>; + +// reduction functors +template +struct SumOp { + void operator()(const constMatrixChip data, MatrixChip output) { + output += data; + } +}; + +template +struct MaxOp { + void operator()(const constMatrixChip data, MatrixChip output) { + output = data.cwiseMax(output); + } +}; + +template +struct MinOp { + void operator()(const constMatrixChip data, MatrixChip output) { + output = data.cwiseMin(output); + } +}; + +template +struct ProdOp { + void operator()(const constMatrixChip data, MatrixChip output) { + output *= data; } }; } // namespace functor -// Base class for SegmentReductionOps that can handle unsorted segment -// definitions -// and specifying the size of the output in addition to a reduction function -template -class UnsortedSegmentBaseOp : public OpKernel { +// Static check routines not in the templated class to reduce code size +static void UnsortedSegmentReductionValidation(OpKernel* op_kernel, + OpKernelContext* context, + const Tensor& data, + const Tensor& segment_ids, + const Tensor& num_segments) { + OP_REQUIRES( + context, op_kernel->IsLegacyScalar(num_segments.shape()), + errors::InvalidArgument("num_segments should be a scalar, not shape ", + num_segments.shape().DebugString())); + OP_REQUIRES( + context, TensorShapeUtils::StartsWith(data.shape(), segment_ids.shape()), + errors::InvalidArgument("data.shape = ", data.shape().DebugString(), + " does not start with segment_ids.shape = ", + segment_ids.shape().DebugString())); +} + +static bool UnsortedSegmentReductionDoValidation(OpKernel* op_kernel, + OpKernelContext* context, + const Tensor& data, + const Tensor& segment_ids, + const Tensor& num_segments) { + UnsortedSegmentReductionValidation(op_kernel, context, data, segment_ids, + num_segments); + return context->status().ok(); +} + +// The UnsortedSegmentReduction OpKernel. The DeviceReductionFunctor +// is the device specific implementation of the reduction. These device +// specific implementations are templated themselves with the corresponding +// initial value functors and reduction functors. +template +class UnsortedSegmentReductionOp : public OpKernel { public: - explicit UnsortedSegmentBaseOp( - OpKernelConstruction* context, - functor::UnsortedSegmentBaseFunctor& functor) - : OpKernel(context), reduction_functor_(functor) {} + explicit UnsortedSegmentReductionOp(OpKernelConstruction* context) + : OpKernel(context), reduction_functor_(DeviceReductionFunctor()) {} void Compute(OpKernelContext* context) override { const Tensor& data = context->input(0); const Tensor& segment_ids = context->input(1); const Tensor& num_segments = context->input(2); - - OP_REQUIRES( - context, IsLegacyScalar(num_segments.shape()), - errors::InvalidArgument("num_segments should be a scalar, not shape ", - num_segments.shape().DebugString())); - OP_REQUIRES( - context, - TensorShapeUtils::StartsWith(data.shape(), segment_ids.shape()), - errors::InvalidArgument("data.shape = ", data.shape().DebugString(), - " does not start with segment_ids.shape = ", - segment_ids.shape().DebugString())); - + if (!UnsortedSegmentReductionDoValidation(this, context, data, segment_ids, + num_segments)) { + return; + } const auto segment_flat = segment_ids.flat(); const Index output_rows = internal::SubtleMustCopy(num_segments.scalar()()); OP_REQUIRES(context, output_rows >= 0, errors::InvalidArgument("Input num_segments == ", output_rows, " must not be negative.")); - TensorShape output_shape; output_shape.AddDim(output_rows); for (int i = segment_ids.dims(); i < data.dims(); i++) { output_shape.AddDim(data.dim_size(i)); } - Tensor* output = nullptr; OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output)); auto output_flat = output->flat_outer_dims(); - auto data_ptr = data.template flat().data(); - reduction_functor_(context, context->template eigen_device(), - output_rows, segment_ids.shape(), segment_flat, + reduction_functor_(context, output_rows, segment_ids.shape(), segment_flat, data.NumElements(), data_ptr, output_flat); } - private: - functor::UnsortedSegmentBaseFunctor& reduction_functor_; -}; - -template -class UnsortedSegmentSumOp : public UnsortedSegmentBaseOp { - public: - explicit UnsortedSegmentSumOp(OpKernelConstruction* context) - : UnsortedSegmentBaseOp(context, sum_functor_) {} - - private: - functor::UnsortedSegmentSumFunctor sum_functor_; + protected: + DeviceReductionFunctor reduction_functor_; }; -template -class UnsortedSegmentMaxOp : public UnsortedSegmentBaseOp { - public: - explicit UnsortedSegmentMaxOp(OpKernelConstruction* context) - : UnsortedSegmentBaseOp(context, max_functor_) {} - - private: - functor::UnsortedSegmentMaxFunctor max_functor_; -}; - -#define REGISTER_REAL_CPU_UNSORTED_KERNELS(type, index_type) \ - REGISTER_KERNEL_BUILDER(Name("UnsortedSegmentSum") \ - .Device(DEVICE_CPU) \ - .TypeConstraint("T") \ - .TypeConstraint("Tindices"), \ - UnsortedSegmentSumOp); \ - REGISTER_KERNEL_BUILDER(Name("UnsortedSegmentMax") \ - .Device(DEVICE_CPU) \ - .TypeConstraint("T") \ - .TypeConstraint("Tindices"), \ - UnsortedSegmentMaxOp); - -#define REGISTER_COMPLEX_CPU_UNSORTED_KERNELS(type, index_type) \ - REGISTER_KERNEL_BUILDER(Name("UnsortedSegmentSum") \ - .Device(DEVICE_CPU) \ - .TypeConstraint("T") \ - .TypeConstraint("Tindices"), \ - UnsortedSegmentSumOp); +#define REGISTER_CPU_KERNEL_UNSORTEDSEGMENT( \ + name, type, index_type, initial_value_functor, reduction_functor) \ + REGISTER_KERNEL_BUILDER( \ + Name(name) \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .TypeConstraint("Tindices"), \ + UnsortedSegmentReductionOp< \ + type, index_type, \ + functor::UnsortedSegmentFunctor >) + +#define REGISTER_REAL_CPU_UNSORTED_KERNELS(type, index_type) \ + REGISTER_CPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentSum", type, index_type, \ + functor::Zero, \ + functor::SumOp); \ + REGISTER_CPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentMax", type, index_type, \ + functor::Lowest, \ + functor::MaxOp); \ + REGISTER_CPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentMin", type, index_type, \ + functor::Highest, \ + functor::MinOp); \ + REGISTER_CPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentProd", type, index_type, \ + functor::One, \ + functor::ProdOp); + +#define REGISTER_COMPLEX_CPU_UNSORTED_KERNELS(type, index_type) \ + REGISTER_CPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentSum", type, index_type, \ + functor::Zero, \ + functor::SumOp); \ + REGISTER_CPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentProd", type, index_type, \ + functor::One, \ + functor::ProdOp) #define REGISTER_REAL_CPU_UNSORTED_KERNELS_ALL(type) \ REGISTER_REAL_CPU_UNSORTED_KERNELS(type, int32); \ @@ -520,31 +542,72 @@ class UnsortedSegmentMaxOp : public UnsortedSegmentBaseOp { TF_CALL_REAL_NUMBER_TYPES(REGISTER_REAL_CPU_UNSORTED_KERNELS_ALL); REGISTER_COMPLEX_CPU_UNSORTED_KERNELS_ALL(complex64); REGISTER_COMPLEX_CPU_UNSORTED_KERNELS_ALL(complex128); + #undef REGISTER_REAL_CPU_UNSORTED_KERNELS +#undef REGISTER_CPU_KERNEL_UNSORTEDSEGMENT #undef REGISTER_COMPLEX_CPU_UNSORTED_KERNELS #undef REGISTER_COMPLEX_CPU_UNSORTED_KERNELS_ALL #undef REGISTER_REAL_CPU_UNSORTED_KERNELS_ALL #if GOOGLE_CUDA -#define REGISTER_GPU_UNSORTED_KERNELS(type, index_type) \ - REGISTER_KERNEL_BUILDER(Name("UnsortedSegmentSum") \ - .Device(DEVICE_GPU) \ - .HostMemory("num_segments") \ - .TypeConstraint("T") \ - .TypeConstraint("Tindices"), \ - UnsortedSegmentSumOp); - -#define REGISTER_GPU_UNSORTED_KERNELS_ALL(type) \ - REGISTER_GPU_UNSORTED_KERNELS(type, int32); \ - REGISTER_GPU_UNSORTED_KERNELS(type, int64); +#define REGISTER_GPU_KERNEL_UNSORTEDSEGMENT( \ + name, type, index_type, initial_value_functor, reduction_kernel_functor) \ + REGISTER_KERNEL_BUILDER( \ + Name(name) \ + .Device(DEVICE_GPU) \ + .HostMemory("num_segments") \ + .TypeConstraint("T") \ + .TypeConstraint("Tindices"), \ + UnsortedSegmentReductionOp< \ + type, index_type, \ + functor::UnsortedSegmentFunctor >) + +// sum is the only op that supports all input types currently +#define REGISTER_REAL_GPU_UNSORTED_KERNELS(type, index_type) \ + REGISTER_GPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentMax", type, index_type, \ + functor::Lowest, \ + functor::MaxOpGpu); \ + REGISTER_GPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentMin", type, index_type, \ + functor::Highest, \ + functor::MinOpGpu); \ + REGISTER_GPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentProd", type, index_type, \ + functor::One, \ + functor::ProdOpGpu); + +#define REGISTER_SUM_GPU_UNSORTED_KERNELS(type, index_type) \ + REGISTER_GPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentSum", type, index_type, \ + functor::Zero, \ + functor::SumOpGpu); + +#define REGISTER_REAL_GPU_UNSORTED_KERNELS_ALL(type) \ + REGISTER_REAL_GPU_UNSORTED_KERNELS(type, int32); \ + REGISTER_REAL_GPU_UNSORTED_KERNELS(type, int64); + +#define REGISTER_SUM_GPU_UNSORTED_KERNELS_ALL(type) \ + REGISTER_SUM_GPU_UNSORTED_KERNELS(type, int32); \ + REGISTER_SUM_GPU_UNSORTED_KERNELS(type, int64); + + +TF_CALL_GPU_NUMBER_TYPES(REGISTER_REAL_GPU_UNSORTED_KERNELS_ALL); +TF_CALL_int32(REGISTER_REAL_GPU_UNSORTED_KERNELS_ALL); +TF_CALL_GPU_NUMBER_TYPES(REGISTER_SUM_GPU_UNSORTED_KERNELS_ALL); +TF_CALL_int32(REGISTER_SUM_GPU_UNSORTED_KERNELS_ALL); +TF_CALL_complex64(REGISTER_SUM_GPU_UNSORTED_KERNELS_ALL); +TF_CALL_complex128(REGISTER_SUM_GPU_UNSORTED_KERNELS_ALL); + +#undef REGISTER_GPU_KERNEL_UNSORTEDSEGMENT +#undef REGISTER_REAL_GPU_UNSORTED_KERNELS +#undef REGISTER_SUM_GPU_UNSORTED_KERNELS +#undef REGISTER_REAL_GPU_UNSORTED_KERNELS_ALL +#undef REGISTER_SUM_GPU_UNSORTED_KERNELS_ALL -TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_UNSORTED_KERNELS_ALL); -TF_CALL_complex64(REGISTER_GPU_UNSORTED_KERNELS_ALL); -TF_CALL_complex128(REGISTER_GPU_UNSORTED_KERNELS_ALL); -#undef REGISTER_GPU_UNSORTED_KERNELS -#undef REGISTER_GPU_UNSORTED_KERNELS_ALL #endif // GOOGLE_CUDA +// ____________________________________________________________________________ +// Sparse segment reduction ops. + // Same as SegmentReductionOp but takes as input a "sparse" tensor, represented // by two dense tensors, one containing the data, and the other containing // indices into the data. diff --git a/tensorflow/core/kernels/segment_reduction_ops.h b/tensorflow/core/kernels/segment_reduction_ops.h index 5c9cfe090656ff043b952192b7a4d6e8a80b692f..51814273b305bfa35bca0ddce0376658064ea56a 100644 --- a/tensorflow/core/kernels/segment_reduction_ops.h +++ b/tensorflow/core/kernels/segment_reduction_ops.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CORE_KERNELS_SEGMENT_REDUCTION_OPS_H_ -#define TENSORFLOW_CORE_KERNELS_SEGMENT_REDUCTION_OPS_H_ +#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_SEGMENT_REDUCTION_OPS_H_ +#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_SEGMENT_REDUCTION_OPS_H_ #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/tensor.h" @@ -46,59 +46,81 @@ struct SegmentSumFunctor { const Index data_size, const T* data, typename TTypes::Tensor output); }; -#endif -// BaseFunctor for definition of UnsorteSegmentReductionOp -// for usage without templates. -template -struct UnsortedSegmentBaseFunctor { - virtual ~UnsortedSegmentBaseFunctor() {} - virtual void operator()(OpKernelContext* ctx, const Device& d, - const Index output_rows, - const TensorShape& segment_ids_shape, - typename TTypes::ConstFlat segment_ids, - const Index data_size, const T* data, - typename TTypes::Tensor output){}; -}; +#endif -// Functor for UnsortedSegmentSumOp. -// output_rows: the number of output segments (unique segment ids in -// 'segment_ids'). -// segment_ids_shape: shape of 'segment_ids' tensor. -// segment_ids: unsorted map from input to output segment ids at which to -// perform segment sum operation. -// data_size: size of input data tensor. -// data: input data tensor. -// output: output reshaped to {output_rows, output.size/output_rows} -template -struct UnsortedSegmentSumFunctor - : public UnsortedSegmentBaseFunctor { - void operator()(OpKernelContext* ctx, const Device& d, - const Index output_rows, const TensorShape& segment_ids_shape, +template +struct UnsortedSegmentFunctor { + void operator()(OpKernelContext* ctx, const Index num_segments, + const TensorShape& segment_ids_shape, typename TTypes::ConstFlat segment_ids, const Index data_size, const T* data, typename TTypes::Tensor output); }; -// Functor for UnsortedSegmentMaxOp. -// output_rows: the number of output segments (unique segment ids in -// 'segment_ids'). -// segment_ids_shape: shape of 'segment_ids' tensor. -// segment_ids: unsorted map from input to output segment ids at which to -// perform segment sum operation. -// data_size: size of input data tensor. -// data: input data tensor. -// output: output reshaped to {output_rows, output.size/output_rows} -template -struct UnsortedSegmentMaxFunctor - : public UnsortedSegmentBaseFunctor { - void operator()(OpKernelContext* ctx, const Device& d, - const Index output_rows, const TensorShape& segment_ids_shape, - typename TTypes::ConstFlat segment_ids, - const Index data_size, const T* data, - typename TTypes::Tensor output); +#ifdef GOOGLE_CUDA +// reduction functors for the gpu +template +struct SumOpGpu { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void operator()(T* dest, + const T& value) { + CudaAtomicAdd(dest, value); + } +}; + +template +struct ProdOpGpu { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void operator()(T* dest, + const T& value) { + CudaAtomicMul(dest, value); + } +}; + +template +struct MaxOpGpu { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void operator()(T* dest, + const T& value) { + CudaAtomicMax(dest, value); + } +}; + +template +struct MinOpGpu { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void operator()(T* dest, + const T& value) { + CudaAtomicMin(dest, value); + } }; + +#endif // GOOGLE_CUDA + +// initial value functors +template +struct Zero { + EIGEN_STRONG_INLINE T operator()() const { return T(0); } +}; + +template +struct One { + EIGEN_STRONG_INLINE T operator()() const { return T(1); } +}; + +template +struct Lowest { + EIGEN_STRONG_INLINE T operator()() const { + return Eigen::NumTraits::lowest(); + } +}; + +template +struct Highest { + EIGEN_STRONG_INLINE T operator()() const { + return Eigen::NumTraits::highest(); + } +}; + } // namespace functor } // namespace tensorflow -#endif // TENSORFLOW_CORE_KERNELS_SEGMENT_REDUCTION_OPS_H_ +#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_SEGMENT_REDUCTION_OPS_H_ diff --git a/tensorflow/core/kernels/segment_reduction_ops_gpu.cu.cc b/tensorflow/core/kernels/segment_reduction_ops_gpu.cu.cc index 39d520698e1910a432de29b747a223f9e8033d24..ba979e6bb216b649ff4fc3cefa7099ac9cbc1b91 100644 --- a/tensorflow/core/kernels/segment_reduction_ops_gpu.cu.cc +++ b/tensorflow/core/kernels/segment_reduction_ops_gpu.cu.cc @@ -18,42 +18,15 @@ limitations under the License. #define EIGEN_USE_GPU #include "tensorflow/core/kernels/segment_reduction_ops.h" - #include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/util/cuda_device_functions.h" #include "tensorflow/core/util/cuda_kernel_helper.h" + namespace tensorflow { using GPUDevice = Eigen::GpuDevice; -// Helper for UnusortedSegmentSumCustomKernel that adds value into dest -// atomically. -template -static __device__ __forceinline__ void AccumulateInto(T* dest, const T& value) { - CudaAtomicAdd(dest, value); -} - -// Specializations of AccumulateInto for complex types, which CudaAtomicAdd does -// not support. We treat a std::complex* as a T* (the C++ standard section -// 26.4.4 allows this explicitly) and atomic add the real and imaginary -// components individually. The operation as a whole is not atomic, but we can -// safely treat the components independently for the purpose of accumulating. -template <> -__device__ __forceinline__ void AccumulateInto( - std::complex* dest, const std::complex& value) { - auto dest_scalar = reinterpret_cast(dest); - CudaAtomicAdd(dest_scalar, value.real()); - CudaAtomicAdd(dest_scalar + 1, value.imag()); -} - -template <> -__device__ __forceinline__ void AccumulateInto( - std::complex* dest, const std::complex& value) { - auto dest_scalar = reinterpret_cast(dest); - CudaAtomicAdd(dest_scalar, value.real()); - CudaAtomicAdd(dest_scalar + 1, value.imag()); -} - // SortedSegmentSumFunctor kernel reduces input data just as // UnsortedSegmentSumCustomKernel does except that input data // is partitioned along the outer reduction dimension. This is @@ -81,7 +54,7 @@ __global__ void SortedSegmentSumCustomKernel(const Index input_outer_dim_size, const Index* segment_ids, const T* input, T* output, const Index total_stripe_count) { - CUDA_1D_KERNEL_LOOP(stripe_index, total_stripe_count) { + for (int stripe_index : CudaGridRangeX(total_stripe_count)) { const Index segment_offset = stripe_index % inner_dim_size; const Index input_outer_dim_index_base = stripe_index / inner_dim_size * Index(OuterDimTileSize); @@ -106,7 +79,7 @@ __global__ void SortedSegmentSumCustomKernel(const Index input_outer_dim_size, // decide whether to write result to global memory using atomic // operations if (last_output_segment_id == first_segment_id) { - AccumulateInto(output + output_index, sum); + CudaAtomicAdd(output + output_index, sum); } else { *(output + output_index) = sum; } @@ -121,31 +94,31 @@ __global__ void SortedSegmentSumCustomKernel(const Index input_outer_dim_size, // the following strip. const Index output_index = last_output_segment_id * inner_dim_size + segment_offset; - AccumulateInto(output + output_index, sum); + CudaAtomicAdd(output + output_index, sum); } } -// UnsortedSegmentSumFunctor kernel processes 'input_total_size' elements. +// UnsortedSegmentSumKernel processes 'input_total_size' elements. // Each element is mapped from input to output by a combination of its // 'segment_ids' mapping and 'inner_dim_size'. -template -__global__ void UnsortedSegmentSumCustomKernel( - const Index input_outer_dim_size, const Index inner_dim_size, - const Index output_outer_dim_size, const Index* segment_ids, const T* input, - T* output) { +template +__global__ void UnsortedSegmentCustomKernel(const Index input_outer_dim_size, + const Index inner_dim_size, + const Index output_outer_dim_size, + const Index* segment_ids, + const T* input, T* output) { const Index input_total_size = input_outer_dim_size * inner_dim_size; const Index output_total_size = output_outer_dim_size * inner_dim_size; - CUDA_1D_KERNEL_LOOP(input_index, input_total_size) { + for (int input_index : CudaGridRangeX(input_total_size)) { const Index input_segment_index = input_index / inner_dim_size; const Index segment_offset = input_index % inner_dim_size; const Index output_segment_index = segment_ids[input_segment_index]; - if (output_segment_index < 0 || output_segment_index >= output_total_size) { continue; } const Index output_index = output_segment_index * inner_dim_size + segment_offset; - AccumulateInto(output + output_index, ldg(input + input_index)); + KernelReductionFunctor()(output + output_index, ldg(input + input_index)); } } @@ -190,41 +163,39 @@ void SegmentSumFunctor::operator()( <<>>( input_outer_dim_size, input_inner_dim_size, output_rows, segment_ids.data(), data, output.data(), total_stripe_count); -}; +} -// UnsortedSegmentSumFunctor implementation for GPUDevice. -template -struct UnsortedSegmentSumFunctor - : UnsortedSegmentBaseFunctor { - void operator()(OpKernelContext* ctx, const GPUDevice& d, - const Index output_rows, const TensorShape& segment_ids_shape, +template +struct UnsortedSegmentFunctor { + void operator()(OpKernelContext* ctx, const Index num_segments, + const TensorShape& segment_ids_shape, typename TTypes::ConstFlat segment_ids, const Index data_size, const T* data, - typename TTypes::Tensor output) override { + typename TTypes::Tensor output) { if (output.size() == 0) { return; } - // Set 'output' to zeros. + // Set 'output' to initial value. + GPUDevice d = ctx->template eigen_device(); CudaLaunchConfig config = GetCudaLaunchConfig(output.size(), d); - SetZero<<>>( - output.size(), output.data()); + SetToValue<<>>( + output.size(), output.data(), InitialValueF()()); if (data_size == 0 || segment_ids_shape.num_elements() == 0) { return; } - - // Launch kernel to compute unsorted segment sum. + // Launch kernel to compute unsorted segment reduction. // Notes: - // *) 'input_total_size' is the total number of elements to process. + // *) 'data_size' is the total number of elements to process. // *) 'segment_ids.shape' is a prefix of data's shape. // *) 'input_outer_dim_size' is the total number of segments to process. - const Index input_total_size = data_size; const Index input_outer_dim_size = segment_ids.dimension(0); - const Index input_inner_dim_size = input_total_size / input_outer_dim_size; + const Index input_inner_dim_size = data_size / input_outer_dim_size; + config = GetCudaLaunchConfig(data_size, d); - config = GetCudaLaunchConfig(input_total_size, d); - UnsortedSegmentSumCustomKernel + UnsortedSegmentCustomKernel <<>>( - input_outer_dim_size, input_inner_dim_size, output_rows, + input_outer_dim_size, input_inner_dim_size, num_segments, segment_ids.data(), data, output.data()); } }; @@ -238,19 +209,40 @@ struct UnsortedSegmentSumFunctor TF_CALL_GPU_NUMBER_TYPES(DEFINE_SORTED_GPU_SPECS); -#define DEFINE_GPU_SPECS_INDEX(T, Index) \ - template struct UnsortedSegmentSumFunctor - -#define DEFINE_GPU_SPECS(T) \ - DEFINE_GPU_SPECS_INDEX(T, int32); \ - DEFINE_GPU_SPECS_INDEX(T, int64); - -TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_SPECS); -TF_CALL_complex64(DEFINE_GPU_SPECS); -TF_CALL_complex128(DEFINE_GPU_SPECS); - -#undef DEFINE_GPU_SPECS -#undef DEFINE_GPU_SPECS_INDEX +#define DEFINE_REAL_UNSORTED_GPU_SPECS_INDEX(T, Index) \ + template struct UnsortedSegmentFunctor< \ + GPUDevice, T, Index, functor::Lowest, functor::MaxOpGpu>; \ + template struct UnsortedSegmentFunctor< \ + GPUDevice, T, Index, functor::Highest, functor::MinOpGpu>; \ + template struct UnsortedSegmentFunctor, \ + functor::ProdOpGpu>; + +// sum is the only op that supports all input types currently +#define DEFINE_SUM_UNSORTED_GPU_SPECS_INDEX(T, Index) \ + template struct UnsortedSegmentFunctor< \ + GPUDevice, T, Index, functor::Zero, functor::SumOpGpu>; + +#define DEFINE_REAL_GPU_SPECS(T) \ + DEFINE_REAL_UNSORTED_GPU_SPECS_INDEX(T, int32); \ + DEFINE_REAL_UNSORTED_GPU_SPECS_INDEX(T, int64); + +#define DEFINE_SUM_GPU_SPECS(T) \ + DEFINE_SUM_UNSORTED_GPU_SPECS_INDEX(T, int32); \ + DEFINE_SUM_UNSORTED_GPU_SPECS_INDEX(T, int64); + +TF_CALL_GPU_NUMBER_TYPES(DEFINE_REAL_GPU_SPECS); +TF_CALL_int32(DEFINE_REAL_GPU_SPECS); +TF_CALL_GPU_NUMBER_TYPES(DEFINE_SUM_GPU_SPECS); +TF_CALL_int32(DEFINE_SUM_GPU_SPECS); +TF_CALL_complex64(DEFINE_SUM_GPU_SPECS); +TF_CALL_complex128(DEFINE_SUM_GPU_SPECS); + +#undef DEFINE_SORTED_GPU_SPECS_INDEX +#undef DEFINE_SORTED_GPU_SPECS +#undef DEFINE_REAL_UNSORTED_GPU_SPECS_INDEX +#undef DEFINE_SUM_UNSORTED_GPU_SPECS_INDEX +#undef DEFINE_REAL_GPU_SPECS +#undef DEFINE_SUM_GPU_SPECS } // namespace functor } // namespace tensorflow diff --git a/tensorflow/core/kernels/serialize_sparse_op.cc b/tensorflow/core/kernels/serialize_sparse_op.cc index 61e40caef99c019914fc331bee5d8beab0883f41..799c574d1542c345c606c276b0cc24fe61a47bba 100644 --- a/tensorflow/core/kernels/serialize_sparse_op.cc +++ b/tensorflow/core/kernels/serialize_sparse_op.cc @@ -426,7 +426,6 @@ class DeserializeSparseOp : public OpKernel { switch (dtype_) { TF_CALL_ALL_TYPES(HANDLE_TYPE); TF_CALL_QUANTIZED_TYPES(HANDLE_TYPE); - TF_CALL_variant(HANDLE_TYPE); #undef HANDLE_TYPE default: OP_REQUIRES(context, false, diff --git a/tensorflow/core/kernels/strided_slice_op.cc b/tensorflow/core/kernels/strided_slice_op.cc index 8f7f91c9df737053a052295cbf870a640c230d7a..7745effe2abe94ba73a2f0d761210e07c62e499c 100644 --- a/tensorflow/core/kernels/strided_slice_op.cc +++ b/tensorflow/core/kernels/strided_slice_op.cc @@ -294,6 +294,11 @@ class StridedSliceAssignOp : public OpKernel { OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), &v)); old_lhs = *v->tensor(); + OP_REQUIRES(context, old_lhs.dtype() == DataTypeToEnum::value, + errors::InvalidArgument( + "l-value dtype ", DataTypeString(old_lhs.dtype()), + " does not match r-value dtype ", + DataTypeString(DataTypeToEnum::value))); } else { context->forward_ref_input_to_ref_output(0, 0); old_lhs = context->mutable_input(0, true); diff --git a/tensorflow/core/kernels/strided_slice_op.h b/tensorflow/core/kernels/strided_slice_op.h index 0f72c4b771025458a1403ce13842787249a2718f..2b5863229860c256e1c74f1fe11bf57ed502008e 100644 --- a/tensorflow/core/kernels/strided_slice_op.h +++ b/tensorflow/core/kernels/strided_slice_op.h @@ -21,6 +21,7 @@ limitations under the License. #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/resource_handle.h" #include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/framework/variant_encode_decode.h" #include "tensorflow/core/platform/types.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/strided_slice_op_impl.h b/tensorflow/core/kernels/strided_slice_op_impl.h index ac1259a9ac4f25fac3f4a15350a50064b4f9b6a7..1c4472bb1ab4e6b9d09a1f1464577172056c6fbe 100644 --- a/tensorflow/core/kernels/strided_slice_op_impl.h +++ b/tensorflow/core/kernels/strided_slice_op_impl.h @@ -26,6 +26,8 @@ limitations under the License. #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/register_types_traits.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/variant.h" +#include "tensorflow/core/framework/variant_encode_decode.h" #include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/kernels/dense_update_functor.h" #include "tensorflow/core/kernels/ops_util.h" diff --git a/tensorflow/core/lib/core/status.h b/tensorflow/core/lib/core/status.h index 58a50a70c26a63a9edd55349e2253a9ace16f1f2..49f74ff47fbc839c84465ba86e85b38cb3bd38ec 100644 --- a/tensorflow/core/lib/core/status.h +++ b/tensorflow/core/lib/core/status.h @@ -131,7 +131,7 @@ inline tensorflow::string* TfCheckOpHelper(::tensorflow::Status v, while (auto _result = ::tensorflow::TfCheckOpHelper(val, #val)) \ LOG(level) << *(_result) -#define TF_CHECK_OK(val) TF_DO_CHECK_OK(val, FATAL) +#define TF_CHECK_OK(val) TF_DO_CHECK_OK(val, FATAL) #define TF_QCHECK_OK(val) TF_DO_CHECK_OK(val, QFATAL) // DEBUG only version of TF_CHECK_OK. Compiler still parses 'val' even in opt diff --git a/tensorflow/core/lib/core/threadpool.cc b/tensorflow/core/lib/core/threadpool.cc index 2b10ebeaf7cbed4a8466a69898d6d4d6660ed5cb..e55ed79d36cd2db7a6f6b19f3579f47e73b4b2d9 100644 --- a/tensorflow/core/lib/core/threadpool.cc +++ b/tensorflow/core/lib/core/threadpool.cc @@ -66,7 +66,9 @@ struct EigenEnvironment { } return Task{ std::unique_ptr(new TaskImpl{ - std::move(f), Context(ContextKind::kThread), id, + std::move(f), + Context(ContextKind::kThread), + id, }), }; } diff --git a/tensorflow/core/lib/core/threadpool_test.cc b/tensorflow/core/lib/core/threadpool_test.cc index 49ddb16645c32a82d90eafa5f550b8887ac84b79..627ef5a892a35ec43d0c31220dcf046b4b8eda55 100644 --- a/tensorflow/core/lib/core/threadpool_test.cc +++ b/tensorflow/core/lib/core/threadpool_test.cc @@ -97,8 +97,8 @@ TEST(ThreadPool, ParallelForWithWorkerId) { } pool.ParallelForWithWorkerId( kWorkItems, kHugeCost, - [&threads_running, &work, num_threads]( - int64 begin, int64 end, int64 id) { + [&threads_running, &work, num_threads](int64 begin, int64 end, + int64 id) { // Store true for the current thread, and assert that another thread // is not running with the same id. ASSERT_LE(0, id); diff --git a/tensorflow/core/lib/db/sqlite.h b/tensorflow/core/lib/db/sqlite.h index 0faa458f1d692a103099d5b05d0400944ffdaad7..efe97f78d259199a74bf5e830f70de657d1cd679 100644 --- a/tensorflow/core/lib/db/sqlite.h +++ b/tensorflow/core/lib/db/sqlite.h @@ -18,12 +18,12 @@ limitations under the License. #include #include "sqlite3.h" +#include "tensorflow/core/lib/core/refcount.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/thread_annotations.h" #include "tensorflow/core/platform/types.h" -#include "tensorflow/core/lib/core/refcount.h" /// TensorFlow SQLite Veneer /// @@ -121,10 +121,7 @@ class LOCKABLE Sqlite : public core::RefCounted { Sqlite(sqlite3* db, sqlite3_stmt* begin, sqlite3_stmt* commit, sqlite3_stmt* rollback) noexcept - : db_(db), - begin_(begin), - commit_(commit), - rollback_(rollback) {} + : db_(db), begin_(begin), commit_(commit), rollback_(rollback) {} sqlite3* const db_; sqlite3_stmt* const begin_; @@ -233,7 +230,8 @@ class SqliteStatement { /// freed until this statement is Reset() or finalized. void BindText(int parameter, const StringPiece& text) { Update(sqlite3_bind_text64(stmt_, parameter, text.data(), text.size(), - SQLITE_TRANSIENT, SQLITE_UTF8), parameter); + SQLITE_TRANSIENT, SQLITE_UTF8), + parameter); size_ += text.size(); } void BindText(const char* parameter, const StringPiece& text) { @@ -241,7 +239,8 @@ class SqliteStatement { } void BindTextUnsafe(int parameter, const StringPiece& text) { Update(sqlite3_bind_text64(stmt_, parameter, text.data(), text.size(), - SQLITE_STATIC, SQLITE_UTF8), parameter); + SQLITE_STATIC, SQLITE_UTF8), + parameter); size_ += text.size(); } void BindTextUnsafe(const char* parameter, const StringPiece& text) { @@ -254,7 +253,8 @@ class SqliteStatement { /// freed until this statement is Reset() or finalized. void BindBlob(int parameter, const StringPiece& blob) { Update(sqlite3_bind_blob64(stmt_, parameter, blob.data(), blob.size(), - SQLITE_TRANSIENT), parameter); + SQLITE_TRANSIENT), + parameter); size_ += blob.size(); } void BindBlob(const char* parameter, const StringPiece& blob) { @@ -262,7 +262,8 @@ class SqliteStatement { } void BindBlobUnsafe(int parameter, const StringPiece& blob) { Update(sqlite3_bind_blob64(stmt_, parameter, blob.data(), blob.size(), - SQLITE_STATIC), parameter); + SQLITE_STATIC), + parameter); size_ += blob.size(); } void BindBlobUnsafe(const char* parameter, const StringPiece& text) { @@ -320,9 +321,7 @@ class SqliteStatement { /// \brief Move constructor, after which is reset to empty. SqliteStatement(SqliteStatement&& other) noexcept - : db_(other.db_), - stmt_(other.stmt_), - bind_error_(other.bind_error_) { + : db_(other.db_), stmt_(other.stmt_), bind_error_(other.bind_error_) { other.db_ = nullptr; other.stmt_ = nullptr; other.bind_error_ = SQLITE_OK; diff --git a/tensorflow/core/lib/db/sqlite_test.cc b/tensorflow/core/lib/db/sqlite_test.cc index c9c76ea5f2cd30b8abe7e3c9766ce4946ca25200..1e88323d017bec4b2705c6dbb19005efb8adbaa9 100644 --- a/tensorflow/core/lib/db/sqlite_test.cc +++ b/tensorflow/core/lib/db/sqlite_test.cc @@ -33,9 +33,7 @@ class SqliteTest : public ::testing::Test { db_->PrepareOrDie("CREATE TABLE T (a BLOB, b BLOB)").StepAndResetOrDie(); } - void TearDown() override { - db_->Unref(); - } + void TearDown() override { db_->Unref(); } Sqlite* db_; bool is_done_; @@ -213,7 +211,7 @@ TEST_F(SqliteTest, BindFailed) { Status s = stmt.StepOnce(); EXPECT_NE(string::npos, s.error_message().find("INSERT INTO T (a) VALUES (123)")) - << s.error_message(); + << s.error_message(); } TEST_F(SqliteTest, SnappyExtension) { @@ -226,7 +224,7 @@ TEST_F(SqliteTest, SnappyBinaryCompatibility) { EXPECT_EQ( "today is the end of the republic", db_->PrepareOrDie("SELECT UNSNAP(X'03207C746F6461792069732074686520656E64" - "206F66207468652072657075626C6963')") + "206F66207468652072657075626C6963')") .StepOnceOrDie() .ColumnString(0)); } diff --git a/tensorflow/core/lib/gif/gif_io.cc b/tensorflow/core/lib/gif/gif_io.cc index 0f6999c88fca3fd7ab91d2f3e28348e22d106f45..e5deb2b873e22249cc52323b1b29518e4255d48a 100644 --- a/tensorflow/core/lib/gif/gif_io.cc +++ b/tensorflow/core/lib/gif/gif_io.cc @@ -44,6 +44,14 @@ int input_callback(GifFileType* gif_file, GifByteType* buf, int size) { return 0; } +static const char* GifErrorStringNonNull(int error_code) { + const char* error_string = GifErrorString(error_code); + if (error_string == nullptr) { + return "Unknown error"; + } + return error_string; +} + uint8* Decode(const void* srcdata, int datasize, const std::function& allocate_output, string* error_string) { @@ -55,17 +63,17 @@ uint8* Decode(const void* srcdata, int datasize, int error_code = D_GIF_SUCCEEDED; if (gif_file && DGifCloseFile(gif_file, &error_code) != GIF_OK) { LOG(WARNING) << "Fail to close gif file, reason: " - << GifErrorString(error_code); + << GifErrorStringNonNull(error_code); } }); if (error_code != D_GIF_SUCCEEDED) { *error_string = strings::StrCat("failed to open gif file: ", - GifErrorString(error_code)); + GifErrorStringNonNull(error_code)); return nullptr; } if (DGifSlurp(gif_file) != GIF_OK) { *error_string = strings::StrCat("failed to slurp gif file: ", - GifErrorString(gif_file->Error)); + GifErrorStringNonNull(gif_file->Error)); return nullptr; } if (gif_file->ImageCount <= 0) { diff --git a/tensorflow/core/lib/gtl/cleanup.h b/tensorflow/core/lib/gtl/cleanup.h index 6053e986402598568299d1756d23068693c193c8..6bd60ca482430cf13f4f076badf460cf2e1d593b 100644 --- a/tensorflow/core/lib/gtl/cleanup.h +++ b/tensorflow/core/lib/gtl/cleanup.h @@ -55,22 +55,21 @@ namespace gtl { template class Cleanup { public: - Cleanup() - : released_(true), f_() {} + Cleanup() : released_(true), f_() {} template - explicit Cleanup(G&& f) // NOLINT + explicit Cleanup(G&& f) // NOLINT : f_(std::forward(f)) {} // NOLINT(build/c++11) Cleanup(Cleanup&& src) // NOLINT - : released_(src.is_released()), f_(src.release()) { } + : released_(src.is_released()), f_(src.release()) {} // Implicitly move-constructible from any compatible Cleanup. // The source will be released as if src.release() were called. // A moved-from Cleanup can be safely destroyed or reassigned. template Cleanup(Cleanup&& src) // NOLINT - : released_(src.is_released()), f_(src.release()) { } + : released_(src.is_released()), f_(src.release()) {} // Assignment to a Cleanup object behaves like destroying it // and making a new one in its place, analogous to unique_ptr @@ -102,8 +101,8 @@ class Cleanup { F f_; }; -template ::type> +template ::type> TF_MUST_USE_RESULT Cleanup MakeCleanup(F&& f) { return Cleanup(std::forward(f)); } diff --git a/tensorflow/core/lib/gtl/cleanup_test.cc b/tensorflow/core/lib/gtl/cleanup_test.cc index bd151cb2ab1c8a830eb1bd9546ab452d05c6c20c..a86ffd5fe284485f15fa824026e8d79f5191a384 100644 --- a/tensorflow/core/lib/gtl/cleanup_test.cc +++ b/tensorflow/core/lib/gtl/cleanup_test.cc @@ -65,15 +65,14 @@ TEST(CleanupTest, Release) { TEST(FinallyTest, TypeErasedWithoutFactory) { string s = "active"; { - AnyCleanup s_cleaner([&s]{ s.append(" clean"); }); + AnyCleanup s_cleaner([&s] { s.append(" clean"); }); EXPECT_EQ("active", s); } EXPECT_EQ("active clean", s); } struct Appender { - Appender(string* s, const string& msg) - : s_(s), msg_(msg) {} + Appender(string* s, const string& msg) : s_(s), msg_(msg) {} void operator()() const { s_->append(msg_); } string* s_; string msg_; @@ -163,7 +162,12 @@ class CleanupReferenceTest : public ::testing::Test { int* i; F(int* cp, int* i) : cp(cp), i(i) {} F(const F& o) : cp(o.cp), i(o.i) { ++*cp; } - F& operator=(const F& o) { cp = o.cp; i = o.i; ++*cp; return *this; } + F& operator=(const F& o) { + cp = o.cp; + i = o.i; + ++*cp; + return *this; + } F(F&&) = default; F& operator=(F&&) = default; void operator()() const { ++*i; } @@ -279,7 +283,7 @@ BENCHMARK(BM_AnyCleanup); void BM_AnyCleanupNoFactory(int iters) { while (iters--) { - AnyCleanup fin([]{Incr();}); + AnyCleanup fin([] { Incr(); }); } } BENCHMARK(BM_AnyCleanupNoFactory); diff --git a/tensorflow/core/lib/gtl/inlined_vector.h b/tensorflow/core/lib/gtl/inlined_vector.h index d6e5d9effa794c46b7aa98691bb993dbd7e764c8..6e3cb2206d9658a3b0bc24b506049f503ae304ed 100644 --- a/tensorflow/core/lib/gtl/inlined_vector.h +++ b/tensorflow/core/lib/gtl/inlined_vector.h @@ -31,12 +31,12 @@ limitations under the License. #ifndef TENSORFLOW_LIB_GTL_INLINED_VECTOR_H_ #define TENSORFLOW_LIB_GTL_INLINED_VECTOR_H_ -#include #include #include #include #include #include +#include #include #include #include @@ -407,7 +407,7 @@ class InlinedVector { }; // 2) Construct a T with args at not-yet-initialized memory pointed by dst. struct Construct { - template + template void operator()(T* dst, Args&&... args) const { new (dst) T(std::forward(args)...); } diff --git a/tensorflow/core/lib/gtl/int_type.h b/tensorflow/core/lib/gtl/int_type.h index 647fc81aa7e4925d1d2b74b82146b18b0c17a4a9..af3e50ad78ff9d07bc0e8e79a5ff7cb3d1aacbfe 100644 --- a/tensorflow/core/lib/gtl/int_type.h +++ b/tensorflow/core/lib/gtl/int_type.h @@ -255,13 +255,13 @@ class IntType { value_ op arg_value; \ return *this; \ } - INT_TYPE_ASSIGNMENT_OP(+= ); - INT_TYPE_ASSIGNMENT_OP(-= ); - INT_TYPE_ASSIGNMENT_OP(*= ); - INT_TYPE_ASSIGNMENT_OP(/= ); - INT_TYPE_ASSIGNMENT_OP(<<= ); // NOLINT - INT_TYPE_ASSIGNMENT_OP(>>= ); // NOLINT - INT_TYPE_ASSIGNMENT_OP(%= ); + INT_TYPE_ASSIGNMENT_OP(+=); + INT_TYPE_ASSIGNMENT_OP(-=); + INT_TYPE_ASSIGNMENT_OP(*=); + INT_TYPE_ASSIGNMENT_OP(/=); + INT_TYPE_ASSIGNMENT_OP(<<=); // NOLINT + INT_TYPE_ASSIGNMENT_OP(>>=); // NOLINT + INT_TYPE_ASSIGNMENT_OP(%=); #undef INT_TYPE_ASSIGNMENT_OP ThisType& operator=(ValueType arg_value) { @@ -314,10 +314,10 @@ std::ostream& operator<<(std::ostream& os, // NOLINT INT_TYPE_ARITHMETIC_OP(+); INT_TYPE_ARITHMETIC_OP(-); INT_TYPE_ARITHMETIC_OP(*); -INT_TYPE_ARITHMETIC_OP(/ ); -INT_TYPE_ARITHMETIC_OP(<< ); // NOLINT -INT_TYPE_ARITHMETIC_OP(>> ); // NOLINT -INT_TYPE_ARITHMETIC_OP(% ); +INT_TYPE_ARITHMETIC_OP(/); +INT_TYPE_ARITHMETIC_OP(<<); // NOLINT +INT_TYPE_ARITHMETIC_OP(>>); // NOLINT +INT_TYPE_ARITHMETIC_OP(%); #undef INT_TYPE_ARITHMETIC_OP // -- NON-MEMBER COMPARISON OPERATORS ------------------------------------------ @@ -345,12 +345,12 @@ INT_TYPE_ARITHMETIC_OP(% ); IntType id) { \ return val op id.value(); \ } -INT_TYPE_COMPARISON_OP(== ); // NOLINT -INT_TYPE_COMPARISON_OP(!= ); // NOLINT -INT_TYPE_COMPARISON_OP(< ); // NOLINT -INT_TYPE_COMPARISON_OP(<= ); // NOLINT -INT_TYPE_COMPARISON_OP(> ); // NOLINT -INT_TYPE_COMPARISON_OP(>= ); // NOLINT +INT_TYPE_COMPARISON_OP(==); // NOLINT +INT_TYPE_COMPARISON_OP(!=); // NOLINT +INT_TYPE_COMPARISON_OP(<); // NOLINT +INT_TYPE_COMPARISON_OP(<=); // NOLINT +INT_TYPE_COMPARISON_OP(>); // NOLINT +INT_TYPE_COMPARISON_OP(>=); // NOLINT #undef INT_TYPE_COMPARISON_OP } // namespace gtl diff --git a/tensorflow/core/lib/gtl/int_type_test.cc b/tensorflow/core/lib/gtl/int_type_test.cc index d3c405d9acdb221f465e98d957ba55ba6bc63f57..61d364017cb90933e8e9af7e800db4a6988d8442 100644 --- a/tensorflow/core/lib/gtl/int_type_test.cc +++ b/tensorflow/core/lib/gtl/int_type_test.cc @@ -42,7 +42,8 @@ class IntTypeTest : public ::testing::Test { // All tests below will be executed on all supported IntTypes. typedef ::testing::Types SupportedIntTypes; + Int64_IT, UInt64_IT, Long_IT> + SupportedIntTypes; TYPED_TEST_CASE(IntTypeTest, SupportedIntTypes); @@ -232,7 +233,8 @@ TYPED_TEST(IntTypeTest, TestOperators) { TYPED_TEST(IntTypeTest, TestHashFunctor) { std::unordered_map map; + typename TestFixture::T::Hasher> + map; typename TestFixture::T a(0); map[a] = 'c'; EXPECT_EQ('c', map[a]); diff --git a/tensorflow/core/lib/gtl/optional.h b/tensorflow/core/lib/gtl/optional.h index 2ff8b9c7d1adbbc206e0429142389e9730efa33c..fa33c24c0c006aa5d3fed5102980da865e12696a 100644 --- a/tensorflow/core/lib/gtl/optional.h +++ b/tensorflow/core/lib/gtl/optional.h @@ -593,12 +593,12 @@ class optional : private internal_optional::optional_data, assert(this->engaged_); return this->pointer(); } - constexpr const T& operator*() const & { return reference(); } + constexpr const T& operator*() const& { return reference(); } T& operator*() & { assert(this->engaged_); return reference(); } - constexpr const T&& operator*() const && { return std::move(reference()); } + constexpr const T&& operator*() const&& { return std::move(reference()); } T&& operator*() && { assert(this->engaged_); return std::move(reference()); @@ -621,7 +621,7 @@ class optional : private internal_optional::optional_data, // Use `opt.value()` to get a reference to underlying value. The constness // and lvalue/rvalue-ness of `opt` is preserved to the view of the T // subobject. - const T& value() const & { + const T& value() const& { CHECK(*this) << "Bad optional access"; return reference(); } @@ -633,7 +633,7 @@ class optional : private internal_optional::optional_data, CHECK(*this) << "Bad optional access"; return std::move(reference()); } - const T&& value() const && { // NOLINT(build/c++11) + const T&& value() const&& { // NOLINT(build/c++11) CHECK(*this) << "Bad optional access"; return std::move(reference()); } @@ -641,7 +641,7 @@ class optional : private internal_optional::optional_data, // Use `opt.value_or(val)` to get either the value of T or the given default // `val` in the empty case. template - constexpr T value_or(U&& v) const & { + constexpr T value_or(U&& v) const& { return static_cast(*this) ? **this : static_cast(std::forward(v)); } @@ -656,8 +656,8 @@ class optional : private internal_optional::optional_data, constexpr const T& reference() const { return *this->pointer(); } T& reference() { return *(this->pointer()); } - // T constraint checks. You can't have an optional of nullopt_t, in_place_t or - // a reference. + // T constraint checks. You can't have an optional of nullopt_t, in_place_t + // or a reference. static_assert( !std::is_same::type>::value, "optional is not allowed."); diff --git a/tensorflow/core/lib/gtl/optional_test.cc b/tensorflow/core/lib/gtl/optional_test.cc index 547bee7b75f3d05e290ec7d53d889ff7e82794a9..12b5bbc60be9961a5f852210c42479b2cd48ea92 100644 --- a/tensorflow/core/lib/gtl/optional_test.cc +++ b/tensorflow/core/lib/gtl/optional_test.cc @@ -24,17 +24,29 @@ limitations under the License. namespace tensorflow { namespace { -using tensorflow::gtl::optional; -using tensorflow::gtl::nullopt; -using tensorflow::gtl::nullopt_t; using tensorflow::gtl::in_place; using tensorflow::gtl::in_place_t; using tensorflow::gtl::make_optional; +using tensorflow::gtl::nullopt; +using tensorflow::gtl::nullopt_t; +using tensorflow::gtl::optional; -template string TypeQuals(T&) { return "&"; } -template string TypeQuals(T&&) { return "&&"; } -template string TypeQuals(const T&) { return "c&"; } -template string TypeQuals(const T&&) { return "c&&"; } +template +string TypeQuals(T&) { + return "&"; +} +template +string TypeQuals(T&&) { + return "&&"; +} +template +string TypeQuals(const T&) { + return "c&"; +} +template +string TypeQuals(const T&&) { + return "c&&"; +} struct StructorListener { int construct0 = 0; diff --git a/tensorflow/core/lib/gtl/top_n_test.cc b/tensorflow/core/lib/gtl/top_n_test.cc index fae85570dc071568a53abcb72fea6ffc22a465ea..ba30c072a9033073a7439f60dbfa3402dbfc5923 100644 --- a/tensorflow/core/lib/gtl/top_n_test.cc +++ b/tensorflow/core/lib/gtl/top_n_test.cc @@ -28,10 +28,10 @@ limitations under the License. namespace { +using tensorflow::string; using tensorflow::gtl::TopN; using tensorflow::random::PhiloxRandom; using tensorflow::random::SimplePhilox; -using tensorflow::string; // Move the contents from an owned raw pointer, returning by value. // Objects are easier to manage by value. diff --git a/tensorflow/core/lib/io/compression.cc b/tensorflow/core/lib/io/compression.cc index c12de98e40105907460f74f967e20aa41bdb0ceb..0d25bca9eccf2b28800a288858ffbc0caeb2dbd3 100644 --- a/tensorflow/core/lib/io/compression.cc +++ b/tensorflow/core/lib/io/compression.cc @@ -22,6 +22,6 @@ namespace compression { const char kNone[] = ""; const char kGzip[] = "GZIP"; -} -} -} +} // namespace compression +} // namespace io +} // namespace tensorflow diff --git a/tensorflow/core/lib/io/compression.h b/tensorflow/core/lib/io/compression.h index ef90c60a3a411cdc94a9f92522116db340e04f1b..4d8e7788cad823e0e79a4e9567c6f17a3d9259cf 100644 --- a/tensorflow/core/lib/io/compression.h +++ b/tensorflow/core/lib/io/compression.h @@ -23,8 +23,8 @@ namespace compression { extern const char kNone[]; extern const char kGzip[]; -} -} -} +} // namespace compression +} // namespace io +} // namespace tensorflow #endif // TENSORFLOW_CORE_LIB_IO_COMPRESSION_H_ diff --git a/tensorflow/core/lib/io/record_reader.cc b/tensorflow/core/lib/io/record_reader.cc index 403c82818ef3293a1dc027d362eb766906d0e94a..9cc6c4034f485c497747d102d7d731e5cd68a4d0 100644 --- a/tensorflow/core/lib/io/record_reader.cc +++ b/tensorflow/core/lib/io/record_reader.cc @@ -207,7 +207,7 @@ Status RecordReader::SkipNBytes(uint64 offset) { } } return Status::OK(); -} +} // namespace io SequentialRecordReader::SequentialRecordReader( RandomAccessFile* file, const RecordReaderOptions& options) diff --git a/tensorflow/core/lib/io/recordio_test.cc b/tensorflow/core/lib/io/recordio_test.cc index 507c26a63ff587809e80739f8d015d1adcc3b21d..b7e51256a22b0d84e734e2a036a184b3adc3e547 100644 --- a/tensorflow/core/lib/io/recordio_test.cc +++ b/tensorflow/core/lib/io/recordio_test.cc @@ -218,8 +218,8 @@ TEST_F(RecordioTest, RandomRead) { // Tests of all the error paths in log_reader.cc follow: static void AssertHasSubstr(StringPiece s, StringPiece expected) { - EXPECT_TRUE(StringPiece(s).contains(expected)) << s << " does not contain " - << expected; + EXPECT_TRUE(StringPiece(s).contains(expected)) + << s << " does not contain " << expected; } TEST_F(RecordioTest, ReadError) { diff --git a/tensorflow/core/lib/png/png_io.cc b/tensorflow/core/lib/png/png_io.cc index 354c819b090ce5e04047f13d2ff19441a499d770..77a3414442caa523ab7a92e3e63babf581030287 100644 --- a/tensorflow/core/lib/png/png_io.cc +++ b/tensorflow/core/lib/png/png_io.cc @@ -197,8 +197,8 @@ bool CommonInitDecode(StringPiece png_string, int desired_channels, int desired_channel_bits, DecodeContext* context) { CHECK(desired_channel_bits == 8 || desired_channel_bits == 16) << "desired_channel_bits = " << desired_channel_bits; - CHECK(0 <= desired_channels && desired_channels <= 4) << "desired_channels = " - << desired_channels; + CHECK(0 <= desired_channels && desired_channels <= 4) + << "desired_channels = " << desired_channels; context->error_condition = false; context->channels = desired_channels; context->png_ptr = png_create_read_struct(PNG_LIBPNG_VER_STRING, context, diff --git a/tensorflow/core/lib/random/philox_random_test_utils.h b/tensorflow/core/lib/random/philox_random_test_utils.h index f4bb087e107e10f90196a807c03ed2407d9d1ad6..6c29ae6b6a224d9c0369172bbf21af465ad53a19 100644 --- a/tensorflow/core/lib/random/philox_random_test_utils.h +++ b/tensorflow/core/lib/random/philox_random_test_utils.h @@ -35,8 +35,8 @@ void FillRandoms(PhiloxRandom gen, typename Distribution::ResultElementType* p, int64 size) { const int granularity = Distribution::kResultElementCount; - CHECK(size % granularity == 0) << " size: " << size - << " granularity: " << granularity; + CHECK(size % granularity == 0) + << " size: " << size << " granularity: " << granularity; Distribution dist; for (int i = 0; i < size; i += granularity) { diff --git a/tensorflow/core/lib/random/random_distributions.h b/tensorflow/core/lib/random/random_distributions.h index 0e281403f8748ffbb7dbfac888cd2303c0a7253f..3fe1f9bc6cf06158df4811eaa177988b60890006 100644 --- a/tensorflow/core/lib/random/random_distributions.h +++ b/tensorflow/core/lib/random/random_distributions.h @@ -17,8 +17,8 @@ limitations under the License. #define TENSORFLOW_LIB_RANDOM_RANDOM_DISTRIBUTIONS_H_ #define _USE_MATH_DEFINES -#include #include +#include #undef _USE_MATH_DEFINES #include @@ -27,7 +27,6 @@ limitations under the License. #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/lib/random/philox_random.h" - namespace tensorflow { namespace random { diff --git a/tensorflow/core/lib/random/random_distributions_test.cc b/tensorflow/core/lib/random/random_distributions_test.cc index 90d0dba4a7793f51472b2e5434489448eb40a498..85d68f456e1e27b7a62315f2b0a962843da87d52 100644 --- a/tensorflow/core/lib/random/random_distributions_test.cc +++ b/tensorflow/core/lib/random/random_distributions_test.cc @@ -45,8 +45,8 @@ void FillRandomsWithSingles(PhiloxRandom gen, int64 size) { int granularity = Distribution::kResultElementCount; - CHECK(size % granularity == 0) << " size: " << size - << " granularity: " << granularity; + CHECK(size % granularity == 0) + << " size: " << size << " granularity: " << granularity; SingleSampleAdapter single_samples(&gen); diff --git a/tensorflow/core/lib/strings/ordered_code.cc b/tensorflow/core/lib/strings/ordered_code.cc index af9a15125948d8ed390e5873f3677527ebddea8e..ef90050b4f628ab65c1dd939ba358fec714c95b5 100644 --- a/tensorflow/core/lib/strings/ordered_code.cc +++ b/tensorflow/core/lib/strings/ordered_code.cc @@ -472,7 +472,8 @@ void OrderedCode::WriteSignedNumIncreasing(string* dest, int64 val) { // buf = val in network byte order, sign extended to 10 bytes const char sign_byte = val < 0 ? '\xff' : '\0'; char buf[10] = { - sign_byte, sign_byte, + sign_byte, + sign_byte, }; StoreBigEndian64(buf + 2, val); static_assert(sizeof(buf) == kMaxSigned64Length, "max length size mismatch"); diff --git a/tensorflow/core/lib/strings/strcat.h b/tensorflow/core/lib/strings/strcat.h index 5835b0101d9ede219a71acf554c5928e4b624ce7..2bc14945cd0413751003c03c7f5255c300790321 100644 --- a/tensorflow/core/lib/strings/strcat.h +++ b/tensorflow/core/lib/strings/strcat.h @@ -126,7 +126,7 @@ class AlphaNum { : piece_(digits_, strlen(DoubleToBuffer(f, digits_))) {} AlphaNum(const Eigen::half &f); // NOLINT(runtime/explicit) - AlphaNum(Hex hex); // NOLINT(runtime/explicit) + AlphaNum(Hex hex); // NOLINT(runtime/explicit) AlphaNum(const char *c_str) : piece_(c_str) {} // NOLINT(runtime/explicit) AlphaNum(const StringPiece &pc) : piece_(pc) {} // NOLINT(runtime/explicit) diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc index 5ec2a4e9b4c43451637610797dae0bd8d189f8d1..267ce88440080399aae783903503f0bbd025d8b4 100644 --- a/tensorflow/core/ops/array_ops.cc +++ b/tensorflow/core/ops/array_ops.cc @@ -708,10 +708,11 @@ REGISTER_OP("MatrixDiagPart") // -------------------------------------------------------------------------- REGISTER_OP("MatrixBandPart") .Input("input: T") - .Input("num_lower: int64") - .Input("num_upper: int64") + .Input("num_lower: Tindex") + .Input("num_upper: Tindex") .Output("band: T") .Attr("T: type") + .Attr("Tindex: {int32, int64} = DT_INT64") .SetShapeFn(shape_inference::UnchangedShape); // -------------------------------------------------------------------------- diff --git a/tensorflow/core/ops/compat/backwards_compatibility_test.cc b/tensorflow/core/ops/compat/backwards_compatibility_test.cc index add05d6610ae62158b653d27699f61bc511ee3b6..6e05ae4be4fb967ac8dcc5a03fa548c7cb6c0f9b 100644 --- a/tensorflow/core/ops/compat/backwards_compatibility_test.cc +++ b/tensorflow/core/ops/compat/backwards_compatibility_test.cc @@ -25,8 +25,9 @@ namespace tensorflow { namespace { TEST(BackwardsCompatibilityTest, IsCompatible) { - OpCompatibilityLib compatibility( - "tensorflow/core/ops", strings::StrCat("v", TF_MAJOR_VERSION), nullptr); + OpCompatibilityLib compatibility("tensorflow/core/ops", + strings::StrCat("v", TF_MAJOR_VERSION), + nullptr); Env* env = Env::Default(); int changed_ops = 0; diff --git a/tensorflow/core/ops/compat/ops_history.v1.pbtxt b/tensorflow/core/ops/compat/ops_history.v1.pbtxt index 65ab81931ad4261f432034f73269d1e8c8005384..2580eaf987bc7dc6413296f0537a2e71f38921ee 100644 --- a/tensorflow/core/ops/compat/ops_history.v1.pbtxt +++ b/tensorflow/core/ops/compat/ops_history.v1.pbtxt @@ -17136,6 +17136,24 @@ op { type: DT_STRING } } +op { + name: "EnqueueInQueueDataset" + input_arg { + name: "queue" + type: DT_VARIANT + } + input_arg { + name: "components" + type_list_attr: "Tcomponents" + } + attr { + name: "Tcomponents" + type: "list(type)" + has_minimum: true + minimum: 1 + } + is_stateful: true +} op { name: "Enter" input_arg { @@ -24840,6 +24858,42 @@ op { type: "type" } } +op { + name: "MatrixBandPart" + input_arg { + name: "input" + type_attr: "T" + } + input_arg { + name: "num_lower" + type_attr: "Tindex" + } + input_arg { + name: "num_upper" + type_attr: "Tindex" + } + output_arg { + name: "band" + type_attr: "T" + } + attr { + name: "T" + type: "type" + } + attr { + name: "Tindex" + type: "type" + default_value { + type: DT_INT64 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } +} op { name: "MatrixDeterminant" input_arg { @@ -32096,6 +32150,48 @@ op { minimum: 1 } } +op { + name: "PrependFromQueueAndPaddedBatchDataset" + input_arg { + name: "input_dataset" + type: DT_VARIANT + } + input_arg { + name: "batch_size" + type: DT_INT64 + } + input_arg { + name: "padded_shapes" + type: DT_INT64 + number_attr: "N" + } + input_arg { + name: "padding_values" + type_list_attr: "Toutput_types" + } + output_arg { + name: "handle" + type: DT_VARIANT + } + attr { + name: "Toutput_types" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "output_shapes" + type: "list(shape)" + has_minimum: true + minimum: 1 + } + attr { + name: "N" + type: "int" + has_minimum: true + minimum: 1 + } +} op { name: "PreventGradient" input_arg { @@ -42820,6 +42916,36 @@ op { } is_stateful: true } +op { + name: "ResourceScatterUpdate" + input_arg { + name: "resource" + type: DT_RESOURCE + } + input_arg { + name: "indices" + type_attr: "Tindices" + } + input_arg { + name: "updates" + type_attr: "dtype" + } + attr { + name: "dtype" + type: "type" + } + attr { + name: "Tindices" + type: "type" + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + is_stateful: true +} op { name: "ResourceSparseApplyAdadelta" input_arg { diff --git a/tensorflow/core/ops/dataset_ops.cc b/tensorflow/core/ops/dataset_ops.cc index 2cae814eab1602e72ffcfd100f9813f8f41c6ac9..3c8e9a8a5f2e1e0d1b26da7580fe2e5e0d1771dd 100644 --- a/tensorflow/core/ops/dataset_ops.cc +++ b/tensorflow/core/ops/dataset_ops.cc @@ -491,4 +491,29 @@ REGISTER_OP("StatsAggregatorSummary") .Output("summary: string") .SetShapeFn(shape_inference::ScalarShape); +REGISTER_OP("PrependFromQueueAndPaddedBatchDataset") + .Input("input_dataset: variant") + .Input("batch_size: int64") + .Input("padded_shapes: N * int64") + .Input("padding_values: Toutput_types") + .Output("handle: variant") + .Attr("Toutput_types: list(type) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .Attr("N: int >= 1") + // TODO(ebrevdo): Validate that `padded_shapes` are all vectors, the lengths + // of `Toutput_types` and `output_shapes` are `N`, that the + // length of `output_types` is `N`, the `output_shapes` are + // (as far as possible to tell statically) compatible with `padded_shapes`, + // and that `padding_values` are all scalars. + .SetShapeFn(shape_inference::ScalarShape); + +REGISTER_OP("EnqueueInQueueDataset") + .Input("queue: variant") + .Input("components: Tcomponents") + .Attr("Tcomponents: list(type) >= 1") + .SetIsStateful() // To avoid CSE on multiple calls to Enqueue. + // TODO(ebrevdo): SetShapeFn to test input dtypes and shapes by + // reading from queue handle (is that even possible?). + .SetShapeFn(shape_inference::NoOutputs); + } // namespace tensorflow diff --git a/tensorflow/core/ops/image_ops.cc b/tensorflow/core/ops/image_ops.cc index ef2ac267cc96419d5e78eae8e296aa150a4927ca..a62e2d782b8d542b98494cc42ccf6f86d295efd0 100644 --- a/tensorflow/core/ops/image_ops.cc +++ b/tensorflow/core/ops/image_ops.cc @@ -586,6 +586,17 @@ REGISTER_OP("NonMaxSuppression") .Output("selected_indices: int32") .Attr("iou_threshold: float = 0.5") .SetShapeFn([](InferenceContext* c) { + // Get inputs and validate ranks. + ShapeHandle boxes; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &boxes)); + ShapeHandle scores; + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &scores)); + ShapeHandle max_output_size; + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &max_output_size)); + // The boxes is a 2-D float Tensor of shape [num_boxes, 4]. + DimensionHandle unused; + TF_RETURN_IF_ERROR(c->WithValue(c->Dim(boxes, 1), 4, &unused)); + c->set_output(0, c->Vector(c->UnknownDim())); return Status::OK(); }); @@ -597,6 +608,19 @@ REGISTER_OP("NonMaxSuppressionV2") .Input("iou_threshold: float") .Output("selected_indices: int32") .SetShapeFn([](InferenceContext* c) { + // Get inputs and validate ranks. + ShapeHandle boxes; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &boxes)); + ShapeHandle scores; + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &scores)); + ShapeHandle max_output_size; + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &max_output_size)); + ShapeHandle iou_threshold; + TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &iou_threshold)); + // The boxes is a 2-D float Tensor of shape [num_boxes, 4]. + DimensionHandle unused; + TF_RETURN_IF_ERROR(c->WithValue(c->Dim(boxes, 1), 4, &unused)); + c->set_output(0, c->Vector(c->UnknownDim())); return Status::OK(); }); diff --git a/tensorflow/core/ops/lookup_ops.cc b/tensorflow/core/ops/lookup_ops.cc index a67267418d608e7c824030225f906b010794a160..444aa8b9544c62d81f288f21e4eaaac23d8691cb 100644 --- a/tensorflow/core/ops/lookup_ops.cc +++ b/tensorflow/core/ops/lookup_ops.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/dataset_stateful_op_whitelist.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_def_builder.h" #include "tensorflow/core/framework/shape_inference.h" @@ -102,6 +103,8 @@ REGISTER_OP("LookupTableFindV2") c->set_output(0, c->UnknownShape()); return Status::OK(); }); +WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS("LookupTableFindV2"); +// TODO(b/72710477): Update this. REGISTER_OP("LookupTableInsert") .Input("table_handle: Ref(string)") diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc index 872ebe98c1f331ca882480282d3f8eecf4ce5f2d..8f33d51d5a20fc207102e4bf79e7605d9817eb9f 100644 --- a/tensorflow/core/ops/math_ops.cc +++ b/tensorflow/core/ops/math_ops.cc @@ -1065,6 +1065,26 @@ REGISTER_OP("UnsortedSegmentMax") .Attr("Tnumsegments: {int32,int64} = DT_INT32") .SetShapeFn(UnsortedSegmentReductionShapeFn); +REGISTER_OP("UnsortedSegmentMin") + .Input("data: T") + .Input("segment_ids: Tindices") + .Input("num_segments: Tnumsegments") + .Output("output: T") + .Attr("T: realnumbertype") + .Attr("Tindices: {int32,int64}") + .Attr("Tnumsegments: {int32,int64} = DT_INT32") + .SetShapeFn(UnsortedSegmentReductionShapeFn); + +REGISTER_OP("UnsortedSegmentProd") + .Input("data: T") + .Input("segment_ids: Tindices") + .Input("num_segments: Tnumsegments") + .Output("output: T") + .Attr("T: realnumbertype") + .Attr("Tindices: {int32,int64}") + .Attr("Tnumsegments: {int32,int64} = DT_INT32") + .SetShapeFn(UnsortedSegmentReductionShapeFn); + REGISTER_OP("SparseSegmentSum") .Input("data: T") .Input("indices: Tidx") diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt index b57206c9c4f53fbf73537f466206f5c1b0caefcb..8df126735b54bee779e5a7c6dbc09b074b6c8a90 100644 --- a/tensorflow/core/ops/ops.pbtxt +++ b/tensorflow/core/ops/ops.pbtxt @@ -7644,6 +7644,24 @@ op { type: DT_STRING } } +op { + name: "EnqueueInQueueDataset" + input_arg { + name: "queue" + type: DT_VARIANT + } + input_arg { + name: "components" + type_list_attr: "Tcomponents" + } + attr { + name: "Tcomponents" + type: "list(type)" + has_minimum: true + minimum: 1 + } + is_stateful: true +} op { name: "Enter" input_arg { @@ -12330,11 +12348,11 @@ op { } input_arg { name: "num_lower" - type: DT_INT64 + type_attr: "Tindex" } input_arg { name: "num_upper" - type: DT_INT64 + type_attr: "Tindex" } output_arg { name: "band" @@ -12344,6 +12362,19 @@ op { name: "T" type: "type" } + attr { + name: "Tindex" + type: "type" + default_value { + type: DT_INT64 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } } op { name: "MatrixDeterminant" @@ -15926,6 +15957,48 @@ op { minimum: 1 } } +op { + name: "PrependFromQueueAndPaddedBatchDataset" + input_arg { + name: "input_dataset" + type: DT_VARIANT + } + input_arg { + name: "batch_size" + type: DT_INT64 + } + input_arg { + name: "padded_shapes" + type: DT_INT64 + number_attr: "N" + } + input_arg { + name: "padding_values" + type_list_attr: "Toutput_types" + } + output_arg { + name: "handle" + type: DT_VARIANT + } + attr { + name: "Toutput_types" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "output_shapes" + type: "list(shape)" + has_minimum: true + minimum: 1 + } + attr { + name: "N" + type: "int" + has_minimum: true + minimum: 1 + } +} op { name: "PreventGradient" input_arg { @@ -20925,27 +20998,6 @@ op { attr { name: "dtype" type: "type" - allowed_values { - list { - type: DT_FLOAT - type: DT_DOUBLE - type: DT_INT32 - type: DT_UINT8 - type: DT_INT16 - type: DT_INT8 - type: DT_COMPLEX64 - type: DT_INT64 - type: DT_QINT8 - type: DT_QUINT8 - type: DT_QINT32 - type: DT_BFLOAT16 - type: DT_UINT16 - type: DT_COMPLEX128 - type: DT_HALF - type: DT_UINT32 - type: DT_UINT64 - } - } } attr { name: "Tindices" diff --git a/tensorflow/core/ops/resource_variable_ops.cc b/tensorflow/core/ops/resource_variable_ops.cc index f6cfbf873a024e3a035842468fc5ccca2d341ce7..8dae7e1ff5f872c33dd56509c0349180cec78593 100644 --- a/tensorflow/core/ops/resource_variable_ops.cc +++ b/tensorflow/core/ops/resource_variable_ops.cc @@ -193,7 +193,7 @@ REGISTER_OP("ResourceScatterUpdate") .Input("resource: resource") .Input("indices: Tindices") .Input("updates: dtype") - .Attr("dtype: numbertype") + .Attr("dtype: type") .Attr("Tindices: {int32, int64}") .SetShapeFn([](InferenceContext* c) { ShapeAndType handle_shape_and_type; diff --git a/tensorflow/core/platform/cloud/BUILD b/tensorflow/core/platform/cloud/BUILD index 07aecf848326b23b18b58ae60e896150ab7b4ef9..9ba25dea4fb278cbfaf4080e21beef8a3e9de769 100644 --- a/tensorflow/core/platform/cloud/BUILD +++ b/tensorflow/core/platform/cloud/BUILD @@ -57,6 +57,17 @@ cc_library( ], ) +cc_library( + name = "gcs_throttle", + srcs = ["gcs_throttle.cc"], + hdrs = ["gcs_throttle.h"], + copts = tf_copts(), + visibility = ["//tensorflow:__subpackages__"], + deps = [ + "//tensorflow/core:lib", + ], +) + cc_library( name = "gcs_file_system", srcs = ["gcs_file_system.cc"], @@ -69,6 +80,7 @@ cc_library( ":expiring_lru_cache", ":file_block_cache", ":gcs_dns_cache", + ":gcs_throttle", ":google_auth_provider", ":http_request", ":retrying_file_system", @@ -271,6 +283,19 @@ tf_cc_test( ], ) +tf_cc_test( + name = "gcs_throttle_test", + size = "small", + srcs = ["gcs_throttle_test.cc"], + linkopts = if_windows(["-DEFAULTLIB:ws2_32.lib"]), + deps = [ + ":gcs_throttle", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + tf_cc_test( name = "curl_http_request_test", size = "small", diff --git a/tensorflow/core/platform/cloud/gcs_dns_cache.cc b/tensorflow/core/platform/cloud/gcs_dns_cache.cc index 2b0e55bf371da9660f1422cef97e3ec1a25a9b61..4d9aff4d24f06c7bd1269ad590c9687092a5b132 100644 --- a/tensorflow/core/platform/cloud/gcs_dns_cache.cc +++ b/tensorflow/core/platform/cloud/gcs_dns_cache.cc @@ -18,9 +18,9 @@ limitations under the License. #include #include #else +#include #include #include -#include #endif #include diff --git a/tensorflow/core/platform/cloud/gcs_file_system.cc b/tensorflow/core/platform/cloud/gcs_file_system.cc index 91d381bd6f72ebcb40d85b5f15d5bc568a4ff03f..01ca0d76bab2720513775ef33ff8670bd148c241 100644 --- a/tensorflow/core/platform/cloud/gcs_file_system.cc +++ b/tensorflow/core/platform/cloud/gcs_file_system.cc @@ -116,6 +116,15 @@ constexpr char kWriteRequestTimeout[] = "GCS_WRITE_REQUEST_TIMEOUT_SECS"; // The environment variable to configure an additional header to send with // all requests to GCS (format HEADERNAME:HEADERCONTENT) constexpr char kAdditionalRequestHeader[] = "GCS_ADDITIONAL_REQUEST_HEADER"; +// The environment variable to configure the throttle (format: ) +constexpr char kThrottleRate[] = "GCS_THROTTLE_TOKEN_RATE"; +// The environment variable to configure the token bucket size (format: ) +constexpr char kThrottleBucket[] = "GCS_THROTTLE_BUCKET_SIZE"; +// The environment variable that controls the number of tokens per request. +// (format: ) +constexpr char kTokensPerRequest[] = "GCS_TOKENS_PER_REQUEST"; +// The environment variable to configure the initial tokens (format: ) +constexpr char kInitialTokens[] = "GCS_INITIAL_TOKENS"; // TODO: DO NOT use a hardcoded path Status GetTmpFilename(string* filename) { @@ -721,6 +730,26 @@ GcsFileSystem::GcsFileSystem() if (GetEnvVar(kWriteRequestTimeout, strings::safe_strtou32, &timeout_value)) { timeouts_.write = timeout_value; } + + int64 token_value; + if (GetEnvVar(kThrottleRate, strings::safe_strto64, &token_value)) { + GcsThrottleConfig config; + config.enabled = true; + config.token_rate = token_value; + + if (GetEnvVar(kThrottleBucket, strings::safe_strto64, &token_value)) { + config.bucket_size = token_value; + } + + if (GetEnvVar(kTokensPerRequest, strings::safe_strto64, &token_value)) { + config.tokens_per_request = token_value; + } + + if (GetEnvVar(kInitialTokens, strings::safe_strto64, &token_value)) { + config.initial_tokens = token_value; + } + throttle_.SetConfig(config); + } } GcsFileSystem::GcsFileSystem( @@ -774,7 +803,9 @@ Status GcsFileSystem::LoadBufferFromGCS(const string& filename, size_t offset, TF_RETURN_IF_ERROR(ParseGcsPath(filename, false, &bucket, &object)); std::unique_ptr request; - TF_RETURN_IF_ERROR(CreateHttpRequest(&request)); + TF_RETURN_WITH_CONTEXT_IF_ERROR(CreateHttpRequest(&request), + "when reading gs://", bucket, "/", object); + request->SetUri(strings::StrCat("https://", kStorageHost, "/", bucket, "/", request->EscapeString(object))); request->SetRange(offset, offset + n - 1); @@ -789,6 +820,8 @@ Status GcsFileSystem::LoadBufferFromGCS(const string& filename, size_t offset, VLOG(1) << "Successful read of gs://" << bucket << "/" << object << " @ " << offset << " of size: " << bytes_read; + throttle_.RecordResponse(bytes_read); + if (bytes_read < block_size()) { // Check stat cache to see if we encountered an interrupted read. FileStatistics stat; @@ -926,41 +959,43 @@ Status GcsFileSystem::StatForObject(const string& fname, const string& bucket, "'object' must be a non-empty string. (File: %s)", fname.c_str())); } - StatCache::ComputeFunc compute_func = - [this, &bucket, &object](const string& fname, FileStatistics* stat) { - std::vector output_buffer; - std::unique_ptr request; - TF_RETURN_IF_ERROR(CreateHttpRequest(&request)); - request->SetUri(strings::StrCat(kGcsUriBase, "b/", bucket, "/o/", - request->EscapeString(object), - "?fields=size%2Cupdated")); - request->SetResultBuffer(&output_buffer); - request->SetTimeouts(timeouts_.connect, timeouts_.idle, - timeouts_.metadata); + StatCache::ComputeFunc compute_func = [this, &bucket, &object]( + const string& fname, + FileStatistics* stat) { + std::vector output_buffer; + std::unique_ptr request; + TF_RETURN_WITH_CONTEXT_IF_ERROR(CreateHttpRequest(&request), + " when reading metadata of gs://", bucket, + "/", object); + + request->SetUri(strings::StrCat(kGcsUriBase, "b/", bucket, "/o/", + request->EscapeString(object), + "?fields=size%2Cupdated")); + request->SetResultBuffer(&output_buffer); + request->SetTimeouts(timeouts_.connect, timeouts_.idle, timeouts_.metadata); - TF_RETURN_WITH_CONTEXT_IF_ERROR(request->Send(), - " when reading metadata of gs://", - bucket, "/", object); + TF_RETURN_WITH_CONTEXT_IF_ERROR(request->Send(), + " when reading metadata of gs://", bucket, + "/", object); - Json::Value root; - TF_RETURN_IF_ERROR(ParseJson(output_buffer, &root)); + Json::Value root; + TF_RETURN_IF_ERROR(ParseJson(output_buffer, &root)); - // Parse file size. - TF_RETURN_IF_ERROR(GetInt64Value(root, "size", &stat->length)); + // Parse file size. + TF_RETURN_IF_ERROR(GetInt64Value(root, "size", &stat->length)); - // Parse file modification time. - string updated; - TF_RETURN_IF_ERROR(GetStringValue(root, "updated", &updated)); - TF_RETURN_IF_ERROR(ParseRfc3339Time(updated, &(stat->mtime_nsec))); + // Parse file modification time. + string updated; + TF_RETURN_IF_ERROR(GetStringValue(root, "updated", &updated)); + TF_RETURN_IF_ERROR(ParseRfc3339Time(updated, &(stat->mtime_nsec))); - VLOG(1) << "Stat of: gs://" << bucket << "/" << object << " -- " - << " length: " << stat->length - << "; mtime_nsec: " << stat->mtime_nsec - << "; updated: " << updated; + VLOG(1) << "Stat of: gs://" << bucket << "/" << object << " -- " + << " length: " << stat->length + << "; mtime_nsec: " << stat->mtime_nsec << "; updated: " << updated; - stat->is_directory = false; - return Status::OK(); - }; + stat->is_directory = false; + return Status::OK(); + }; TF_RETURN_IF_ERROR(stat_cache_->LookupOrCompute(fname, stat, compute_func)); if (stat->is_directory) { @@ -1438,6 +1473,10 @@ Status GcsFileSystem::CreateHttpRequest(std::unique_ptr* request) { additional_header_->second); } + if (!throttle_.AdmitRequest()) { + return errors::Unavailable("Request throttled"); + } + *request = std::move(new_request); return Status::OK(); } diff --git a/tensorflow/core/platform/cloud/gcs_file_system.h b/tensorflow/core/platform/cloud/gcs_file_system.h index 2eae39608e38184450290e86bc12d81494bb8302..e8edde8a445aad4c0310394d89480dc6ae445dfa 100644 --- a/tensorflow/core/platform/cloud/gcs_file_system.h +++ b/tensorflow/core/platform/cloud/gcs_file_system.h @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/core/platform/cloud/expiring_lru_cache.h" #include "tensorflow/core/platform/cloud/file_block_cache.h" #include "tensorflow/core/platform/cloud/gcs_dns_cache.h" +#include "tensorflow/core/platform/cloud/gcs_throttle.h" #include "tensorflow/core/platform/cloud/http_request.h" #include "tensorflow/core/platform/cloud/retrying_file_system.h" #include "tensorflow/core/platform/file_system.h" @@ -194,6 +195,7 @@ class GcsFileSystem : public FileSystem { std::unique_ptr http_request_factory_; std::unique_ptr file_block_cache_; std::unique_ptr dns_cache_; + GcsThrottle throttle_; using StatCache = ExpiringLRUCache; std::unique_ptr stat_cache_; diff --git a/tensorflow/core/platform/cloud/gcs_throttle.cc b/tensorflow/core/platform/cloud/gcs_throttle.cc new file mode 100644 index 0000000000000000000000000000000000000000..eb5f8958a37f45aeac1a836ca037f91931bb34a6 --- /dev/null +++ b/tensorflow/core/platform/cloud/gcs_throttle.cc @@ -0,0 +1,62 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/platform/cloud/gcs_throttle.h" + +#include + +namespace tensorflow { + +GcsThrottle::GcsThrottle(EnvTime* env_time) + : last_updated_secs_(env_time->NowSeconds()), + available_tokens_(0), + env_time_(env_time) {} + +bool GcsThrottle::AdmitRequest() { + mutex_lock l(mu_); + if (!config_.enabled) return true; + UpdateState(); + if (available_tokens_ < config_.tokens_per_request) { + return false; + } + available_tokens_ -= config_.tokens_per_request; + return true; +} + +void GcsThrottle::RecordResponse(size_t num_bytes) { + mutex_lock l(mu_); + if (!config_.enabled) return; + UpdateState(); + available_tokens_ -= request_bytes_to_tokens(num_bytes); +} + +void GcsThrottle::SetConfig(GcsThrottleConfig config) { + mutex_lock l(mu_); + config_ = config; + available_tokens_ = config.initial_tokens; + last_updated_secs_ = env_time_->NowSeconds(); +} + +void GcsThrottle::UpdateState() { + // TODO(b/72643279): Switch to a monotonic clock. + int64 now = env_time_->NowSeconds(); + uint64 delta_secs = + std::max(0LL, now - static_cast(last_updated_secs_)); + available_tokens_ += delta_secs * config_.token_rate; + available_tokens_ = std::min(available_tokens_, config_.bucket_size); + last_updated_secs_ = now; +} + +} // namespace tensorflow diff --git a/tensorflow/core/platform/cloud/gcs_throttle.h b/tensorflow/core/platform/cloud/gcs_throttle.h new file mode 100644 index 0000000000000000000000000000000000000000..1a89daef084e921f1ad8bd856cefcc62d0d7aa1c --- /dev/null +++ b/tensorflow/core/platform/cloud/gcs_throttle.h @@ -0,0 +1,156 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_CLOUD_GCS_THROTTLE_H_ +#define TENSORFLOW_CORE_PLATFORM_CLOUD_GCS_THROTTLE_H_ + +#include "tensorflow/core/platform/env.h" + +namespace tensorflow { + +/** + * GcsThrottleConfig is used to configure the GcsThrottle. + */ +struct GcsThrottleConfig { + /** + * enabled is true if GcsThrottle should throttle requests, false otherwise. + */ + bool enabled = false; + + /** + * token_rate is the number of tokens accrued every second that can be used + * for making requests to the GCS service. + */ + int64 token_rate = 100000; // Approximately 800 MBits/second bandwidth-only. + + /** + * bucket_size is the maximum number of available tokens the GcsThrottle can + * accrue. + */ + int64 bucket_size = 10000000; // 10 million tokens total + + /** + * tokens_per_request determines the number of tokens consumed for every + * request. + * + * Note: tokens are also consumed in proportion to the response size. + */ + int64 tokens_per_request = 100; + + /** + * initial_tokens determines how many tokens should be available immediately + * after the GcsThrottle is constructed. + */ + int64 initial_tokens = 0; +}; + +/** + * GcsThrottle is used to ensure fair use of the available GCS capacity. + * + * GcsThrottle operates around a concept of tokens. Tokens are consumed when + * making requests to the GCS service. Tokens are consumed both based on the + * number of requests made, as well as the bandwidth consumed (response sizes). + * + * GcsThrottle is thread safe and can be used from multiple threads. + */ +class GcsThrottle { + public: + /** + * Constructs a GcsThrottle. + */ + explicit GcsThrottle(EnvTime* env_time = EnvTime::Default()); + + /** + * AdmitRequest updates the GcsThrottle to record a request will be made. + * + * AdmitRequest should be called before any request is made. AdmitRequest + * returns false if the request should be denied. If AdmitRequest + * returns false, no tokens are consumed. If true is returned, the configured + * number of tokens are consumed. + */ + bool AdmitRequest(); + + /** + * RecordResponse updates the GcsThrottle to record a request has been made. + * + * RecordResponse should be called after the response has been received. + * RecordResponse will update the internal state based on the number of bytes + * in the response. + * + * Note: we split up the request and the response in this fashion in order to + * avoid penalizing consumers who are using large readahead buffers at higher + * layers of the I/O stack. + */ + void RecordResponse(size_t num_bytes); + + /** + * SetConfig sets the configuration for GcsThrottle and re-initializes state. + * + * After calling this, the token pool will be config.initial_tokens. + */ + void SetConfig(GcsThrottleConfig config); + + /** + * available_tokens gives a snapshot of how many tokens are available. + * + * The returned value should not be used to make admission decisions. The + * purpose of this function is to make available to monitoring or other + * instrumentation the number of available tokens in the pool. + */ + inline int64 available_tokens() { + mutex_lock l(mu_); + if (!config_.enabled) return 0; + UpdateState(); + return available_tokens_; + } + + private: + /** + * UpdateState updates the available_tokens_ and last_updated_secs_ variables. + * + * UpdateState should be called in order to mark the passage of time, and + * therefore add tokens to the availble_tokens_ pool. + */ + void UpdateState() EXCLUSIVE_LOCKS_REQUIRED(mu_); + + inline uint64 request_bytes_to_tokens(size_t num_bytes) { + return num_bytes >> 10; + } + + mutex mu_; + + /** + * last_updated_secs_ records the number of seconds since the Unix epoch that + * the internal state of the GcsThrottle was updated. This is important when + * determining the number of tokens to add to the available_tokens_ pool. + */ + uint64 last_updated_secs_ GUARDED_BY(mu_) = 0; + + /** + * available_tokens_ records how many tokens are available to be consumed. + * + * Note: it is possible for available_tokens_ to become negative. If a + * response comes back that consumes more than the available tokens, the count + * will go negative, and block future requests until we have available tokens. + */ + int64 available_tokens_ GUARDED_BY(mu_) = 0; + + EnvTime* const env_time_; + GcsThrottleConfig config_ GUARDED_BY(mu_); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PLATFORM_CLOUD_GCS_THROTTLE_H_ diff --git a/tensorflow/core/platform/cloud/gcs_throttle_test.cc b/tensorflow/core/platform/cloud/gcs_throttle_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..694756022e37263a07f8215bf7496c9ca130fd58 --- /dev/null +++ b/tensorflow/core/platform/cloud/gcs_throttle_test.cc @@ -0,0 +1,101 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/platform/cloud/gcs_throttle.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { + +namespace { + +class TestTime : public EnvTime { + public: + uint64 NowMicros() override { return now_; } + + void SetTime(uint64 now_micros) { now_ = now_micros; } + + void AdvanceSeconds(int64 secs) { now_ += secs * 1000000L; } + + private: + uint64 now_ = 1234567890000000ULL; +}; + +class GcsThrottleTest : public ::testing::Test { + protected: + GcsThrottleTest() : throttle_(&time_) { + config_.enabled = true; + throttle_.SetConfig(config_); + } + + GcsThrottleConfig config_; + TestTime time_; + GcsThrottle throttle_; +}; + +TEST_F(GcsThrottleTest, ReplenishTokens) { + EXPECT_EQ(0, throttle_.available_tokens()); + time_.AdvanceSeconds(1); + EXPECT_EQ(100000, throttle_.available_tokens()); + time_.AdvanceSeconds(2); + EXPECT_EQ(300000, throttle_.available_tokens()); +} + +TEST_F(GcsThrottleTest, RejectRequest) { + EXPECT_EQ(0, throttle_.available_tokens()); + time_.AdvanceSeconds(1); + EXPECT_TRUE(throttle_.AdmitRequest()); + EXPECT_EQ(99900, throttle_.available_tokens()); + for (int i = 1; i < 1000; i++) { + EXPECT_TRUE(throttle_.AdmitRequest()); + } + EXPECT_FALSE(throttle_.AdmitRequest()); +} + +TEST_F(GcsThrottleTest, MarkResponses) { + time_.AdvanceSeconds(1); + EXPECT_TRUE(throttle_.AdmitRequest()); + throttle_.RecordResponse(128000000); // 128 MB response + EXPECT_EQ(-25100, throttle_.available_tokens()); + EXPECT_FALSE(throttle_.AdmitRequest()); + time_.AdvanceSeconds(1); + EXPECT_TRUE(throttle_.AdmitRequest()) + << "Available tokens: " << throttle_.available_tokens(); +} + +TEST_F(GcsThrottleTest, Skippingtime_) { + EXPECT_EQ(0, throttle_.available_tokens()); + time_.AdvanceSeconds(90); + EXPECT_EQ(9000000, throttle_.available_tokens()); +} + +TEST_F(GcsThrottleTest, BucketLimit) { + time_.AdvanceSeconds(120); + EXPECT_EQ(10000000, throttle_.available_tokens()); +} + +TEST_F(GcsThrottleTest, ReverseTime) { + time_.AdvanceSeconds(1); + EXPECT_EQ(100000, throttle_.available_tokens()); + time_.AdvanceSeconds(-3600); + EXPECT_EQ(100000, throttle_.available_tokens()); + time_.AdvanceSeconds(1); + EXPECT_EQ(200000, throttle_.available_tokens()); +} + +} // namespace + +} // namespace tensorflow diff --git a/tensorflow/core/platform/cloud/http_request_fake.h b/tensorflow/core/platform/cloud/http_request_fake.h index 682b97f6ec6d697bef2ef6301a39be35c95c5861..7711eaceb290fb21c54c9656c473d912ebbd84cf 100644 --- a/tensorflow/core/platform/cloud/http_request_fake.h +++ b/tensorflow/core/platform/cloud/http_request_fake.h @@ -38,8 +38,7 @@ class FakeHttpRequest : public CurlHttpRequest { public: /// Return the response for the given request. FakeHttpRequest(const string& request, const string& response) - : FakeHttpRequest(request, response, Status::OK(), nullptr, {}, 200) { - } + : FakeHttpRequest(request, response, Status::OK(), nullptr, {}, 200) {} /// Return the response with headers for the given request. FakeHttpRequest(const string& request, const string& response, diff --git a/tensorflow/core/platform/cloud/oauth_client_test.cc b/tensorflow/core/platform/cloud/oauth_client_test.cc index 236259dbc16ffc806779bd100e1ec6ace2b7bb39..ad569758cc6ec11555a81a3bc7fbefbc580d6529 100644 --- a/tensorflow/core/platform/cloud/oauth_client_test.cc +++ b/tensorflow/core/platform/cloud/oauth_client_test.cc @@ -160,12 +160,12 @@ TEST(OAuthClientTest, GetTokenFromServiceAccountJson) { ASSERT_EQ(1, EVP_DigestVerifyInit(md_ctx, nullptr, md, nullptr, key)); ASSERT_EQ(1, EVP_DigestVerifyUpdate(md_ctx, header_dot_claim.c_str(), header_dot_claim.size())); - ASSERT_EQ( - 1, - EVP_DigestVerifyFinal( - md_ctx, const_cast( - reinterpret_cast(signature.data())), - signature.size())); + ASSERT_EQ(1, + EVP_DigestVerifyFinal( + md_ctx, + const_cast( + reinterpret_cast(signature.data())), + signature.size())); EVP_MD_CTX_cleanup(md_ctx); // Free all the crypto-related resources. diff --git a/tensorflow/core/platform/cloud/retrying_file_system.cc b/tensorflow/core/platform/cloud/retrying_file_system.cc index c3b6831361305f69e8a9882dbff90ce139ca13c0..be9ebe67b18e7be76e95149258cb1fcce6047d85 100644 --- a/tensorflow/core/platform/cloud/retrying_file_system.cc +++ b/tensorflow/core/platform/cloud/retrying_file_system.cc @@ -25,7 +25,6 @@ namespace tensorflow { namespace { - class RetryingRandomAccessFile : public RandomAccessFile { public: RetryingRandomAccessFile(std::unique_ptr base_file, @@ -203,4 +202,6 @@ Status RetryingFileSystem::DeleteRecursively(const string& dirname, initial_delay_microseconds_); } +void RetryingFileSystem::FlushCaches() { base_file_system_->FlushCaches(); } + } // namespace tensorflow diff --git a/tensorflow/core/platform/cloud/retrying_file_system.h b/tensorflow/core/platform/cloud/retrying_file_system.h index d9d8ea6b004c3cf1d0d77ff65fa415e746310afd..a262a5fd940f9b269721790c80caaef38d79d690 100644 --- a/tensorflow/core/platform/cloud/retrying_file_system.h +++ b/tensorflow/core/platform/cloud/retrying_file_system.h @@ -69,6 +69,8 @@ class RetryingFileSystem : public FileSystem { Status DeleteRecursively(const string& dirname, int64* undeleted_files, int64* undeleted_dirs) override; + void FlushCaches() override; + private: std::unique_ptr base_file_system_; const int64 initial_delay_microseconds_; diff --git a/tensorflow/core/platform/cloud/retrying_file_system_test.cc b/tensorflow/core/platform/cloud/retrying_file_system_test.cc index 232dcb3e71aa7c5b05b45e37332fe58970fc3fe8..d3f763bb3c845436e8458135a0a754d8cb002957 100644 --- a/tensorflow/core/platform/cloud/retrying_file_system_test.cc +++ b/tensorflow/core/platform/cloud/retrying_file_system_test.cc @@ -84,7 +84,8 @@ class MockWritableFile : public WritableFile { class MockFileSystem : public FileSystem { public: - explicit MockFileSystem(const ExpectedCalls& calls) : calls_(calls) {} + explicit MockFileSystem(const ExpectedCalls& calls, bool* flushed = nullptr) + : calls_(calls), flushed_(flushed) {} Status NewRandomAccessFile( const string& fname, std::unique_ptr* result) override { @@ -156,11 +157,18 @@ class MockFileSystem : public FileSystem { return calls_.ConsumeNextCall("DeleteRecursively"); } + void FlushCaches() override { + if (flushed_) { + *flushed_ = true; + } + } + std::unique_ptr writable_file_to_return; std::unique_ptr random_access_file_to_return; private: MockCallSequence calls_; + bool* flushed_ = nullptr; }; TEST(RetryingFileSystemTest, NewRandomAccessFile_ImmediateSuccess) { @@ -702,5 +710,14 @@ TEST(RetryingFileSystemTest, DeleteRecursively_AllRetriesFailed) { << status; } +TEST(RetryingFileSystemTest, FlushCaches) { + ExpectedCalls none; + bool flushed = false; + std::unique_ptr base_fs(new MockFileSystem(none, &flushed)); + RetryingFileSystem fs(std::move(base_fs), 0); + fs.FlushCaches(); + EXPECT_TRUE(flushed); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/core/platform/cuda_libdevice_path_test.cc b/tensorflow/core/platform/cuda_libdevice_path_test.cc index 639f6804ea236b86f458263091f371c1374e50ae..2d34239a9958d722a1cb84213657ca8229ebaf2c 100644 --- a/tensorflow/core/platform/cuda_libdevice_path_test.cc +++ b/tensorflow/core/platform/cuda_libdevice_path_test.cc @@ -27,8 +27,7 @@ TEST(CudaLibdevicePathTest, LibdevicePath) { VLOG(2) << "Libdevice root = " << LibdeviceRoot(); std::vector libdevice_files; TF_EXPECT_OK(Env::Default()->GetMatchingPaths( - io::JoinPath(LibdeviceRoot(), "libdevice.*.bc"), - &libdevice_files)); + io::JoinPath(LibdeviceRoot(), "libdevice.*.bc"), &libdevice_files)); EXPECT_LT(0, libdevice_files.size()); } #endif diff --git a/tensorflow/core/platform/default/build_config.bzl b/tensorflow/core/platform/default/build_config.bzl index 119ffa3d9e887676e0d970358e589cb1403cbd28..2102c5cca383b553c56fb3704596e3d1335c55c2 100644 --- a/tensorflow/core/platform/default/build_config.bzl +++ b/tensorflow/core/platform/default/build_config.bzl @@ -489,12 +489,6 @@ def tf_additional_core_deps(): "//tensorflow/core/platform/s3:s3_file_system", ], "//conditions:default": [], - }) + select({ - "//tensorflow:with_kafka_support": [ - "//tensorflow/contrib/kafka:kafka_kernels", - "//tensorflow/contrib/kafka:kafka_ops_op_lib", - ], - "//conditions:default": [], }) # TODO(jart, jhseu): Delete when GCP is default on. diff --git a/tensorflow/core/platform/default/device_tracer.cc b/tensorflow/core/platform/default/device_tracer.cc index f4b0f16393d70521386ad49fbf010591e5afb08c..8e60a7f0910ff9cf77a33f9d72d680ec42847777 100644 --- a/tensorflow/core/platform/default/device_tracer.cc +++ b/tensorflow/core/platform/default/device_tracer.cc @@ -579,8 +579,10 @@ Status DeviceTracerImpl::Collect(StepStatsCollector *collector) { // TODO(pbar) Handle device IDs and prefix properly. const string prefix = ""; const int id = 0; - const string stream_device = strings::StrCat(prefix, "/device:GPU:", id, "/stream:"); - const string memcpy_device = strings::StrCat(prefix, "/device:GPU:", id, "/memcpy"); + const string stream_device = + strings::StrCat(prefix, "/device:GPU:", id, "/stream:"); + const string memcpy_device = + strings::StrCat(prefix, "/device:GPU:", id, "/memcpy"); mutex_lock l2(trace_mu_); for (const auto &rec : kernel_records_) { diff --git a/tensorflow/core/platform/default/logging.cc b/tensorflow/core/platform/default/logging.cc index 82bd69f9ca46eb1b8dd586d18ed852a2e8c5084e..2b874da1981bed396330ca3c526d82779046bdf2 100644 --- a/tensorflow/core/platform/default/logging.cc +++ b/tensorflow/core/platform/default/logging.cc @@ -83,15 +83,14 @@ void LogMessage::GenerateLogMessage() { const size_t time_buffer_size = 30; char time_buffer[time_buffer_size]; strftime(time_buffer, time_buffer_size, "%Y-%m-%d %H:%M:%S", - localtime(&now_seconds)); + localtime(&now_seconds)); // TODO(jeff,sanjay): Replace this with something that logs through the env. fprintf(stderr, "%s.%06d: %c %s:%d] %s\n", time_buffer, micros_remainder, - "IWEF"[severity_], fname_, line_, str().c_str()); + "IWEF"[severity_], fname_, line_, str().c_str()); } #endif - namespace { // Parse log level (int64) from environment variable (char*) diff --git a/tensorflow/core/platform/default/logging.h b/tensorflow/core/platform/default/logging.h index 40c260f236613e533e30dc006e77b02f393bdd48..f0efa31d5576393e9d9bba6e39a454b2a33cddc3 100644 --- a/tensorflow/core/platform/default/logging.h +++ b/tensorflow/core/platform/default/logging.h @@ -19,8 +19,8 @@ limitations under the License. // IWYU pragma: private, include "third_party/tensorflow/core/platform/logging.h" // IWYU pragma: friend third_party/tensorflow/core/platform/logging.h -#include #include +#include #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" @@ -205,16 +205,18 @@ string* MakeCheckOpString(const T1& v1, const T2& v2, const char* exprtext) { inline string* name##Impl(int v1, int v2, const char* exprtext) { \ return name##Impl(v1, v2, exprtext); \ } \ - inline string* name##Impl(const size_t v1, const int v2, const char* exprtext) { \ + inline string* name##Impl(const size_t v1, const int v2, \ + const char* exprtext) { \ if (TF_PREDICT_FALSE(v2 < 0)) { \ - return ::tensorflow::internal::MakeCheckOpString(v1, v2, exprtext);\ + return ::tensorflow::internal::MakeCheckOpString(v1, v2, exprtext); \ } \ const size_t uval = (size_t)((unsigned)v1); \ return name##Impl(uval, v2, exprtext); \ } \ - inline string* name##Impl(const int v1, const size_t v2, const char* exprtext) { \ - if (TF_PREDICT_FALSE(v2 >= std::numeric_limits::max())) { \ - return ::tensorflow::internal::MakeCheckOpString(v1, v2, exprtext);\ + inline string* name##Impl(const int v1, const size_t v2, \ + const char* exprtext) { \ + if (TF_PREDICT_FALSE(v2 >= std::numeric_limits::max())) { \ + return ::tensorflow::internal::MakeCheckOpString(v1, v2, exprtext); \ } \ const size_t uval = (size_t)((unsigned)v2); \ return name##Impl(v1, uval, exprtext); \ @@ -225,12 +227,12 @@ string* MakeCheckOpString(const T1& v1, const T2& v2, const char* exprtext) { // This happens if, for example, those are used as token names in a // yacc grammar. TF_DEFINE_CHECK_OP_IMPL(Check_EQ, - == ) // Compilation error with CHECK_EQ(NULL, x)? -TF_DEFINE_CHECK_OP_IMPL(Check_NE, != ) // Use CHECK(x == NULL) instead. -TF_DEFINE_CHECK_OP_IMPL(Check_LE, <= ) -TF_DEFINE_CHECK_OP_IMPL(Check_LT, < ) -TF_DEFINE_CHECK_OP_IMPL(Check_GE, >= ) -TF_DEFINE_CHECK_OP_IMPL(Check_GT, > ) + ==) // Compilation error with CHECK_EQ(NULL, x)? +TF_DEFINE_CHECK_OP_IMPL(Check_NE, !=) // Use CHECK(x == NULL) instead. +TF_DEFINE_CHECK_OP_IMPL(Check_LE, <=) +TF_DEFINE_CHECK_OP_IMPL(Check_LT, <) +TF_DEFINE_CHECK_OP_IMPL(Check_GE, >=) +TF_DEFINE_CHECK_OP_IMPL(Check_GT, >) #undef TF_DEFINE_CHECK_OP_IMPL // In optimized mode, use CheckOpString to hint to compiler that diff --git a/tensorflow/core/platform/denormal.cc b/tensorflow/core/platform/denormal.cc index f13b0af2a79bec4538c64cbc475681f6eb0ce127..e00dbdb4ae5ef682369b345353e236a6084460ef 100644 --- a/tensorflow/core/platform/denormal.cc +++ b/tensorflow/core/platform/denormal.cc @@ -41,8 +41,8 @@ namespace tensorflow { namespace port { ScopedFlushDenormal::ScopedFlushDenormal() { -// For now, we flush denormals only on SSE 3. Other architectures such as ARM -// can be added as needed. + // For now, we flush denormals only on SSE 3. Other architectures such as ARM + // can be added as needed. #ifdef DENORM_USE_INTRINSICS if (TestCPUFeature(SSE3)) { diff --git a/tensorflow/core/platform/device_tracer_test.cc b/tensorflow/core/platform/device_tracer_test.cc index c0c08dabacbcb9fdbbfd9bdbe16bcfaea7328507..89f14e905afa4e2c10055f59721fe4cabf082781 100644 --- a/tensorflow/core/platform/device_tracer_test.cc +++ b/tensorflow/core/platform/device_tracer_test.cc @@ -77,7 +77,8 @@ class DeviceTracerTest : public ::testing::Test { Node* y_neg = test::graph::Unary(&graph, "Neg", i); y_neg_ = y_neg->name(); - y_neg->set_assigned_device_name("/job:localhost/replica:0/task:0/device:GPU:0"); + y_neg->set_assigned_device_name( + "/job:localhost/replica:0/task:0/device:GPU:0"); test::graph::ToGraphDef(&graph, &def_); } diff --git a/tensorflow/core/platform/env.h b/tensorflow/core/platform/env.h index 557bfa87e50a85a6f9de86548931ea215d8ac7ff..34aaf3f78ba983de2ca84cd5281219a244cdbd72 100644 --- a/tensorflow/core/platform/env.h +++ b/tensorflow/core/platform/env.h @@ -286,7 +286,7 @@ class Env { // "version" should be the version of the library or NULL // returns the name that LoadLibrary() can use virtual string FormatLibraryFileName(const string& name, - const string& version) = 0; + const string& version) = 0; private: // Returns a possible list of local temporary directories. @@ -353,6 +353,7 @@ class EnvWrapper : public Env { const string& version) override { return target_->FormatLibraryFileName(name, version); } + private: Env* target_; }; diff --git a/tensorflow/core/platform/file_system.cc b/tensorflow/core/platform/file_system.cc index 14755891fa2d3b916396c75c9647acafe66ec524..b9866cf641ac9126a3a7a3e9ecb2d3bc8f49ebc3 100644 --- a/tensorflow/core/platform/file_system.cc +++ b/tensorflow/core/platform/file_system.cc @@ -131,18 +131,19 @@ Status FileSystem::GetMatchingPaths(const string& pattern, if (children.empty()) continue; // This IsDirectory call can be expensive for some FS. Parallelizing it. children_dir_status.resize(children.size()); - ForEach(0, children.size(), [this, ¤t_dir, &children, &fixed_prefix, - &children_dir_status](int i) { - const string child_path = io::JoinPath(current_dir, children[i]); - // In case the child_path doesn't start with the fixed_prefix then - // we don't need to explore this path. - if (!StringPiece(child_path).starts_with(fixed_prefix)) { - children_dir_status[i] = - Status(tensorflow::error::CANCELLED, "Operation not needed"); - } else { - children_dir_status[i] = IsDirectory(child_path); - } - }); + ForEach(0, children.size(), + [this, ¤t_dir, &children, &fixed_prefix, + &children_dir_status](int i) { + const string child_path = io::JoinPath(current_dir, children[i]); + // In case the child_path doesn't start with the fixed_prefix then + // we don't need to explore this path. + if (!StringPiece(child_path).starts_with(fixed_prefix)) { + children_dir_status[i] = Status(tensorflow::error::CANCELLED, + "Operation not needed"); + } else { + children_dir_status[i] = IsDirectory(child_path); + } + }); for (int i = 0; i < children.size(); ++i) { const string child_path = io::JoinPath(current_dir, children[i]); // If the IsDirectory call was cancelled we bail. diff --git a/tensorflow/core/platform/gif.h b/tensorflow/core/platform/gif.h index 9c72d34ff518abcabf773af607589fe8114beebf..ab095a35c93517c6527b55bd922dbeb46d695ca4 100644 --- a/tensorflow/core/platform/gif.h +++ b/tensorflow/core/platform/gif.h @@ -20,7 +20,8 @@ limitations under the License. #if defined(PLATFORM_GOOGLE) #include "tensorflow/core/platform/google/build_config/gif.h" -#elif defined(PLATFORM_POSIX)|| defined(PLATFORM_WINDOWS) ||defined(PLATFORM_POSIX_ANDROID) +#elif defined(PLATFORM_POSIX) || defined(PLATFORM_WINDOWS) || \ + defined(PLATFORM_POSIX_ANDROID) #include #else #error Define the appropriate PLATFORM_ macro for this platform diff --git a/tensorflow/core/platform/hadoop/hadoop_file_system.cc b/tensorflow/core/platform/hadoop/hadoop_file_system.cc index 0baeac09841073ad6013a4700646e82d5d97182f..74863293a32451e8881c93de468539b913169aaa 100644 --- a/tensorflow/core/platform/hadoop/hadoop_file_system.cc +++ b/tensorflow/core/platform/hadoop/hadoop_file_system.cc @@ -164,8 +164,9 @@ Status HadoopFileSystem::Connect(StringPiece fname, hdfsFS* fs) { } else { hdfs_->hdfsBuilderSetNameNode(builder, nn.c_str()); } - // KERB_TICKET_CACHE_PATH will be deleted in the future, Because KRB5CCNAME is the build in - // environment variable of Kerberos, so KERB_TICKET_CACHE_PATH and related code are unnecessary. + // KERB_TICKET_CACHE_PATH will be deleted in the future, Because KRB5CCNAME is + // the build in environment variable of Kerberos, so KERB_TICKET_CACHE_PATH + // and related code are unnecessary. char* ticket_cache_path = getenv("KERB_TICKET_CACHE_PATH"); if (ticket_cache_path != nullptr) { hdfs_->hdfsBuilderSetKerbTicketCachePath(builder, ticket_cache_path); diff --git a/tensorflow/core/platform/jpeg.h b/tensorflow/core/platform/jpeg.h index edbcbd960a7d61970119bfb385f075e1d3ffb96f..1b5e633f0aad09850afa82bee59d45c7943bbd8a 100644 --- a/tensorflow/core/platform/jpeg.h +++ b/tensorflow/core/platform/jpeg.h @@ -20,7 +20,8 @@ limitations under the License. #if defined(PLATFORM_GOOGLE) #include "tensorflow/core/platform/google/build_config/jpeg.h" -#elif defined(PLATFORM_POSIX)|| defined(PLATFORM_WINDOWS) ||defined(PLATFORM_POSIX_ANDROID) +#elif defined(PLATFORM_POSIX) || defined(PLATFORM_WINDOWS) || \ + defined(PLATFORM_POSIX_ANDROID) #include #include #include diff --git a/tensorflow/core/platform/mem.h b/tensorflow/core/platform/mem.h index dc389a8741501d27394ac559c95eaa73c2014afd..7bb9fc264fbf6ee3f20e9b2687c9ba52b6171ec4 100644 --- a/tensorflow/core/platform/mem.h +++ b/tensorflow/core/platform/mem.h @@ -59,6 +59,9 @@ void MallocExtension_ReleaseToSystem(std::size_t num_bytes); // routine, this routine returns 0. std::size_t MallocExtension_GetAllocatedSize(const void* p); +// Returns the amount of RAM available in kB, or INT64_MAX if unknown. +int64 AvailableRam(); + } // namespace port } // namespace tensorflow diff --git a/tensorflow/core/platform/png.h b/tensorflow/core/platform/png.h index 5b0203c343e6b1764a9cc8a7908919422d826bcb..dad18d72195953e78c6a169a19b9182ae6571485 100644 --- a/tensorflow/core/platform/png.h +++ b/tensorflow/core/platform/png.h @@ -20,7 +20,8 @@ limitations under the License. #if defined(PLATFORM_GOOGLE) #include "tensorflow/core/platform/google/build_config/png.h" -#elif defined(PLATFORM_POSIX)|| defined(PLATFORM_WINDOWS) ||defined(PLATFORM_POSIX_ANDROID) +#elif defined(PLATFORM_POSIX) || defined(PLATFORM_WINDOWS) || \ + defined(PLATFORM_POSIX_ANDROID) #include #else #error Define the appropriate PLATFORM_ macro for this platform diff --git a/tensorflow/core/platform/posix/error.cc b/tensorflow/core/platform/posix/error.cc index cda6d7d8f9d6ad3e7f2c8fa56cc99a8dbe07fa00..2bb9443fb3c45e0cd4bb31a48539355747684b5f 100644 --- a/tensorflow/core/platform/posix/error.cc +++ b/tensorflow/core/platform/posix/error.cc @@ -73,19 +73,19 @@ error::Code ErrnoToCode(int err_number) { case ECHILD: // No child processes case EISCONN: // Socket is connected #if !defined(_WIN32) && !defined(__HAIKU__) - case ENOTBLK: // Block device required + case ENOTBLK: // Block device required #endif - case ENOTCONN: // The socket is not connected - case EPIPE: // Broken pipe + case ENOTCONN: // The socket is not connected + case EPIPE: // Broken pipe #if !defined(_WIN32) - case ESHUTDOWN: // Cannot send after transport endpoint shutdown + case ESHUTDOWN: // Cannot send after transport endpoint shutdown #endif - case ETXTBSY: // Text file busy + case ETXTBSY: // Text file busy code = error::FAILED_PRECONDITION; break; - case ENOSPC: // No space left on device + case ENOSPC: // No space left on device #if !defined(_WIN32) - case EDQUOT: // Disk quota exceeded + case EDQUOT: // Disk quota exceeded #endif case EMFILE: // Too many open files case EMLINK: // Too many links @@ -95,7 +95,7 @@ error::Code ErrnoToCode(int err_number) { case ENOMEM: // Not enough space case ENOSR: // No STREAM resources #if !defined(_WIN32) && !defined(__HAIKU__) - case EUSERS: // Too many users + case EUSERS: // Too many users #endif code = error::RESOURCE_EXHAUSTED; break; @@ -104,17 +104,17 @@ error::Code ErrnoToCode(int err_number) { case ERANGE: // Result too large code = error::OUT_OF_RANGE; break; - case ENOSYS: // Function not implemented - case ENOTSUP: // Operation not supported - case EAFNOSUPPORT: // Address family not supported + case ENOSYS: // Function not implemented + case ENOTSUP: // Operation not supported + case EAFNOSUPPORT: // Address family not supported #if !defined(_WIN32) - case EPFNOSUPPORT: // Protocol family not supported + case EPFNOSUPPORT: // Protocol family not supported #endif case EPROTONOSUPPORT: // Protocol not supported #if !defined(_WIN32) && !defined(__HAIKU__) case ESOCKTNOSUPPORT: // Socket type not supported #endif - case EXDEV: // Improper link + case EXDEV: // Improper link code = error::UNIMPLEMENTED; break; case EAGAIN: // Resource temporarily unavailable @@ -123,7 +123,7 @@ error::Code ErrnoToCode(int err_number) { case ECONNRESET: // Connection reset case EINTR: // Interrupted function call #if !defined(_WIN32) - case EHOSTDOWN: // Host is down + case EHOSTDOWN: // Host is down #endif case EHOSTUNREACH: // Host is unreachable case ENETDOWN: // Network is down @@ -139,7 +139,7 @@ error::Code ErrnoToCode(int err_number) { break; case EDEADLK: // Resource deadlock avoided #if !defined(_WIN32) - case ESTALE: // Stale file handle + case ESTALE: // Stale file handle #endif code = error::ABORTED; break; @@ -158,7 +158,7 @@ error::Code ErrnoToCode(int err_number) { case ENOMSG: // No message of the desired type case EPROTO: // Protocol error #if !defined(_WIN32) && !defined(__HAIKU__) - case EREMOTE: // Object is remote + case EREMOTE: // Object is remote #endif code = error::UNKNOWN; break; diff --git a/tensorflow/core/platform/posix/port.cc b/tensorflow/core/platform/posix/port.cc index 614ee00b0133976e9fe49caf7c75a01194e10237..494acde803a778fb839a7444e4d5ac2fd094eb09 100644 --- a/tensorflow/core/platform/posix/port.cc +++ b/tensorflow/core/platform/posix/port.cc @@ -29,6 +29,7 @@ limitations under the License. #if defined(__linux__) && !defined(__ANDROID__) #include +#include #endif #include #include @@ -171,5 +172,16 @@ double NominalCPUFrequency() { #endif } +int64 AvailableRam() { +#if defined(__linux__) && !defined(__ANDROID__) + struct sysinfo info; + int err = sysinfo(&info); + if (err == 0) { + return info.freeram / 1024; + } +#endif + return INT64_MAX; +} + } // namespace port } // namespace tensorflow diff --git a/tensorflow/core/platform/profile_utils/android_armv7a_cpu_utils_helper.h b/tensorflow/core/platform/profile_utils/android_armv7a_cpu_utils_helper.h index 8604b01c53ef69040a919dadda73df897e98b0e1..ce2069b004473a684a1882068d3479ed049c58d6 100644 --- a/tensorflow/core/platform/profile_utils/android_armv7a_cpu_utils_helper.h +++ b/tensorflow/core/platform/profile_utils/android_armv7a_cpu_utils_helper.h @@ -58,8 +58,8 @@ class AndroidArmV7ACpuUtilsHelper : public ICpuUtilsHelper { TF_DISALLOW_COPY_AND_ASSIGN(AndroidArmV7ACpuUtilsHelper); }; -} // profile_utils -} // tensorflow +} // namespace profile_utils +} // namespace tensorflow #endif // defined(__ANDROID__) && (__ANDROID_API__ >= 21) && // (defined(__ARM_ARCH_7A__) || defined(__aarch64__)) diff --git a/tensorflow/core/platform/profile_utils/cpu_utils.cc b/tensorflow/core/platform/profile_utils/cpu_utils.cc index d3362690d7e08c8e88e8168b62c8134b6af5a319..02de7d1362bbfca645d07ee72165283351944b9b 100644 --- a/tensorflow/core/platform/profile_utils/cpu_utils.cc +++ b/tensorflow/core/platform/profile_utils/cpu_utils.cc @@ -28,15 +28,17 @@ namespace profile_utils { static ICpuUtilsHelper* cpu_utils_helper_instance_ = nullptr; -#if (defined(__powerpc__) || defined(__ppc__) && ( __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__)) || (defined(__s390x__)) - /* static */ uint64 CpuUtils::GetCycleCounterFrequency() { - static const uint64 cpu_frequency = GetCycleCounterFrequencyImpl(); - return cpu_frequency; +#if (defined(__powerpc__) || \ + defined(__ppc__) && (__BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__)) || \ + (defined(__s390x__)) +/* static */ uint64 CpuUtils::GetCycleCounterFrequency() { + static const uint64 cpu_frequency = GetCycleCounterFrequencyImpl(); + return cpu_frequency; } #else - /* static */ int64 CpuUtils::GetCycleCounterFrequency() { - static const int64 cpu_frequency = GetCycleCounterFrequencyImpl(); - return cpu_frequency; +/* static */ int64 CpuUtils::GetCycleCounterFrequency() { + static const int64 cpu_frequency = GetCycleCounterFrequencyImpl(); + return cpu_frequency; } #endif diff --git a/tensorflow/core/platform/profile_utils/cpu_utils.h b/tensorflow/core/platform/profile_utils/cpu_utils.h index e95843b80a53f2d711c2ee162203510e1a7821ae..7b580c8bf606cdd9acf998fa21cb1d946e5e6ada 100644 --- a/tensorflow/core/platform/profile_utils/cpu_utils.h +++ b/tensorflow/core/platform/profile_utils/cpu_utils.h @@ -94,14 +94,16 @@ class CpuUtils { #endif } - // Return cycle counter frequency. - // As this method caches the cpu frequency internally, - // the first call will incur overhead, but not subsequent calls. - #if (defined(__powerpc__) || defined(__ppc__) && ( __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__)) || (defined(__s390x__)) - static uint64 GetCycleCounterFrequency(); - #else - static int64 GetCycleCounterFrequency(); - #endif +// Return cycle counter frequency. +// As this method caches the cpu frequency internally, +// the first call will incur overhead, but not subsequent calls. +#if (defined(__powerpc__) || \ + defined(__ppc__) && (__BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__)) || \ + (defined(__s390x__)) + static uint64 GetCycleCounterFrequency(); +#else + static int64 GetCycleCounterFrequency(); +#endif // Return micro second per each clock // As this method caches the cpu frequency internally, diff --git a/tensorflow/core/platform/profile_utils/cpu_utils_test.cc b/tensorflow/core/platform/profile_utils/cpu_utils_test.cc index 5b11b684dd9833bf742faaeaa3e79d2b49a78c6d..eb8161fbfd5ddfc796edd66a9119ad70c3c1de8e 100644 --- a/tensorflow/core/platform/profile_utils/cpu_utils_test.cc +++ b/tensorflow/core/platform/profile_utils/cpu_utils_test.cc @@ -53,15 +53,17 @@ TEST_F(CpuUtilsTest, CheckGetCurrentClockCycle) { } TEST_F(CpuUtilsTest, CheckCycleCounterFrequency) { - #if (defined(__powerpc__) || defined(__ppc__) && ( __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__)) || (defined(__s390x__)) - const uint64 cpu_frequency = CpuUtils::GetCycleCounterFrequency(); - CHECK_GT(cpu_frequency, 0); - CHECK_NE(cpu_frequency, unsigned(CpuUtils::INVALID_FREQUENCY)); - #else - const int64 cpu_frequency = CpuUtils::GetCycleCounterFrequency(); - CHECK_GT(cpu_frequency, 0); - CHECK_NE(cpu_frequency, CpuUtils::INVALID_FREQUENCY); - #endif +#if (defined(__powerpc__) || \ + defined(__ppc__) && (__BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__)) || \ + (defined(__s390x__)) + const uint64 cpu_frequency = CpuUtils::GetCycleCounterFrequency(); + CHECK_GT(cpu_frequency, 0); + CHECK_NE(cpu_frequency, unsigned(CpuUtils::INVALID_FREQUENCY)); +#else + const int64 cpu_frequency = CpuUtils::GetCycleCounterFrequency(); + CHECK_GT(cpu_frequency, 0); + CHECK_NE(cpu_frequency, CpuUtils::INVALID_FREQUENCY); +#endif if (DBG) { LOG(INFO) << "Cpu frequency = " << cpu_frequency; } diff --git a/tensorflow/core/platform/profile_utils/i_cpu_utils_helper.h b/tensorflow/core/platform/profile_utils/i_cpu_utils_helper.h index 51c54d50d1dadcf78e8263ce44b07c998b68c05c..11b739c0096b5b5fd498bb5c753a54c8b1628208 100644 --- a/tensorflow/core/platform/profile_utils/i_cpu_utils_helper.h +++ b/tensorflow/core/platform/profile_utils/i_cpu_utils_helper.h @@ -47,7 +47,7 @@ class ICpuUtilsHelper { TF_DISALLOW_COPY_AND_ASSIGN(ICpuUtilsHelper); }; -} // profile_utils -} // tensorflow +} // namespace profile_utils +} // namespace tensorflow #endif // TENSORFLOW_PLATFORM_PROFILEUTILS_I_CPU_UTILS_HELPER_H__ diff --git a/tensorflow/core/platform/protobuf_internal.h b/tensorflow/core/platform/protobuf_internal.h index 7d6e8f57a62e08a7897bdccdeb7033363b282bd4..2f151a5aee6af067e4536bb569b4c0799c831b98 100644 --- a/tensorflow/core/platform/protobuf_internal.h +++ b/tensorflow/core/platform/protobuf_internal.h @@ -45,8 +45,8 @@ Status ParseAny(const google::protobuf::Any& any, T* message, #ifdef TENSORFLOW_LITE_PROTOS if (any.type_url() != strings::StrCat("type.googleapis.com/", type_name)) { return errors::FailedPrecondition( - "Expected Any type_url for: ", type_name, ". Got: ", - string(any.type_url().data(), any.type_url().size()), "."); + "Expected Any type_url for: ", type_name, + ". Got: ", string(any.type_url().data(), any.type_url().size()), "."); } if (!message->ParseFromString(any.value())) { return errors::FailedPrecondition("Failed to unpack: ", diff --git a/tensorflow/core/platform/s3/aws_logging.cc b/tensorflow/core/platform/s3/aws_logging.cc index fbca0acc36b01fa91dece4bdd0d19b7059dc114e..44317f1a3e41831b903bd0044d53d1eba80168df 100644 --- a/tensorflow/core/platform/s3/aws_logging.cc +++ b/tensorflow/core/platform/s3/aws_logging.cc @@ -96,7 +96,7 @@ Aws::Utils::Logging::LogLevel ParseLogLevelFromEnv() { return log_level; } -} +} // namespace static bool initialized = false; static mutex s3_logging_mutex(LINKER_INITIALIZED); diff --git a/tensorflow/core/platform/s3/s3_file_system.h b/tensorflow/core/platform/s3/s3_file_system.h index d0d6bb59499797d6b74ea61bb5dab7406b5e1bbb..8177e48dba52f11458faeb3092a12e6801f6b7ef 100644 --- a/tensorflow/core/platform/s3/s3_file_system.h +++ b/tensorflow/core/platform/s3/s3_file_system.h @@ -63,7 +63,7 @@ class S3FileSystem : public FileSystem { // variables. // By default S3 access regional endpoint, with region // controlled by `AWS_REGION`. The endpoint could be overridden - // with explicity `S3_ENDPOINT`. S3 use HTTPS by default. + // explicitly with `S3_ENDPOINT`. S3 uses HTTPS by default. // If S3_USE_HTTPS=0 is specified, HTTP is used. Also, // S3_VERIFY_SSL=0 could disable SSL verification in case // HTTPS is used. diff --git a/tensorflow/core/platform/setround.cc b/tensorflow/core/platform/setround.cc index 0c66da09bb9aa1c892063be11c66aedaf75d7eb6..592626bfa17e691d1b10ddce5c7f0f31ed825861 100644 --- a/tensorflow/core/platform/setround.cc +++ b/tensorflow/core/platform/setround.cc @@ -15,7 +15,6 @@ limitations under the License. #include "tensorflow/core/platform/setround.h" - namespace tensorflow { namespace port { diff --git a/tensorflow/core/platform/test_benchmark.h b/tensorflow/core/platform/test_benchmark.h index a6636225ccbbc8154e290cd7f1aa6cafe3d2027a..327237dba933230cb313dd06091d2ff2ca3cc4b2 100644 --- a/tensorflow/core/platform/test_benchmark.h +++ b/tensorflow/core/platform/test_benchmark.h @@ -60,7 +60,7 @@ class Benchmark { private: string name_; int num_args_; - std::vector> args_; + std::vector > args_; void (*fn0_)(int) = nullptr; void (*fn1_)(int, int) = nullptr; void (*fn2_)(int, int, int) = nullptr; diff --git a/tensorflow/core/platform/windows/cpu_info.h b/tensorflow/core/platform/windows/cpu_info.h index d6e78dbc8f9f25070d94141e46d35dcb8d727ef7..f20939d3c0ff02be30f19be170644fab44b6f45e 100644 --- a/tensorflow/core/platform/windows/cpu_info.h +++ b/tensorflow/core/platform/windows/cpu_info.h @@ -22,8 +22,10 @@ limitations under the License. // Byte order defines provided by gcc. MSVC doesn't define those so // we define them here. // We assume that all windows platform out there are little endian. +#if defined(_MSC_VER) && !defined(__clang__) #define __ORDER_LITTLE_ENDIAN__ 0x4d2 #define __ORDER_BIG_ENDIAN__ 0x10e1 #define __BYTE_ORDER__ __ORDER_LITTLE_ENDIAN__ +#endif #endif // TENSORFLOW_PLATFORM_WINDOWS_CPU_INFO_H_ diff --git a/tensorflow/core/platform/windows/env.cc b/tensorflow/core/platform/windows/env.cc index 788a4bf4b1af74393099d1b590a1e589d9a07f25..41b264417071cadb5f70806b458ee2b46ebb2feb 100644 --- a/tensorflow/core/platform/windows/env.cc +++ b/tensorflow/core/platform/windows/env.cc @@ -24,9 +24,9 @@ limitations under the License. #undef LoadLibrary #undef ERROR +#include #include #include -#include #include "tensorflow/core/lib/core/error_codes.pb.h" #include "tensorflow/core/platform/load_library.h" @@ -53,8 +53,7 @@ class StdThread : public Thread { class WindowsEnv : public Env { public: - WindowsEnv() - : GetSystemTimePreciseAsFileTime_(NULL) { + WindowsEnv() : GetSystemTimePreciseAsFileTime_(NULL) { // GetSystemTimePreciseAsFileTime function is only available in the latest // versions of Windows. For that reason, we try to look it up in // kernel32.dll at runtime and use an alternative option if the function @@ -72,8 +71,8 @@ class WindowsEnv : public Env { } bool MatchPath(const string& path, const string& pattern) override { - std::wstring ws_path(WindowsFileSystem::Utf8ToWideChar(path)); - std::wstring ws_pattern(WindowsFileSystem::Utf8ToWideChar(pattern)); + std::wstring ws_path(WindowsFileSystem::Utf8ToWideChar(path)); + std::wstring ws_pattern(WindowsFileSystem::Utf8ToWideChar(pattern)); return PathMatchSpecW(ws_path.c_str(), ws_pattern.c_str()) == TRUE; } @@ -122,14 +121,14 @@ class WindowsEnv : public Env { SetThreadpoolTimer(timer, &FileDueTime, 0, 0); } - Status LoadLibrary(const char *library_filename, void** handle) override { + Status LoadLibrary(const char* library_filename, void** handle) override { std::string file_name = library_filename; std::replace(file_name.begin(), file_name.end(), '/', '\\'); std::wstring ws_file_name(WindowsFileSystem::Utf8ToWideChar(file_name)); HMODULE hModule = LoadLibraryExW(ws_file_name.c_str(), NULL, - LOAD_WITH_ALTERED_SEARCH_PATH); + LOAD_WITH_ALTERED_SEARCH_PATH); if (!hModule) { return errors::NotFound(file_name + " not found"); } @@ -138,31 +137,30 @@ class WindowsEnv : public Env { } Status GetSymbolFromLibrary(void* handle, const char* symbol_name, - void** symbol) override { + void** symbol) override { FARPROC found_symbol; found_symbol = GetProcAddress((HMODULE)handle, symbol_name); if (found_symbol == NULL) { return errors::NotFound(std::string(symbol_name) + " not found"); } - *symbol = (void **)found_symbol; + *symbol = (void**)found_symbol; return Status::OK(); } - string FormatLibraryFileName(const string& name, const string& version) - override { + string FormatLibraryFileName(const string& name, + const string& version) override { string filename; if (version.size() == 0) { filename = name + ".dll"; - } - else { + } else { filename = name + version + ".dll"; } return filename; } private: - typedef VOID(WINAPI * FnGetSystemTimePreciseAsFileTime)(LPFILETIME); + typedef VOID(WINAPI* FnGetSystemTimePreciseAsFileTime)(LPFILETIME); FnGetSystemTimePreciseAsFileTime GetSystemTimePreciseAsFileTime_; }; diff --git a/tensorflow/core/platform/windows/error.cc b/tensorflow/core/platform/windows/error.cc index 39e941a3834f7f7cd03e7791d43d56f190dc1fd6..291fc5003fb6bbc07274cdea72d73e92a453f363 100644 --- a/tensorflow/core/platform/windows/error.cc +++ b/tensorflow/core/platform/windows/error.cc @@ -21,7 +21,7 @@ namespace internal { std::string GetWindowsErrorMessage(DWORD err) { LPSTR buffer = NULL; DWORD flags = FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | - FORMAT_MESSAGE_IGNORE_INSERTS; + FORMAT_MESSAGE_IGNORE_INSERTS; FormatMessageA(flags, NULL, err, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), reinterpret_cast(&buffer), 0, NULL); std::string message = buffer; diff --git a/tensorflow/core/platform/windows/error.h b/tensorflow/core/platform/windows/error.h index 026e0d5aa946f7c851dacc05a3306631e06886aa..ba643a0fa8f92f58fbd88ac00fba3f663bb7e0f2 100644 --- a/tensorflow/core/platform/windows/error.h +++ b/tensorflow/core/platform/windows/error.h @@ -24,9 +24,7 @@ namespace tensorflow { namespace internal { std::string GetWindowsErrorMessage(DWORD err); - -} } +} // namespace tensorflow #endif // TENSORFLOW_CORE_PLATFORM_WINDOWS_ERROR_H_ - diff --git a/tensorflow/core/platform/windows/integral_types.h b/tensorflow/core/platform/windows/integral_types.h index 4970b8ca6a1673dd24d2d445348fe5b337ae13be..46338a536dbc3541763e62954fee74b2a5a0700b 100644 --- a/tensorflow/core/platform/windows/integral_types.h +++ b/tensorflow/core/platform/windows/integral_types.h @@ -1,18 +1,18 @@ - /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ==============================================================================*/ - +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + #ifndef TENSORFLOW_PLATFORM_WINDOWS_INTEGRAL_TYPES_H_ #define TENSORFLOW_PLATFORM_WINDOWS_INTEGRAL_TYPES_H_ diff --git a/tensorflow/core/platform/windows/net.cc b/tensorflow/core/platform/windows/net.cc index 46eb072d42592028859122a4cad3d9478a96476e..2ab558ab95cafd15b10f7b887c846b32ab7e4c47 100644 --- a/tensorflow/core/platform/windows/net.cc +++ b/tensorflow/core/platform/windows/net.cc @@ -26,7 +26,7 @@ limitations under the License. #undef ERROR -#pragma comment(lib,"Ws2_32.lib") +#pragma comment(lib, "Ws2_32.lib") namespace tensorflow { namespace internal { @@ -44,8 +44,8 @@ bool IsPortAvailable(int* port, bool is_tcp) { CHECK_GE(*port, 0); CHECK_LE(*port, 65535); if (sock == INVALID_SOCKET) { - LOG(ERROR) << "socket() failed: " << - GetWindowsErrorMessage(WSAGetLastError()); + LOG(ERROR) << "socket() failed: " + << GetWindowsErrorMessage(WSAGetLastError()); return false; } @@ -54,8 +54,8 @@ bool IsPortAvailable(int* port, bool is_tcp) { int result = setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast(&one), sizeof(one)); if (result == SOCKET_ERROR) { - LOG(ERROR) << "setsockopt() failed: " << - GetWindowsErrorMessage(WSAGetLastError()); + LOG(ERROR) << "setsockopt() failed: " + << GetWindowsErrorMessage(WSAGetLastError()); closesocket(sock); return false; } @@ -66,8 +66,8 @@ bool IsPortAvailable(int* port, bool is_tcp) { addr.sin_port = htons((uint16_t)*port); result = bind(sock, (struct sockaddr*)&addr, sizeof(addr)); if (result == SOCKET_ERROR) { - LOG(WARNING) << "bind(port=" << *port << ") failed: " << - GetWindowsErrorMessage(WSAGetLastError()); + LOG(WARNING) << "bind(port=" << *port + << ") failed: " << GetWindowsErrorMessage(WSAGetLastError()); closesocket(sock); return false; } @@ -75,8 +75,8 @@ bool IsPortAvailable(int* port, bool is_tcp) { // Get the bound port number. result = getsockname(sock, (struct sockaddr*)&addr, &addr_len); if (result == SOCKET_ERROR) { - LOG(WARNING) << "getsockname() failed: " << - GetWindowsErrorMessage(WSAGetLastError()); + LOG(WARNING) << "getsockname() failed: " + << GetWindowsErrorMessage(WSAGetLastError()); closesocket(sock); return false; } diff --git a/tensorflow/core/platform/windows/port.cc b/tensorflow/core/platform/windows/port.cc index e327d53949caf7e2d30e6deba0be2848f010afc2..582b232054b850a2ef5ab8f47c089eb35a7bb3cf 100644 --- a/tensorflow/core/platform/windows/port.cc +++ b/tensorflow/core/platform/windows/port.cc @@ -149,8 +149,20 @@ bool Snappy_Uncompress(const char* input, size_t length, char* output) { string Demangle(const char* mangled) { return mangled; } double NominalCPUFrequency() { - // TODO(yuefengz): implement it for this platform. +#ifdef TENSORFLOW_USE_ABSL + return absl::base_internal::NominalCPUFrequency(); +#else return 1.0; +#endif +} + +int64 AvailableRam() { + MEMORYSTATUSEX statex; + statex.dwLength = sizeof(statex); + if (GlobalMemoryStatusEx(&statex)) { + return statex.ullAvailPhys / 1024; + } + return INT64_MAX; } } // namespace port diff --git a/tensorflow/core/platform/windows/subprocess.h b/tensorflow/core/platform/windows/subprocess.h index b65313363ed79ab327414179a9923ba2d436dd0b..66ec44885d52195b807f4957aec6d590324b2975 100644 --- a/tensorflow/core/platform/windows/subprocess.h +++ b/tensorflow/core/platform/windows/subprocess.h @@ -19,8 +19,7 @@ limitations under the License. namespace tensorflow { // SubProcess is not yet implemented for Windows. -class SubProcess { -}; +class SubProcess {}; } // namespace tensorflow diff --git a/tensorflow/core/platform/windows/test.cc b/tensorflow/core/platform/windows/test.cc index 0ffd02ff14849d77761e85c30388dc49a53c84db..584acad91b24fc6be9b93f71b7d44b0fba3cb2e8 100644 --- a/tensorflow/core/platform/windows/test.cc +++ b/tensorflow/core/platform/windows/test.cc @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/platform/net.h" #include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/net.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" diff --git a/tensorflow/core/platform/windows/windows_file_system.cc b/tensorflow/core/platform/windows/windows_file_system.cc index 604348fe03a01d44195ba8a8ff427ae3ef3a4137..b6b3722caae4dc0cdc0ddff91be479ab91a744b2 100644 --- a/tensorflow/core/platform/windows/windows_file_system.cc +++ b/tensorflow/core/platform/windows/windows_file_system.cc @@ -13,12 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include #include #include #include -#include #undef StrCat #include #include @@ -75,16 +75,16 @@ SSIZE_T pread(HANDLE hfile, char* src, size_t num_bytes, uint64_t offset) { if (TRUE == read_result) { result = bytes_read; } else if ((FALSE == read_result) && - ((last_error = GetLastError()) != ERROR_IO_PENDING)) { + ((last_error = GetLastError()) != ERROR_IO_PENDING)) { result = (last_error == ERROR_HANDLE_EOF) ? 0 : -1; } else { - if (ERROR_IO_PENDING == last_error) { // Otherwise bytes_read already has the result. - BOOL overlapped_result = ::GetOverlappedResult(hfile, &overlapped, - &bytes_read, TRUE); + if (ERROR_IO_PENDING == + last_error) { // Otherwise bytes_read already has the result. + BOOL overlapped_result = + ::GetOverlappedResult(hfile, &overlapped, &bytes_read, TRUE); if (FALSE == overlapped_result) { result = (::GetLastError() == ERROR_HANDLE_EOF) ? 0 : -1; - } - else { + } else { result = bytes_read; } } @@ -151,11 +151,11 @@ class WindowsWritableFile : public WritableFile { Status Append(const StringPiece& data) override { DWORD bytes_written = 0; DWORD data_size = static_cast(data.size()); - BOOL write_result = ::WriteFile(hfile_, data.data(), data_size, - &bytes_written, NULL); + BOOL write_result = + ::WriteFile(hfile_, data.data(), data_size, &bytes_written, NULL); if (FALSE == write_result) { - return IOErrorFromWindowsError( - "Failed to WriteFile: " + filename_, ::GetLastError()); + return IOErrorFromWindowsError("Failed to WriteFile: " + filename_, + ::GetLastError()); } assert(size_t(bytes_written) == data.size()); @@ -171,8 +171,8 @@ class WindowsWritableFile : public WritableFile { } if (FALSE == ::CloseHandle(hfile_)) { - return IOErrorFromWindowsError( - "CloseHandle failed for: " + filename_, ::GetLastError()); + return IOErrorFromWindowsError("CloseHandle failed for: " + filename_, + ::GetLastError()); } hfile_ = INVALID_HANDLE_VALUE; @@ -187,9 +187,7 @@ class WindowsWritableFile : public WritableFile { return Status::OK(); } - Status Sync() override { - return Flush(); - } + Status Sync() override { return Flush(); } }; class WinReadOnlyMemoryRegion : public ReadOnlyMemoryRegion { @@ -204,7 +202,10 @@ class WinReadOnlyMemoryRegion : public ReadOnlyMemoryRegion { public: WinReadOnlyMemoryRegion(const std::string& filename, HANDLE hfile, HANDLE hmap, const void* address, uint64 length) - : filename_(filename), hfile_(hfile), hmap_(hmap), address_(address), + : filename_(filename), + hfile_(hfile), + hmap_(hmap), + address_(address), length_(length) {} ~WinReadOnlyMemoryRegion() { @@ -238,9 +239,9 @@ Status WindowsFileSystem::NewRandomAccessFile( // almost all tests would work with a possible exception of fault_injection. DWORD share_mode = FILE_SHARE_READ | FILE_SHARE_WRITE | FILE_SHARE_DELETE; - HANDLE hfile = ::CreateFileW(ws_translated_fname.c_str(), GENERIC_READ, - share_mode, NULL, OPEN_EXISTING, file_flags, - NULL); + HANDLE hfile = + ::CreateFileW(ws_translated_fname.c_str(), GENERIC_READ, share_mode, NULL, + OPEN_EXISTING, file_flags, NULL); if (INVALID_HANDLE_VALUE == hfile) { string context = "NewRandomAccessFile failed to Create/Open: " + fname; @@ -258,9 +259,9 @@ Status WindowsFileSystem::NewWritableFile( result->reset(); DWORD share_mode = FILE_SHARE_READ | FILE_SHARE_WRITE | FILE_SHARE_DELETE; - HANDLE hfile = ::CreateFileW(ws_translated_fname.c_str(), GENERIC_WRITE, - share_mode, NULL, CREATE_ALWAYS, - FILE_ATTRIBUTE_NORMAL, NULL); + HANDLE hfile = + ::CreateFileW(ws_translated_fname.c_str(), GENERIC_WRITE, share_mode, + NULL, CREATE_ALWAYS, FILE_ATTRIBUTE_NORMAL, NULL); if (INVALID_HANDLE_VALUE == hfile) { string context = "Failed to create a NewWriteableFile: " + fname; @@ -278,9 +279,9 @@ Status WindowsFileSystem::NewAppendableFile( result->reset(); DWORD share_mode = FILE_SHARE_READ | FILE_SHARE_WRITE | FILE_SHARE_DELETE; - HANDLE hfile = ::CreateFileW(ws_translated_fname.c_str(), GENERIC_WRITE, - share_mode, NULL, OPEN_ALWAYS, - FILE_ATTRIBUTE_NORMAL, NULL); + HANDLE hfile = + ::CreateFileW(ws_translated_fname.c_str(), GENERIC_WRITE, share_mode, + NULL, OPEN_ALWAYS, FILE_ATTRIBUTE_NORMAL, NULL); if (INVALID_HANDLE_VALUE == hfile) { string context = "Failed to create a NewAppendableFile: " + fname; @@ -316,9 +317,9 @@ Status WindowsFileSystem::NewReadOnlyMemoryRegionFromFile( file_flags |= FILE_FLAG_OVERLAPPED; DWORD share_mode = FILE_SHARE_READ | FILE_SHARE_WRITE | FILE_SHARE_DELETE; - HANDLE hfile = ::CreateFileW(ws_translated_fname.c_str(), GENERIC_READ, - share_mode, NULL, OPEN_EXISTING, file_flags, - NULL); + HANDLE hfile = + ::CreateFileW(ws_translated_fname.c_str(), GENERIC_READ, share_mode, NULL, + OPEN_EXISTING, file_flags, NULL); if (INVALID_HANDLE_VALUE == hfile) { return IOErrorFromWindowsError( @@ -345,28 +346,32 @@ Status WindowsFileSystem::NewReadOnlyMemoryRegionFromFile( NULL); // Mapping name if (!hmap) { - string context = "Failed to create file mapping for " - "NewReadOnlyMemoryRegionFromFile: " + fname; + string context = + "Failed to create file mapping for " + "NewReadOnlyMemoryRegionFromFile: " + + fname; return IOErrorFromWindowsError(context, ::GetLastError()); } UniqueCloseHandlePtr map_guard(hmap, CloseHandleFunc); - const void* mapped_region = ::MapViewOfFileEx( - hmap, FILE_MAP_READ, - 0, // High DWORD of access start - 0, // Low DWORD - file_size, - NULL); // Let the OS choose the mapping + const void* mapped_region = + ::MapViewOfFileEx(hmap, FILE_MAP_READ, + 0, // High DWORD of access start + 0, // Low DWORD + file_size, + NULL); // Let the OS choose the mapping if (!mapped_region) { - string context = "Failed to MapViewOfFile for " - "NewReadOnlyMemoryRegionFromFile: " + fname; + string context = + "Failed to MapViewOfFile for " + "NewReadOnlyMemoryRegionFromFile: " + + fname; return IOErrorFromWindowsError(context, ::GetLastError()); } - result->reset(new WinReadOnlyMemoryRegion(fname, hfile, hmap, - mapped_region, file_size)); + result->reset(new WinReadOnlyMemoryRegion(fname, hfile, hmap, mapped_region, + file_size)); map_guard.release(); file_guard.release(); @@ -404,8 +409,8 @@ Status WindowsFileSystem::GetChildren(const string& dir, } do { - string file_name = WideCharToUtf8(find_data.cFileName); - const StringPiece basename = file_name; + string file_name = WideCharToUtf8(find_data.cFileName); + const StringPiece basename = file_name; if (basename != "." && basename != "..") { result->push_back(file_name); } @@ -457,8 +462,7 @@ Status WindowsFileSystem::GetFileSize(const string& fname, uint64* size) { file_size.HighPart = attrs.nFileSizeHigh; file_size.LowPart = attrs.nFileSizeLow; *size = file_size.QuadPart; - } - else { + } else { string context = "Can not get size for: " + fname; result = IOErrorFromWindowsError(context, ::GetLastError()); } @@ -472,7 +476,7 @@ Status WindowsFileSystem::RenameFile(const string& src, const string& target) { std::wstring ws_translated_src = Utf8ToWideChar(TranslateName(src)); std::wstring ws_translated_target = Utf8ToWideChar(TranslateName(target)); if (!::MoveFileExW(ws_translated_src.c_str(), ws_translated_target.c_str(), - MOVEFILE_REPLACE_EXISTING)) { + MOVEFILE_REPLACE_EXISTING)) { string context(strings::StrCat("Failed to rename: ", src, " to: ", target)); result = IOErrorFromWindowsError(context, ::GetLastError()); } diff --git a/tensorflow/core/platform/windows/windows_file_system.h b/tensorflow/core/platform/windows/windows_file_system.h index 8dcc1530370f0615ec45785a1f3d10ce828d11a3..ba0302f0fd8b56dabaf9271a725bebdac4716102 100644 --- a/tensorflow/core/platform/windows/windows_file_system.h +++ b/tensorflow/core/platform/windows/windows_file_system.h @@ -63,33 +63,35 @@ class WindowsFileSystem : public FileSystem { Status RenameFile(const string& src, const string& target) override; - string TranslateName(const string& name) const override { - return name; - } + string TranslateName(const string& name) const override { return name; } static std::wstring Utf8ToWideChar(const string& utf8str) { - int size_required = MultiByteToWideChar(CP_UTF8, 0, utf8str.c_str(), (int)utf8str.size(), NULL, 0); - std::wstring ws_translated_str(size_required, 0); - MultiByteToWideChar(CP_UTF8, 0, utf8str.c_str(), (int)utf8str.size(), &ws_translated_str[0], size_required); - return ws_translated_str; + int size_required = MultiByteToWideChar(CP_UTF8, 0, utf8str.c_str(), + (int)utf8str.size(), NULL, 0); + std::wstring ws_translated_str(size_required, 0); + MultiByteToWideChar(CP_UTF8, 0, utf8str.c_str(), (int)utf8str.size(), + &ws_translated_str[0], size_required); + return ws_translated_str; } - static string WideCharToUtf8(const std::wstring &wstr) { - if (wstr.empty()) return std::string(); - int size_required = WideCharToMultiByte(CP_UTF8, 0, wstr.c_str(), (int)wstr.size(), NULL, 0, NULL, NULL); - string utf8_translated_str(size_required, 0); - WideCharToMultiByte(CP_UTF8, 0, wstr.c_str(), (int)wstr.size(), &utf8_translated_str[0], size_required, NULL, NULL); - return utf8_translated_str; + static string WideCharToUtf8(const std::wstring& wstr) { + if (wstr.empty()) return std::string(); + int size_required = WideCharToMultiByte( + CP_UTF8, 0, wstr.c_str(), (int)wstr.size(), NULL, 0, NULL, NULL); + string utf8_translated_str(size_required, 0); + WideCharToMultiByte(CP_UTF8, 0, wstr.c_str(), (int)wstr.size(), + &utf8_translated_str[0], size_required, NULL, NULL); + return utf8_translated_str; } }; class LocalWinFileSystem : public WindowsFileSystem { -public: - string TranslateName(const string& name) const override { - StringPiece scheme, host, path; - io::ParseURI(name, &scheme, &host, &path); - return path.ToString(); - } + public: + string TranslateName(const string& name) const override { + StringPiece scheme, host, path; + io::ParseURI(name, &scheme, &host, &path); + return path.ToString(); + } }; } // namespace tensorflow diff --git a/tensorflow/core/profiler/README.md b/tensorflow/core/profiler/README.md index 7997bdfa050a8c7a10b1edabfc35243a05d47ba8..57d76eb4cb9382790c80a0d55ee94b64e7b9dcdc 100644 --- a/tensorflow/core/profiler/README.md +++ b/tensorflow/core/profiler/README.md @@ -257,7 +257,7 @@ bug fix. `OpLogProto` is a good plus if it is used. #### Teams -* Xin Pan (xpan@google.com, github: panyx0718) +* Xin Pan * Chris Antaki * Yao Zhang * Jon Shlens diff --git a/tensorflow/core/profiler/internal/advisor/tfprof_advisor_test.cc b/tensorflow/core/profiler/internal/advisor/tfprof_advisor_test.cc index d05143aff9b8cc0b9a0e9af9445ba79345e4bf62..e968b9c97e28eeae22954102d5f0e07e09d75f7f 100644 --- a/tensorflow/core/profiler/internal/advisor/tfprof_advisor_test.cc +++ b/tensorflow/core/profiler/internal/advisor/tfprof_advisor_test.cc @@ -53,10 +53,13 @@ class TFProfAdvisorTest : public ::testing::Test { NodeExecStats node_stat; node_stat.set_all_start_micros(start_miros); node_stat.set_op_end_rel_micros(end_rel_micros); - node->AddStepStat(step, "/job:localhost/replica:0/task:0/device:GPU:0", node_stat); - node->AddStepStat(step, "/job:localhost/replica:0/task:0/device:GPU:0:stream:all", + node->AddStepStat(step, "/job:localhost/replica:0/task:0/device:GPU:0", node_stat); - node->AddStepStat(step, "/job:localhost/replica:0/task:0/device:GPU:0:stream:0", + node->AddStepStat(step, + "/job:localhost/replica:0/task:0/device:GPU:0:stream:all", + node_stat); + node->AddStepStat(step, + "/job:localhost/replica:0/task:0/device:GPU:0:stream:0", node_stat); return node; } diff --git a/tensorflow/core/profiler/internal/tfprof_op.cc b/tensorflow/core/profiler/internal/tfprof_op.cc index 5a8429d4893effc8bbfa0bf69e18b4a182e9a5df..3dce1d85db35436d162e73bf0946b320b899d5eb 100644 --- a/tensorflow/core/profiler/internal/tfprof_op.cc +++ b/tensorflow/core/profiler/internal/tfprof_op.cc @@ -113,8 +113,9 @@ const ShowMultiNode* TFOp::ShowInternal(const Options& opts, root_->formatted_str = FormatNode(root_.get(), root_.get(), opts); } if (timeline) { - fprintf(stderr, "op view doesn't support timeline yet. " - "Consider graph/scope/code view.\n"); + fprintf(stderr, + "op view doesn't support timeline yet. " + "Consider graph/scope/code view.\n"); return root_.get(); } if (cnodes_map_.empty()) { @@ -265,9 +266,9 @@ string TFOp::FormatNode(OpNode* node, OpNode* root, const Options& opts) const { double pct = 0.0; if (node->proto().total_parameters() > 0) { accu_pct = 100.0 * node->proto().total_parameters() / - root->proto().total_parameters(); - pct = 100.0 * node->proto().parameters() / - root->proto().total_parameters(); + root->proto().total_parameters(); + pct = + 100.0 * node->proto().parameters() / root->proto().total_parameters(); } attrs.push_back(strings::Printf( "%30s", @@ -282,9 +283,8 @@ string TFOp::FormatNode(OpNode* node, OpNode* root, const Options& opts) const { double pct = 0.0; if (node->proto().total_float_ops() > 0) { accu_pct = 100.0 * node->proto().total_float_ops() / - root->proto().total_float_ops(); - pct = 100.0 * node->proto().float_ops() / - root->proto().total_float_ops(); + root->proto().total_float_ops(); + pct = 100.0 * node->proto().float_ops() / root->proto().total_float_ops(); } attrs.push_back(strings::Printf( diff --git a/tensorflow/core/profiler/internal/tfprof_op.h b/tensorflow/core/profiler/internal/tfprof_op.h index fe1c3b2ae826783c1405b6151b82f153c05d2901..aa22182d36cac8d7e1f9fb3143beadfdfe0efce6 100644 --- a/tensorflow/core/profiler/internal/tfprof_op.h +++ b/tensorflow/core/profiler/internal/tfprof_op.h @@ -41,8 +41,7 @@ namespace tfprof { // to input ops. class TFOp : public TFMultiShow { public: - explicit TFOp() - : TFMultiShow() {} + explicit TFOp() : TFMultiShow() {} ~TFOp() override {} void AddNode(TFGraphNode* node) override; @@ -51,7 +50,7 @@ class TFOp : public TFMultiShow { private: const ShowMultiNode* ShowInternal(const Options& opts, - Timeline* timeline) override; + Timeline* timeline) override; int64 SearchRoot(const std::vector nodes, const std::vector& regexes); diff --git a/tensorflow/core/profiler/internal/tfprof_show.h b/tensorflow/core/profiler/internal/tfprof_show.h index 4d6de060705435c5346f6f49810b7dfc05d4530e..81b021549a49625cd5ba4a6ba8130f12cc7cf5f7 100644 --- a/tensorflow/core/profiler/internal/tfprof_show.h +++ b/tensorflow/core/profiler/internal/tfprof_show.h @@ -78,40 +78,43 @@ class TFShow { return nodes; } std::vector sorted_nodes = nodes; - std::sort(sorted_nodes.begin(), sorted_nodes.end(), [&opts](const T* n1, - const T* n2) { - if (n1->name() == kTFProfRoot) return true; - if (n2->name() == kTFProfRoot) return false; - bool name_cmp = n1->name() < n2->name(); - if (opts.order_by == kOrderBy[0]) { - return name_cmp; - } else if (opts.order_by == kOrderBy[1]) { - return n1->proto().total_requested_bytes() > - n2->proto().total_requested_bytes(); - } else if (opts.order_by == kOrderBy[2]) { - return n1->proto().total_peak_bytes() > n2->proto().total_peak_bytes(); - } else if (opts.order_by == kOrderBy[3]) { - return n1->proto().total_residual_bytes() > - n2->proto().total_residual_bytes(); - } else if (opts.order_by == kOrderBy[4]) { - return n1->proto().total_output_bytes() > - n2->proto().total_output_bytes(); - } else if (opts.order_by == kOrderBy[5]) { - return n1->proto().total_exec_micros() > - n2->proto().total_exec_micros(); - } else if (opts.order_by == kOrderBy[6]) { - return n1->proto().total_accelerator_exec_micros() > - n2->proto().total_accelerator_exec_micros(); - } else if (opts.order_by == kOrderBy[7]) { - return n1->proto().total_cpu_exec_micros() > - n2->proto().total_cpu_exec_micros(); - } else if (opts.order_by == kOrderBy[8]) { - return n1->proto().total_parameters() > n2->proto().total_parameters(); - } else if (opts.order_by == kOrderBy[9]) { - return n1->proto().total_float_ops() > n2->proto().total_float_ops(); - } - return name_cmp; - }); + std::sort(sorted_nodes.begin(), sorted_nodes.end(), + [&opts](const T* n1, const T* n2) { + if (n1->name() == kTFProfRoot) return true; + if (n2->name() == kTFProfRoot) return false; + bool name_cmp = n1->name() < n2->name(); + if (opts.order_by == kOrderBy[0]) { + return name_cmp; + } else if (opts.order_by == kOrderBy[1]) { + return n1->proto().total_requested_bytes() > + n2->proto().total_requested_bytes(); + } else if (opts.order_by == kOrderBy[2]) { + return n1->proto().total_peak_bytes() > + n2->proto().total_peak_bytes(); + } else if (opts.order_by == kOrderBy[3]) { + return n1->proto().total_residual_bytes() > + n2->proto().total_residual_bytes(); + } else if (opts.order_by == kOrderBy[4]) { + return n1->proto().total_output_bytes() > + n2->proto().total_output_bytes(); + } else if (opts.order_by == kOrderBy[5]) { + return n1->proto().total_exec_micros() > + n2->proto().total_exec_micros(); + } else if (opts.order_by == kOrderBy[6]) { + return n1->proto().total_accelerator_exec_micros() > + n2->proto().total_accelerator_exec_micros(); + } else if (opts.order_by == kOrderBy[7]) { + return n1->proto().total_cpu_exec_micros() > + n2->proto().total_cpu_exec_micros(); + } else if (opts.order_by == kOrderBy[8]) { + return n1->proto().total_parameters() > + n2->proto().total_parameters(); + } else if (opts.order_by == kOrderBy[9]) { + return n1->proto().total_float_ops() > + n2->proto().total_float_ops(); + } + return name_cmp; + }); return sorted_nodes; } diff --git a/tensorflow/core/profiler/internal/tfprof_show_multi.h b/tensorflow/core/profiler/internal/tfprof_show_multi.h index 2a2208d8e78efd5bc20d0db23e5fdaabbb3e8d5a..711d35f9753cf85f7f318a9ac3de40d6d2bf786e 100644 --- a/tensorflow/core/profiler/internal/tfprof_show_multi.h +++ b/tensorflow/core/profiler/internal/tfprof_show_multi.h @@ -50,7 +50,7 @@ class TFMultiShow { protected: virtual const ShowMultiNode* ShowInternal(const Options& opts, - Timeline* timeline) = 0; + Timeline* timeline) = 0; bool LookUpCheckPoint(const string& name, std::unique_ptr* tensor); diff --git a/tensorflow/core/profiler/internal/tfprof_timeline.h b/tensorflow/core/profiler/internal/tfprof_timeline.h index 651ad3f0c1c232d6f4b3730133b78c1e7f96a9bc..baf3fb2bedb13e13b21940485ec439c19a97dd02 100644 --- a/tensorflow/core/profiler/internal/tfprof_timeline.h +++ b/tensorflow/core/profiler/internal/tfprof_timeline.h @@ -20,8 +20,8 @@ limitations under the License. #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/step_stats.pb.h" #include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/protobuf/config.pb.h" #include "tensorflow/core/profiler/internal/tfprof_node_show.h" +#include "tensorflow/core/protobuf/config.pb.h" namespace tensorflow { namespace tfprof { diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index 67da7bf4526235ae51eb172f8da9fc267cc12b98..50bfa9126789033c617e22f25dbb76273fccfc60 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -19,12 +19,12 @@ limitations under the License. // TensorFlow uses semantic versioning, see http://semver.org/. #define TF_MAJOR_VERSION 1 -#define TF_MINOR_VERSION 5 +#define TF_MINOR_VERSION 6 #define TF_PATCH_VERSION 0 // TF_VERSION_SUFFIX is non-empty for pre-releases (e.g. "-alpha", "-alpha.1", // "-beta", "-rc", "-rc.1") -#define TF_VERSION_SUFFIX "-rc1" +#define TF_VERSION_SUFFIX "-rc0" #define TF_STR_HELPER(x) #x #define TF_STR(x) TF_STR_HELPER(x) diff --git a/tensorflow/core/util/bcast.cc b/tensorflow/core/util/bcast.cc index 1eab7e3d024c181f260500686b9127dd76dbe206..3a5f1f83af8d2d2324f3139568aa69f204cf1248 100644 --- a/tensorflow/core/util/bcast.cc +++ b/tensorflow/core/util/bcast.cc @@ -69,9 +69,9 @@ BCast::BCast(const Vec& sx, const Vec& sy, const bool fewer_dims_optimization) { State curr = UNKNOWN; const int64 x_i = x[i]; // i-th dimension of x. const int64 y_i = y[i]; // i-th dimension of y. - int64 o_i; // i-th dimension of the output. - int64 bx_i; // i-th broadcast for x. - int64 by_i; // i-th broadcast for y. + int64 o_i; // i-th dimension of the output. + int64 bx_i; // i-th broadcast for x. + int64 by_i; // i-th broadcast for y. // Invariant: // o_i = x_i * bx_i = y_i * by_i if (x_i == y_i) { diff --git a/tensorflow/core/util/ctc/ctc_loss_calculator.h b/tensorflow/core/util/ctc/ctc_loss_calculator.h index be00895b0d3517fe06a852685f79f32e5a0b5167..dd1163310bf406b66bdd450ac6bf840272f7c592 100644 --- a/tensorflow/core/util/ctc/ctc_loss_calculator.h +++ b/tensorflow/core/util/ctc/ctc_loss_calculator.h @@ -130,13 +130,13 @@ Status CTCLossCalculator::CalculateLoss( for (int t = 1; t < num_time_steps; ++t) { if (inputs[t].rows() != batch_size) { return errors::InvalidArgument("Expected batch size at t: ", t, - " to be: ", batch_size, " but got: ", - inputs[t].rows()); + " to be: ", batch_size, + " but got: ", inputs[t].rows()); } if (inputs[t].cols() != num_classes) { return errors::InvalidArgument("Expected class count at t: ", t, - " to be: ", num_classes, " but got: ", - inputs[t].cols()); + " to be: ", num_classes, + " but got: ", inputs[t].cols()); } } @@ -282,8 +282,8 @@ Status CTCLossCalculator::PopulateLPrimes( LabelSequences* l_primes) const { // labels is a Label array of size batch_size if (labels.size() != batch_size) { - return errors::InvalidArgument("labels.size() != batch_size: ", - labels.size(), " vs. ", batch_size); + return errors::InvalidArgument( + "labels.size() != batch_size: ", labels.size(), " vs. ", batch_size); } *max_u_prime = 0; // keep track of longest l' modified label sequence. @@ -325,12 +325,13 @@ Status CTCLossCalculator::PopulateLPrimes( for (int l_i : l) { if (l_i < 0) { return errors::InvalidArgument( - "All labels must be nonnegative integers, batch: ", b, " labels: ", - str_util::Join(l, ",")); + "All labels must be nonnegative integers, batch: ", b, + " labels: ", str_util::Join(l, ",")); } else if (l_i >= num_classes) { return errors::InvalidArgument( - "No label may be greater than num_classes. ", "num_classes: ", - num_classes, ", batch: ", b, " labels: ", str_util::Join(l, ",")); + "No label may be greater than num_classes. ", + "num_classes: ", num_classes, ", batch: ", b, + " labels: ", str_util::Join(l, ",")); } } if (!ignore_longer_outputs_than_inputs) { diff --git a/tensorflow/core/util/cuda_device_functions.h b/tensorflow/core/util/cuda_device_functions.h index f787687f6628797ce9c7d21f65fb6fd983710bb6..f2d4e470c82d9a1480ac1bf7726a7a7a9ae08715 100644 --- a/tensorflow/core/util/cuda_device_functions.h +++ b/tensorflow/core/util/cuda_device_functions.h @@ -28,14 +28,10 @@ limitations under the License. #include #include +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "cuda/include/cuda.h" -#include "cuda/include/device_functions.h" #include "tensorflow/core/platform/types.h" -#if CUDA_VERSION >= 7050 -#include "cuda/include/cuda_fp16.h" -#endif // CUDA_VERSION >= 7050 - namespace tensorflow { namespace detail { @@ -394,6 +390,17 @@ __global__ void SetZero(const int count, T* ptr) { } } +// Helper to set all tensor entries to a specific value. +template +__global__ void SetToValue(const int count, T* ptr, T value) { + // Check that the grid is one dimensional and index doesn't overflow. + assert(blockDim.y == 1 && blockDim.z == 1); + assert(blockDim.x * gridDim.x / blockDim.x == gridDim.x); + for (int i : CudaGridRangeX(count)) { + ptr[i] = value; + } +} + namespace detail { // Helper function for atomic accumulation implemented as CAS. template @@ -425,6 +432,47 @@ __device__ double CudaAtomicCasHelper(double* ptr, F accumulate) { })); } +// Overload of above function for half. Note that we don't have +// atomicCAS() for anything less than 32 bits, so we need to include the +// other 16 bits in the operation. +// +// This version is going to be very slow +// under high concurrency, since most threads will be spinning on failing +// their compare-and-swap tests. (The fact that we get false sharing on the +// neighboring fp16 makes this even worse.) If you are doing a large reduction, +// you are much better off with doing the intermediate steps in fp32 and then +// switching to fp16 as late as you can in the calculations. +// +// Note: Assumes little endian. +template +__device__ Eigen::half CudaAtomicCasHelper(Eigen::half* ptr, F accumulate) { +#if defined(__BYTE_ORDER__) && defined(__ORDER_LITTLE_ENDIAN__) + static_assert(__BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__, "Not little endian"); +#endif + namespace half_impl = Eigen::half_impl; + intptr_t intptr = reinterpret_cast(ptr); + assert(!(intptr & 0x1)); // should be 2-aligned. + if (intptr & 0x2) { + // The half is in the second part of the uint32 (upper 16 bits). + uint32* address = reinterpret_cast(intptr - 2); + uint32 result = CudaAtomicCasHelper(address, [accumulate](uint32 arg) { + unsigned short high = static_cast(arg >> 16); + Eigen::half acc = accumulate(half_impl::raw_uint16_to_half(high)); + return (static_cast(acc.x) << 16) | (arg & 0xffff); + }); + return half_impl::raw_uint16_to_half(static_cast(result >> 16)); + } else { + // The half is in the first part of the uint32 (lower 16 bits). + uint32* address = reinterpret_cast(intptr); + uint32 result = CudaAtomicCasHelper(address, [accumulate](uint32 arg) { + unsigned short low = static_cast(arg & 0xffff); + Eigen::half acc = accumulate(half_impl::raw_uint16_to_half(low)); + return (arg & 0xffff0000) | static_cast(acc.x); + }); + return half_impl::raw_uint16_to_half(static_cast(result & 0xffff)); + } +} + template using ToTypeIfConvertible = typename std::enable_if::value, To>::type; @@ -438,6 +486,14 @@ template __device__ detail::ToTypeIfConvertible CudaAtomicAdd(T* ptr, U value) { return atomicAdd(ptr, value); } + +__device__ inline Eigen::half CudaAtomicAdd(Eigen::half* ptr, + Eigen::half value) { + return detail::CudaAtomicCasHelper( + ptr, [value](Eigen::half a) { return a + value; }); +} + + #if __CUDA_ARCH__ < 600 __device__ inline double CudaAtomicAdd(double* ptr, double value) { return detail::CudaAtomicCasHelper(ptr, @@ -455,27 +511,74 @@ __device__ inline double CudaAtomicAdd(double* ptr, double value) { return result; } #endif - +// CudaAtomicAdd +// Specializations of CudaAtomicAdd for complex types, which CudaAtomicAdd does +// not support. We treat a std::complex* as a T* (the C++ standard section +// 26.4.4 allows this explicitly) and atomic add the real and imaginary +// components individually. The operation as a whole is not atomic, but we can +// safely treat the components independently for the purpose of accumulating. +__device__ inline std::complex CudaAtomicAdd(std::complex* ptr, + std::complex value) { + auto ptr_scalar = reinterpret_cast(ptr); + return std::complex(CudaAtomicAdd(ptr_scalar, value.real()), + CudaAtomicAdd(ptr_scalar + 1, value.imag())); +} + +__device__ inline std::complex CudaAtomicAdd( + std::complex* ptr, std::complex value) { + auto ptr_scalar = reinterpret_cast(ptr); + return std::complex(CudaAtomicAdd(ptr_scalar, value.real()), + CudaAtomicAdd(ptr_scalar + 1, value.imag())); +} + +// CudaAtomicSub template __device__ detail::ToTypeIfConvertible CudaAtomicSub(T* ptr, U value) { return atomicSub(ptr, value); } + // Specializations of substraction which add the negative value. __device__ inline float CudaAtomicSub(float* ptr, float value) { return CudaAtomicAdd(ptr, -value); } + __device__ inline double CudaAtomicSub(double* ptr, double value) { return CudaAtomicAdd(ptr, -value); } + __device__ inline tensorflow::uint64 CudaAtomicSub(tensorflow::uint64* ptr, tensorflow::uint64 value) { return CudaAtomicAdd(ptr, -value); } +__device__ inline Eigen::half CudaAtomicSub(Eigen::half* ptr, + Eigen::half value) { + return detail::CudaAtomicCasHelper( + ptr, [value](Eigen::half a) { return a - value; }); +} + +// CudaAtomicMax template __device__ detail::ToTypeIfConvertible CudaAtomicMax(T* ptr, U value) { return atomicMax(ptr, value); } + +__device__ inline float CudaAtomicMax(float* ptr, float value) { + return detail::CudaAtomicCasHelper( + ptr, [value](float a) { return max(a, value); }); +} + +__device__ inline double CudaAtomicMax(double* ptr, double value) { + return detail::CudaAtomicCasHelper( + ptr, [value](double a) { return max(a, value); }); +} + +__device__ inline Eigen::half CudaAtomicMax(Eigen::half* ptr, + Eigen::half value) { + return detail::CudaAtomicCasHelper( + ptr, [value](Eigen::half a) { return max(a, value); }); +} + #if __CUDA_ARCH__ < 320 __device__ inline tensorflow::uint64 CudaAtomicMax(tensorflow::uint64* ptr, tensorflow::uint64 value) { @@ -484,10 +587,43 @@ __device__ inline tensorflow::uint64 CudaAtomicMax(tensorflow::uint64* ptr, } #endif +// CudaAtomicMin +template +__device__ detail::ToTypeIfConvertible CudaAtomicMin(T* ptr, U value) { + return atomicMin(ptr, value); +} + +__device__ inline float CudaAtomicMin(float* ptr, float value) { + return detail::CudaAtomicCasHelper( + ptr, [value](float a) { return min(a, value); }); +} + +__device__ inline double CudaAtomicMin(double* ptr, double value) { + return detail::CudaAtomicCasHelper( + ptr, [value](double a) { return min(a, value); }); +} + +__device__ inline Eigen::half CudaAtomicMin(Eigen::half* ptr, + Eigen::half value) { + return detail::CudaAtomicCasHelper( + ptr, [value](Eigen::half a) { return min(a, value); }); +} + +#if __CUDA_ARCH__ < 320 +__device__ inline tensorflow::uint64 CudaAtomicMin(tensorflow::uint64* ptr, + tensorflow::uint64 value) { + return detail::CudaAtomicCasHelper( + ptr, [value](tensorflow::uint64 a) { return min(a, value); }); +} +#endif + +// CudaAtomicMul template __device__ detail::ToTypeIfConvertible CudaAtomicMul(T* ptr, U value) { return detail::CudaAtomicCasHelper(ptr, [value](T a) { return a * value; }); } + +// CudaAtomicDiv template __device__ detail::ToTypeIfConvertible CudaAtomicDiv(T* ptr, U value) { return detail::CudaAtomicCasHelper(ptr, [value](T a) { return a / value; }); diff --git a/tensorflow/core/util/cuda_kernel_helper.h b/tensorflow/core/util/cuda_kernel_helper.h index 18a4c008f138bb4ba3b1e4c381781e0c363863f7..3c59524cb6f85911544b8f2d7d3339e19af7f5b4 100644 --- a/tensorflow/core/util/cuda_kernel_helper.h +++ b/tensorflow/core/util/cuda_kernel_helper.h @@ -90,60 +90,6 @@ __device__ EIGEN_ALWAYS_INLINE Eigen::half CudaShuffleXorSync( CudaShuffleXorSync(mask, static_cast(value), lane_mask, width)); } -namespace detail { -// Overload of above function for half. Note that we don't have -// atomicCAS() for anything less than 32 bits, so we need to include the -// other 16 bits in the operation. -// -// This version is going to be very slow -// under high concurrency, since most threads will be spinning on failing -// their compare-and-swap tests. (The fact that we get false sharing on the -// neighboring fp16 makes this even worse.) If you are doing a large reduction, -// you are much better off with doing the intermediate steps in fp32 and then -// switching to fp16 as late as you can in the calculations. -// -// Note: Assumes little endian. -template -__device__ Eigen::half CudaAtomicCasHelper(Eigen::half* ptr, F accumulate) { -#if defined(__BYTE_ORDER__) && defined(__ORDER_LITTLE_ENDIAN__) - static_assert(__BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__, "Not little endian"); -#endif - namespace half_impl = Eigen::half_impl; - intptr_t intptr = reinterpret_cast(ptr); - assert(!(intptr & 0x1)); // should be 2-aligned. - if (intptr & 0x2) { - // The half is in the second part of the uint32 (upper 16 bits). - uint32* address = reinterpret_cast(intptr - 2); - uint32 result = CudaAtomicCasHelper(address, [accumulate](uint32 arg) { - unsigned short high = static_cast(arg >> 16); - Eigen::half acc = accumulate(half_impl::raw_uint16_to_half(high)); - return (static_cast(acc.x) << 16) | (arg & 0xffff); - }); - return half_impl::raw_uint16_to_half(static_cast(result >> 16)); - } else { - // The half is in the first part of the uint32 (lower 16 bits). - uint32* address = reinterpret_cast(intptr); - uint32 result = CudaAtomicCasHelper(address, [accumulate](uint32 arg) { - unsigned short low = static_cast(arg & 0xffff); - Eigen::half acc = accumulate(half_impl::raw_uint16_to_half(low)); - return (arg & 0xffff0000) | static_cast(acc.x); - }); - return half_impl::raw_uint16_to_half(static_cast(result & 0xffff)); - } -} -} // namespace detail - -__device__ inline Eigen::half CudaAtomicAdd(Eigen::half* ptr, - Eigen::half value) { - return detail::CudaAtomicCasHelper( - ptr, [value](Eigen::half a) { return a + value; }); -} -__device__ inline Eigen::half CudaAtomicSub(Eigen::half* ptr, - Eigen::half value) { - return detail::CudaAtomicCasHelper( - ptr, [value](Eigen::half a) { return a - value; }); -} - namespace cuda_helper { template __device__ IntType upper_bound(IntType* first, IntType count, IntType val) { diff --git a/tensorflow/core/util/cuda_kernel_helper_test.cu.cc b/tensorflow/core/util/cuda_kernel_helper_test.cu.cc index bd4c356ea01b81f4f7bb230481f4b9bef981ada4..732ed33ede17bc90d3301d3f1eee6302a96028d7 100644 --- a/tensorflow/core/util/cuda_kernel_helper_test.cu.cc +++ b/tensorflow/core/util/cuda_kernel_helper_test.cu.cc @@ -149,27 +149,27 @@ class CudaLaunchConfigTest : public ::testing::Test { TEST_F(CudaLaunchConfigTest, GetCudaLaunchConfig) { CudaLaunchConfig cfg; - // test valid inputs - #define TEST_LAUNCH_PARAMETER(work_element_count) \ - cfg = GetCudaLaunchConfig(bufsize, d); \ - SetOutbufZero<<>> \ - (cfg, outbuf); \ - CUDA_ASSERT_SUCCESS \ - cfg = GetCudaLaunchConfig(work_element_count, d); \ - Count1D<<>> ( \ - cfg, bufsize, outbuf); \ - CUDA_EXPECT_SUCCESS \ - EXPECT_EQ(work_element_count, std::accumulate(outbuf, outbuf + bufsize, 0));\ - \ - cfg = GetCudaLaunchConfig(bufsize, d, SetOutbufZero, 0, 0); \ - SetOutbufZero<<>> \ - (cfg, outbuf); \ - CUDA_ASSERT_SUCCESS \ - cfg = GetCudaLaunchConfig(work_element_count, d, Count1D, 0, 0); \ - Count1D<<>> ( \ - cfg, bufsize, outbuf); \ - CUDA_EXPECT_SUCCESS \ - EXPECT_EQ(work_element_count, std::accumulate(outbuf, outbuf + bufsize, 0)) +// test valid inputs +#define TEST_LAUNCH_PARAMETER(work_element_count) \ + cfg = GetCudaLaunchConfig(bufsize, d); \ + SetOutbufZero<<>>( \ + cfg, outbuf); \ + CUDA_ASSERT_SUCCESS \ + cfg = GetCudaLaunchConfig(work_element_count, d); \ + Count1D<<>>( \ + cfg, bufsize, outbuf); \ + CUDA_EXPECT_SUCCESS \ + EXPECT_EQ(work_element_count, std::accumulate(outbuf, outbuf + bufsize, 0)); \ + \ + cfg = GetCudaLaunchConfig(bufsize, d, SetOutbufZero, 0, 0); \ + SetOutbufZero<<>>( \ + cfg, outbuf); \ + CUDA_ASSERT_SUCCESS \ + cfg = GetCudaLaunchConfig(work_element_count, d, Count1D, 0, 0); \ + Count1D<<>>( \ + cfg, bufsize, outbuf); \ + CUDA_EXPECT_SUCCESS \ + EXPECT_EQ(work_element_count, std::accumulate(outbuf, outbuf + bufsize, 0)) TEST_LAUNCH_PARAMETER(128); TEST_LAUNCH_PARAMETER(129); @@ -181,7 +181,7 @@ TEST_F(CudaLaunchConfigTest, GetCudaLaunchConfig) { TEST_LAUNCH_PARAMETER(8192); TEST_LAUNCH_PARAMETER(123456); TEST_LAUNCH_PARAMETER(1 << 30); - #undef TEST_LAUNCH_PARAMETER +#undef TEST_LAUNCH_PARAMETER } bool operator==(const Cuda2DLaunchConfig& a, const Cuda2DLaunchConfig& b) { @@ -200,27 +200,27 @@ TEST_F(CudaLaunchConfigTest, GetCuda2DLaunchConfig) { Cuda2DLaunchConfig cfg; CudaLaunchConfig cfg1d; - // test valid inputs - #define TEST_LAUNCH_PARAMETER(dimx, dimy) \ - cfg1d = GetCudaLaunchConfig(bufsize, d); \ - SetOutbufZero<<>> \ - (cfg1d, outbuf);\ - CUDA_ASSERT_SUCCESS \ - cfg = GetCuda2DLaunchConfig(dimx, dimy, d); \ - Count2D<<>> ( \ - cfg, bufsize, outbuf); \ - CUDA_EXPECT_SUCCESS \ - EXPECT_EQ(dimx * dimy, std::accumulate(outbuf, outbuf + bufsize, 0)); \ - \ - cfg1d = GetCudaLaunchConfig(bufsize, d, SetOutbufZero, 0, 0); \ - SetOutbufZero<<>> \ - (cfg1d, outbuf);\ - CUDA_ASSERT_SUCCESS \ - cfg = GetCuda2DLaunchConfig(dimx, dimy, d, Count2D, 0, 0); \ - Count2D<<>> ( \ - cfg, bufsize, outbuf); \ - CUDA_EXPECT_SUCCESS \ - EXPECT_EQ(dimx * dimy, std::accumulate(outbuf, outbuf + bufsize, 0)) +// test valid inputs +#define TEST_LAUNCH_PARAMETER(dimx, dimy) \ + cfg1d = GetCudaLaunchConfig(bufsize, d); \ + SetOutbufZero<<>>( \ + cfg1d, outbuf); \ + CUDA_ASSERT_SUCCESS \ + cfg = GetCuda2DLaunchConfig(dimx, dimy, d); \ + Count2D<<>>( \ + cfg, bufsize, outbuf); \ + CUDA_EXPECT_SUCCESS \ + EXPECT_EQ(dimx* dimy, std::accumulate(outbuf, outbuf + bufsize, 0)); \ + \ + cfg1d = GetCudaLaunchConfig(bufsize, d, SetOutbufZero, 0, 0); \ + SetOutbufZero<<>>( \ + cfg1d, outbuf); \ + CUDA_ASSERT_SUCCESS \ + cfg = GetCuda2DLaunchConfig(dimx, dimy, d, Count2D, 0, 0); \ + Count2D<<>>( \ + cfg, bufsize, outbuf); \ + CUDA_EXPECT_SUCCESS \ + EXPECT_EQ(dimx* dimy, std::accumulate(outbuf, outbuf + bufsize, 0)) TEST_LAUNCH_PARAMETER(128, 128); TEST_LAUNCH_PARAMETER(129, 64); @@ -233,24 +233,24 @@ TEST_F(CudaLaunchConfigTest, GetCuda2DLaunchConfig) { TEST_LAUNCH_PARAMETER(123456, 12); TEST_LAUNCH_PARAMETER(1, 1 << 30); TEST_LAUNCH_PARAMETER(1 << 30, 1); - #undef TEST_LAUNCH_PARAMETER +#undef TEST_LAUNCH_PARAMETER } TEST_F(CudaLaunchConfigTest, GetCuda3DLaunchConfig) { Cuda3DLaunchConfig cfg; CudaLaunchConfig cfg1d; - // test valid inputs - #define TEST_LAUNCH_PARAMETER(dimx, dimy, dimz) \ - cfg1d = GetCudaLaunchConfig(bufsize, d, SetOutbufZero, 0, 0); \ - SetOutbufZero<<>> \ - (cfg1d, outbuf);\ - CUDA_ASSERT_SUCCESS \ - cfg = GetCuda3DLaunchConfig(dimx, dimy, dimz, d, Count3D, 0, 0); \ - Count3D<<>> ( \ - cfg, bufsize, outbuf); \ - CUDA_EXPECT_SUCCESS \ - EXPECT_EQ(dimx * dimy * dimz, std::accumulate(outbuf, outbuf + bufsize, 0)) +// test valid inputs +#define TEST_LAUNCH_PARAMETER(dimx, dimy, dimz) \ + cfg1d = GetCudaLaunchConfig(bufsize, d, SetOutbufZero, 0, 0); \ + SetOutbufZero<<>>( \ + cfg1d, outbuf); \ + CUDA_ASSERT_SUCCESS \ + cfg = GetCuda3DLaunchConfig(dimx, dimy, dimz, d, Count3D, 0, 0); \ + Count3D<<>>( \ + cfg, bufsize, outbuf); \ + CUDA_EXPECT_SUCCESS \ + EXPECT_EQ(dimx* dimy* dimz, std::accumulate(outbuf, outbuf + bufsize, 0)) TEST_LAUNCH_PARAMETER(128, 128, 128); TEST_LAUNCH_PARAMETER(129, 64, 1024); @@ -264,7 +264,7 @@ TEST_F(CudaLaunchConfigTest, GetCuda3DLaunchConfig) { TEST_LAUNCH_PARAMETER(1, 1, 1 << 30); TEST_LAUNCH_PARAMETER(1, 1 << 30, 1); TEST_LAUNCH_PARAMETER(1 << 30, 1, 1); - #undef TEST_LAUNCH_PARAMETER +#undef TEST_LAUNCH_PARAMETER } TEST(CudaDeviceFunctionsTest, ShuffleGetSrcLane) { diff --git a/tensorflow/core/util/example_proto_fast_parsing_test.cc b/tensorflow/core/util/example_proto_fast_parsing_test.cc index 9b6a8e12511448b72e17a0b20a4418c4a5cd2c7a..13e41c17f7c7df5ad581bd3f6a39051641139258 100644 --- a/tensorflow/core/util/example_proto_fast_parsing_test.cc +++ b/tensorflow/core/util/example_proto_fast_parsing_test.cc @@ -57,6 +57,7 @@ void TestCorrectness(const string& serialized) { Example example; Example fast_example; EXPECT_TRUE(example.ParseFromString(serialized)); + example.DiscardUnknownFields(); EXPECT_TRUE(TestFastParse(serialized, &fast_example)); EXPECT_EQ(example.DebugString(), fast_example.DebugString()); if (example.DebugString() != fast_example.DebugString()) { diff --git a/tensorflow/core/util/example_proto_helper.cc b/tensorflow/core/util/example_proto_helper.cc index 41f56d2daa48e651f5ac4051deae9c05ef1ed859..e156a3bc8f0f01acc543e9b385bd9782870be52a 100644 --- a/tensorflow/core/util/example_proto_helper.cc +++ b/tensorflow/core/util/example_proto_helper.cc @@ -247,8 +247,9 @@ Status SingleExampleProtoToTensors( bool types_match; TF_RETURN_IF_ERROR(CheckTypesMatch(f, dtype, &types_match)); if (!types_match) { - return errors::InvalidArgument("Name: ", example_name, ", Feature: ", - key, ". Data types don't match. ", + return errors::InvalidArgument("Name: ", example_name, + ", Feature: ", key, + ". Data types don't match. ", "Expected type: ", DataTypeString(dtype), " Feature is: ", ProtoDebugString(f)); } @@ -278,8 +279,9 @@ Status SingleExampleProtoToTensors( bool types_match; TF_RETURN_IF_ERROR(CheckTypesMatch(f, dtype, &types_match)); if (!types_match) { - return errors::InvalidArgument("Name: ", example_name, ", Feature: ", - key, ". Data types don't match. ", + return errors::InvalidArgument("Name: ", example_name, + ", Feature: ", key, + ". Data types don't match. ", "Expected type: ", DataTypeString(dtype), " Feature is: ", ProtoDebugString(f)); } diff --git a/tensorflow/core/util/memmapped_file_system_test.cc b/tensorflow/core/util/memmapped_file_system_test.cc index 616eb5dac32188688ac01cf49ff583dc1623d5ad..504d2d353f8f76f77e4efd3e4a6a6edcaa200711 100644 --- a/tensorflow/core/util/memmapped_file_system_test.cc +++ b/tensorflow/core/util/memmapped_file_system_test.cc @@ -144,8 +144,8 @@ TEST(MemmappedFileSystemTest, ProxyToDefault) { TF_ASSERT_OK(memmapped_env.NewAppendableFile(filename, &writable_file_temp)); // Making sure to clean up after the test finishes. const auto adh = [&memmapped_env, &filename](WritableFile* f) { - delete f; - TF_CHECK_OK(memmapped_env.DeleteFile(filename)); + delete f; + TF_CHECK_OK(memmapped_env.DeleteFile(filename)); }; std::unique_ptr writable_file( writable_file_temp.release(), adh); diff --git a/tensorflow/core/util/mkl_util.h b/tensorflow/core/util/mkl_util.h index 34ef7ba21b442b7290aeb38565a402f3dbe707f1..4467373c0060bb4dd25108891e2ff51d903a2453 100644 --- a/tensorflow/core/util/mkl_util.h +++ b/tensorflow/core/util/mkl_util.h @@ -210,31 +210,32 @@ class MklShape { CHECK_EQ(dnnDelete_F32(convert), E_SUCCESS); } -// The following methods are used for serializing and de-serializing the -// contents of the mklshape object. -// The data is serialized in this order -// isMklTensor_ -// dimension_ -// sizes_ -// strides_ -// mklLayout_ -// tfLayout_ -// tf_to_mkl_dim_map_ + // The following methods are used for serializing and de-serializing the + // contents of the mklshape object. + // The data is serialized in this order + // isMklTensor_ + // dimension_ + // sizes_ + // strides_ + // mklLayout_ + // tfLayout_ + // tf_to_mkl_dim_map_ #define SIZE_OF_MKL_DNN_BUF \ (dnnLayoutSerializationBufferSize_F32()) // Size of buffer needed to // serialize dnn_layout pointer -// Size of buffer to hold the serialized object, the size is computed as follows -// sizeof(isMklTensor_) + sizeof(dimension_) + sizeof(sizes_) + sizeof(strides_) -// + sizeof(mklLayout_ buffer) + sizeof(tfLayout_ buffer) -// + sizeof(tf_to_mkl_dim_map_) + // Size of buffer to hold the serialized object, the size is computed as + // follows sizeof(isMklTensor_) + sizeof(dimension_) + sizeof(sizes_) + + // sizeof(strides_) + // + sizeof(mklLayout_ buffer) + sizeof(tfLayout_ buffer) + // + sizeof(tf_to_mkl_dim_map_) #define SIZE_OF_MKL_SERIAL_DATA(dims) \ (2 * sizeof(size_t) + 3 * dims * sizeof(size_t) + 2 * SIZE_OF_MKL_DNN_BUF) -// First we need to define some macro for offsets into the serial buffer where -// different elements of Mklshape is written/read from + // First we need to define some macro for offsets into the serial buffer where + // different elements of Mklshape is written/read from #define IS_MKL_TENSOR_OFFSET 0 // Location from start of buffer where isMklTensor_ is serialized @@ -388,7 +389,7 @@ class MklDnnShape { /// Equality function for MklDnnShape objects /// @return true if both are equal; false otherwise. - inline bool operator == (const MklDnnShape& input_shape) const { + inline bool operator==(const MklDnnShape& input_shape) const { if (this->IsMklTensor() != input_shape.IsMklTensor()) { return false; } @@ -406,7 +407,7 @@ class MklDnnShape { /// Equality operator for MklDnnShape and TFShape. /// Returns: true if TF shapes for both are the same, false otherwise - inline bool operator == (const TensorShape& input_shape) const { + inline bool operator==(const TensorShape& input_shape) const { if (!this->IsMklTensor()) { return false; } @@ -425,7 +426,7 @@ class MklDnnShape { inline size_t GetDimension(char dimension) const { int index = GetMklDnnTensorDimIndex(dimension); CHECK(index >= 0 && index < this->GetDimension()) - << "Invalid index from the dimension: " << index << ", " << dimension; + << "Invalid index from the dimension: " << index << ", " << dimension; return this->DimSize(index); } @@ -705,8 +706,8 @@ inline Tensor ConvertMklToTF(OpKernelContext* context, const Tensor& mkl_tensor, Tensor output_tensor; TensorShape output_shape; - TF_CHECK_OK(Status(error::Code::UNIMPLEMENTED, - "Unimplemented conversion function")); + TF_CHECK_OK( + Status(error::Code::UNIMPLEMENTED, "Unimplemented conversion function")); return output_tensor; } @@ -973,8 +974,8 @@ inline int64 GetMklTensorDim(const MklShape& mkl_shape, char dimension) { return mkl_shape.dim_size(index); } -inline void CopyMklTensorInToOut(OpKernelContext* context, - int idx_in, int idx_out) { +inline void CopyMklTensorInToOut(OpKernelContext* context, int idx_in, + int idx_out) { int num_inputs = context->num_inputs(); int num_outputs = context->num_outputs(); int idx_data_in = GetTensorDataIndex(idx_in, num_inputs); @@ -995,8 +996,8 @@ inline void CopyMklTensorInToOut(OpKernelContext* context, } #ifdef INTEL_MKL_ML -inline void CopyTfTensorInToOutWithShape(OpKernelContext* context, - int idx_in, int idx_out, +inline void CopyTfTensorInToOutWithShape(OpKernelContext* context, int idx_in, + int idx_out, const TensorShape& shape) { int num_inputs = context->num_inputs(); int num_outputs = context->num_outputs(); @@ -1013,8 +1014,8 @@ inline void CopyTfTensorInToOutWithShape(OpKernelContext* context, context->set_output(idx_data_out, output); } #else -inline void CopyTfTensorInToOutWithShape(OpKernelContext* context, - int idx_in, int idx_out, +inline void CopyTfTensorInToOutWithShape(OpKernelContext* context, int idx_in, + int idx_out, const TensorShape& shape) { int num_inputs = context->num_inputs(); int num_outputs = context->num_outputs(); @@ -1034,8 +1035,8 @@ inline void CopyTfTensorInToOutWithShape(OpKernelContext* context, #ifdef INTEL_MKL_ML -inline void ForwardTfTensorInToOut(OpKernelContext* context, - int idx_in, int idx_out) { +inline void ForwardTfTensorInToOut(OpKernelContext* context, int idx_in, + int idx_out) { int num_inputs = context->num_inputs(); int num_outputs = context->num_outputs(); int idx_data_in = GetTensorDataIndex(idx_in, num_inputs); @@ -1053,8 +1054,8 @@ inline void ForwardTfTensorInToOut(OpKernelContext* context, #else -inline void ForwardTfTensorInToOut(OpKernelContext* context, - int idx_in, int idx_out) { +inline void ForwardTfTensorInToOut(OpKernelContext* context, int idx_in, + int idx_out) { int num_inputs = context->num_inputs(); int num_outputs = context->num_outputs(); int idx_data_in = GetTensorDataIndex(idx_in, num_inputs); @@ -1072,8 +1073,8 @@ inline void ForwardTfTensorInToOut(OpKernelContext* context, #endif -inline void ForwardMklTensorInToOut(OpKernelContext* context, - int idx_in, int idx_out) { +inline void ForwardMklTensorInToOut(OpKernelContext* context, int idx_in, + int idx_out) { int num_inputs = context->num_inputs(); int num_outputs = context->num_outputs(); int idx_data_in = GetTensorDataIndex(idx_in, num_inputs); @@ -1092,8 +1093,8 @@ inline void ForwardMklTensorInToOut(OpKernelContext* context, #ifndef INTEL_MKL_ML inline void ForwardMklTensorInToOutWithMklShape(OpKernelContext* context, - int idx_in, int idx_out, - const MklDnnShape& mkl_shape) { + int idx_in, int idx_out, + const MklDnnShape& mkl_shape) { int num_inputs = context->num_inputs(); int num_outputs = context->num_outputs(); int idx_data_in = GetTensorDataIndex(idx_in, num_inputs); @@ -1216,11 +1217,11 @@ inline void MklNHWCToNCHW(const Tensor& input, Tensor** output) { int64 H = input.dim_size(1); int64 W = input.dim_size(2); int64 C = input.dim_size(3); - int64 stride_n = H*W*C; -# pragma omp parallel for num_threads(16) + int64 stride_n = H * W * C; +#pragma omp parallel for num_threads(16) for (int64 n = 0; n < N; ++n) { - mkl_somatcopy('R', 'T', H*W, C, 1, buf_in + n*stride_n, C, - buf_out + n*stride_n, H*W); + mkl_somatcopy('R', 'T', H * W, C, 1, buf_in + n * stride_n, C, + buf_out + n * stride_n, H * W); } } @@ -1232,11 +1233,11 @@ inline void MklNCHWToNHWC(const Tensor& input, Tensor** output) { int64 H = (*output)->dim_size(1); int64 W = (*output)->dim_size(2); int64 C = (*output)->dim_size(3); - int64 stride_n = H*W*C; -# pragma omp parallel for num_threads(16) + int64 stride_n = H * W * C; +#pragma omp parallel for num_threads(16) for (int64 n = 0; n < N; ++n) { - mkl_somatcopy('R', 'T', C, H*W, 1, buf_in + n*stride_n, H*W, - buf_out + n*stride_n, C); + mkl_somatcopy('R', 'T', C, H * W, 1, buf_in + n * stride_n, H * W, + buf_out + n * stride_n, C); } } @@ -1279,10 +1280,11 @@ inline memory::format TFDataFormatToMklDnnDataFormat(TensorFormat format) { /// @return: Tensorflow data format corresponding to memory::format /// Fails with an error if invalid data format. inline TensorFormat MklDnnDataFormatToTFDataFormat(memory::format format) { - if (format == memory::format::nhwc) return FORMAT_NHWC; - else if (format == memory::format::nchw) return FORMAT_NCHW; - TF_CHECK_OK(Status(error::Code::INVALID_ARGUMENT, - "Unsupported data format")); + if (format == memory::format::nhwc) + return FORMAT_NHWC; + else if (format == memory::format::nchw) + return FORMAT_NCHW; + TF_CHECK_OK(Status(error::Code::INVALID_ARGUMENT, "Unsupported data format")); // Return to prevent compiler warnings, otherwise TF_CHECK_OK will ensure // that we don't come here. @@ -1425,7 +1427,6 @@ inline memory::desc CreateBlockedMemDescHelper(const memory::dims& dim, return memory::desc(md); } - /* * Class to represent all the resources corresponding to a tensor in TensorFlow * that are required to execute an operation (such as Convolution). @@ -1494,7 +1495,7 @@ class MklDnnData { /// @return: memory::desc object corresponding to blocked memory format /// for given dimensions and strides. static inline memory::desc CreateBlockedMemDesc(const memory::dims& dim, - const memory::dims& strides) { + const memory::dims& strides) { return CreateBlockedMemDescHelper(dim, strides, MklDnnType()); } @@ -1563,7 +1564,6 @@ class MklDnnData { return user_memory_->get_primitive_desc(); } - /// Get function for descriptor of user memory. inline memory::desc GetUsrMemDesc() { // This is ugly. Why MKL-DNN does not provide desc() method of const type?? @@ -1634,7 +1634,8 @@ class MklDnnData { /// @return: true in case reorder of input is needed; false, otherwise. inline bool IsReorderNeeded(const memory::format& target_format) const { CHECK_NOTNULL(user_memory_); - return target_format != user_memory_->get_primitive_desc().desc().data.format; + return target_format != + user_memory_->get_primitive_desc().desc().data.format; } /// Function to create a reorder from memory pointed by from to memory pointed diff --git a/tensorflow/core/util/presized_cuckoo_map.h b/tensorflow/core/util/presized_cuckoo_map.h index e7dab830f0ec9e3401d621f04358d3ee62cb0b63..f88ad2faaff344832d65b04357c3d8c2665ebad5 100644 --- a/tensorflow/core/util/presized_cuckoo_map.h +++ b/tensorflow/core/util/presized_cuckoo_map.h @@ -67,7 +67,7 @@ inline uint64 multiply_high_u64(uint64 x, uint64 y) { return prod_hi + (prod_mid1 >> 32) + (prod_mid2 >> 32) + carry; #endif } -} +} // namespace presized_cuckoo_map template class PresizedCuckooMap { diff --git a/tensorflow/core/util/reporter_test.cc b/tensorflow/core/util/reporter_test.cc index 1cb07718feee820c334d8f5183cafb2de0cb009b..575c27d4ef72ec33c4b9352de59fc806b12d6385 100644 --- a/tensorflow/core/util/reporter_test.cc +++ b/tensorflow/core/util/reporter_test.cc @@ -29,8 +29,8 @@ namespace { // Tests of all the error paths in log_reader.cc follow: static void ExpectHasSubstr(StringPiece s, StringPiece expected) { - EXPECT_TRUE(StringPiece(s).contains(expected)) << s << " does not contain " - << expected; + EXPECT_TRUE(StringPiece(s).contains(expected)) + << s << " does not contain " << expected; } TEST(TestReporter, NoLogging) { diff --git a/tensorflow/core/util/sparse/sparse_tensor.h b/tensorflow/core/util/sparse/sparse_tensor.h index f2401a0af4e60f66c606e86e90a37bcf09eb6308..258ee418c145bae161c7603d4249875fb687c94a 100644 --- a/tensorflow/core/util/sparse/sparse_tensor.h +++ b/tensorflow/core/util/sparse/sparse_tensor.h @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/framework/types.h" @@ -31,7 +32,6 @@ limitations under the License. #include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/sparse/dim_comparator.h" #include "tensorflow/core/util/sparse/group_iterator.h" -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" namespace tensorflow { namespace sparse { @@ -59,8 +59,8 @@ class SparseTensor { shape_(shape.begin(), shape.end()), order_(order.begin(), order.end()), dims_(GetDimsFromIx(ix)) { - CHECK_EQ(ix.dtype(), DT_INT64) << "indices must be type int64 but got: " - << ix.dtype(); + CHECK_EQ(ix.dtype(), DT_INT64) + << "indices must be type int64 but got: " << ix.dtype(); CHECK(TensorShapeUtils::IsVector(vals.shape())) << "vals must be a vec, but got: " << vals.shape().DebugString(); CHECK_EQ(ix.shape().dim_size(0), vals.shape().dim_size(0)) diff --git a/tensorflow/core/util/sparse/sparse_tensor_test.cc b/tensorflow/core/util/sparse/sparse_tensor_test.cc index efdd97fd3d6ffa5c1f66f2a0950d7bd44ba01eb1..85de0320857e307ea54594c2eff611b9e413945b 100644 --- a/tensorflow/core/util/sparse/sparse_tensor_test.cc +++ b/tensorflow/core/util/sparse/sparse_tensor_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -25,7 +26,6 @@ limitations under the License. #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test_benchmark.h" -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" namespace tensorflow { namespace sparse { diff --git a/tensorflow/core/util/stream_executor_util.h b/tensorflow/core/util/stream_executor_util.h index 6a5ddec04c9d6c2f723e0caa7343103f09c63183..f7767ace716782e53a2023bea7acc7b2f3c6604c 100644 --- a/tensorflow/core/util/stream_executor_util.h +++ b/tensorflow/core/util/stream_executor_util.h @@ -41,9 +41,10 @@ class StreamExecutorUtil { // This assumes that the error codes between the two implementations // match. static Status ConvertStatus(const perftools::gputools::port::Status& s) { - return s.ok() ? Status::OK() : Status(static_cast( - static_cast(s.code())), - s.error_message()); + return s.ok() ? Status::OK() + : Status(static_cast( + static_cast(s.code())), + s.error_message()); } }; diff --git a/tensorflow/core/util/tensor_bundle/tensor_bundle.cc b/tensorflow/core/util/tensor_bundle/tensor_bundle.cc index 579b70ab5149f05749205f24a0c6e64c95f12dfd..462b420976e63ca63079fd652fdb12c5ef2a1404 100644 --- a/tensorflow/core/util/tensor_bundle/tensor_bundle.cc +++ b/tensorflow/core/util/tensor_bundle/tensor_bundle.cc @@ -913,8 +913,8 @@ Status BundleReader::LookupSlice(StringPiece full_tensor_key, Status BundleReader::GetSliceValue(StringPiece full_tensor_key, const BundleEntryProto& full_tensor_entry, const TensorSlice& slice_spec, Tensor* val) { - using checkpoint::TensorSliceSet; using checkpoint::RegisterTensorSlice; + using checkpoint::TensorSliceSet; DCHECK_GE(full_tensor_entry.slices_size(), 0); const TensorShape full_shape(TensorShape(full_tensor_entry.shape())); diff --git a/tensorflow/core/util/tensor_slice_reader_cache.cc b/tensorflow/core/util/tensor_slice_reader_cache.cc index 0f009d7de57a3cf1471c1ba694d3a771bc00635c..424f8098a9c1e3cec3851be06d04d49bed93e9af 100644 --- a/tensorflow/core/util/tensor_slice_reader_cache.cc +++ b/tensorflow/core/util/tensor_slice_reader_cache.cc @@ -55,7 +55,7 @@ const TensorSliceReader* TensorSliceReaderCache::GetReader( TensorSliceReader::OpenTableFunction open_function, int preferred_shard) { mutex_lock l(mu_); -#if defined(__GXX_RTTI) || defined(_CPPRTTI) +#if defined(__GXX_RTTI) || defined(_CPPRTTI) // Get the function pointer from the open_function value. TensorSliceReaderCache::OpenFuncType* func_ptr = open_function.target(); diff --git a/tensorflow/core/util/tensor_slice_set.cc b/tensorflow/core/util/tensor_slice_set.cc index 4217df90ca147ccc17cadf6c46c6e4ef4524f12b..7c1d325c0a54e7ba5261f645a2962970fa2d3630 100644 --- a/tensorflow/core/util/tensor_slice_set.cc +++ b/tensorflow/core/util/tensor_slice_set.cc @@ -188,9 +188,9 @@ Status RegisterTensorSlice( } if (type != tss->type()) { return errors::Internal("Incompatible tensor types detected for tensor ", - name, ": existing = ", - DataTypeString(tss->type()), ", new = ", - DataTypeString(type)); + name, + ": existing = ", DataTypeString(tss->type()), + ", new = ", DataTypeString(type)); } } // Register the tensor slices without the actual data. diff --git a/tensorflow/core/util/tensor_slice_util.h b/tensorflow/core/util/tensor_slice_util.h index c7edae66b267d4cbd88d497c745b4d81802ab3a9..8f5a6f1d93591e94ec759d343ec26146c67552c0 100644 --- a/tensorflow/core/util/tensor_slice_util.h +++ b/tensorflow/core/util/tensor_slice_util.h @@ -139,9 +139,9 @@ static bool CopyDataFromTensorSliceToTensorSlice(const TensorShape& shape, const TensorSlice& slice_d, const SrcT* ptr_s, DstT* ptr_d) { - CHECK_LE(shape.dims(), kTensorSliceMaxRank) << "Only tensors of size up to " - << kTensorSliceMaxRank - << " are supported"; + CHECK_LE(shape.dims(), kTensorSliceMaxRank) + << "Only tensors of size up to " << kTensorSliceMaxRank + << " are supported"; // We need to compute the intersection of the two slices. TensorSlice inter; if (!slice_s.Intersect(slice_d, &inter)) { diff --git a/tensorflow/core/util/tensor_slice_writer.h b/tensorflow/core/util/tensor_slice_writer.h index bdb4921e1bbf8611d84420c1e52d01fa39c25264..2888c66d10fa3c2ab0eaf755a23da3eb3fcd6b09 100644 --- a/tensorflow/core/util/tensor_slice_writer.h +++ b/tensorflow/core/util/tensor_slice_writer.h @@ -101,8 +101,8 @@ Status TensorSliceWriter::Add(const string& name, const TensorShape& shape, // The tensor and the slice have to be compatible if (shape.dims() != slice.dims()) { return errors::Internal("Incompatible tensor shape and slice: ", "shape = ", - shape.DebugString(), ", slice = ", - slice.DebugString()); + shape.DebugString(), + ", slice = ", slice.DebugString()); } DataType dt = DataTypeToEnum::value; // We need to add an entry for "name" if there isn't an entry already. @@ -114,9 +114,9 @@ Status TensorSliceWriter::Add(const string& name, const TensorShape& shape, CHECK_EQ(name, ssm.name()) << ProtoShortDebugString(ssm); TensorShape ssm_shape(ssm.shape()); if (!shape.IsSameSize(ssm_shape)) { - return errors::Internal("Mismatching shapes: existing tensor = ", - ssm_shape.DebugString(), ", trying to add name ", - name, ", shape = ", shape.DebugString()); + return errors::Internal( + "Mismatching shapes: existing tensor = ", ssm_shape.DebugString(), + ", trying to add name ", name, ", shape = ", shape.DebugString()); } if (dt != ssm.type()) { return errors::Internal( diff --git a/tensorflow/docs_src/about/bib.md b/tensorflow/docs_src/about/bib.md index c9f0c532c62791a9fcf854f11fd2f330955ee7d6..5593a3d95c435df38174fde5db37f4dd3437acd4 100644 --- a/tensorflow/docs_src/about/bib.md +++ b/tensorflow/docs_src/about/bib.md @@ -60,7 +60,7 @@ author={ Lukasz~Kaiser and Manjunath~Kudlur and Josh~Levenberg and - Dan~Man\'{e} and + Dandelion~Man\'{e} and Rajat~Monga and Sherry~Moore and Derek~Murray and diff --git a/tensorflow/docs_src/api_guides/python/TPUEstimator.md b/tensorflow/docs_src/api_guides/python/TPUEstimator.md new file mode 100644 index 0000000000000000000000000000000000000000..d74d7f3181c9cf44e6c97e13742db682858f4694 --- /dev/null +++ b/tensorflow/docs_src/api_guides/python/TPUEstimator.md @@ -0,0 +1,396 @@ +# Using TPUs + +This document walks through the principal TensorFlow APIs necessary to make +effective use of a [Cloud TPU](https://cloud.google.com/tpu/), and highlights +the differences between regular TensorFlow usage, and usage on a TPU. + +This doc is aimed at users who: + +* Are familiar with TensorFlow's `Estimator` and `Dataset` APIs +* Have maybe [tried out a Cloud TPU](https://cloud.google.com/tpu/docs/quickstart) + using an existing model. +* Have, perhaps, skimmed the code of an example TPU model + [[1]](https://github.com/tensorflow/models/blob/master/official/mnist/mnist_tpu.py) + [[2]](https://github.com/tensorflow/tpu-demos/tree/master/cloud_tpu/models). +* Are interested in porting an existing `Estimator` model to + run on Cloud TPUs + +## TPUEstimator + +@{tf.estimator.Estimator$Estimators} are TensorFlow's model-level abstraction. +Standard `Estimators` can drive models on CPU and GPUs. You must use +@{tf.contrib.tpu.TPUEstimator} to drive a model on TPUs. + +Refer to TensorFlow's Getting Started section for an introduction to the basics +of using a @{$get_started/premade_estimators$pre-made `Estimator`}, and +@{$get_started/custom_estimators$custom `Estimator`s}. + +The `TPUEstimator` class differs somewhat from the `Estimator` class. + +The simplest way to maintain a model that can be run both on CPU/GPU or on a +Cloud TPU is to define the model's inference phase (from inputs to predictions) +outside of the `model_fn`. Then maintain separate implementations of the +`Estimator` setup and `model_fn`, both wrapping this inference step. For an +example of this pattern compare the `mnist.py` and `mnist_tpu.py` implementation in +[tensorflow/models](https://github.com/tensorflow/models/tree/master/official/mnist). + +### Running a `TPUEstimator` locally + +To create a standard `Estimator` you call the constructor, and pass it a +`model_fn`, for example: + +``` +my_estimator = tf.estimator.Estimator( + model_fn=my_model_fn) +``` + +The changes required to use a @{tf.contrib.tpu.TPUEstimator} on your local +machine are relatively minor. The constructor requires two additional arguments. +You should set the `use_tpu` argument to `False`, and pass a +@{tf.contrib.tpu.RunConfig} as the `config` argument, as shown below: + +``` python +my_tpu_estimator = tf.contrib.tpu.TPUEstimator( + model_fn=my_model_fn, + config=tf.contrib.tpu.RunConfig() + use_tpu=False) +``` + +Just this simple change will allow you to run a `TPUEstimator` locally. +The majority of example TPU models can be run in this local mode, +by setting the command line flags as follows: + + +``` +$> python mnist_tpu.py --use_tpu=false --master='' +``` + +Note: This `use_tpu=False` argument is useful for trying out the `TPUEstimator` +API. It is not meant to be a complete TPU compatibility test. Successfully +running a model locally in a `TPUEstimator` does not guarantee that it will +work on a TPU. + + +### Building a `tpu.RunConfig` + +While the default `RunConfig` is sufficient for local training, these settings +cannot be ignored in real usage. + +A more typical setup for a `RunConfig`, that can be switched to use a Cloud +TPU, might be as follows: + +``` python +import tempfile +import subprocess + +class FLAGS(object): + use_tpu=False + tpu_name=None + # Use a local temporary path for the `model_dir` + model_dir = tempfile.mkdtemp() + # Number of training steps to run on the Cloud TPU before returning control. + iterations = 50 + # A single Cloud TPU has 8 shards. + num_shards = 8 + +if FLAGS.use_tpu: + my_project_name = subprocess.check_output([ + 'gcloud','config','get-value','project']) + my_zone = subprocess.check_output([ + 'gcloud','config','get-value','compute/zone']) + cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver( + tpu_names=[FLAGS.tpu_name], + zone=my_zone, + project=my_project) + master = tpu_cluster_resolver.get_master() +else: + master = '' + +my_tpu_run_config = tf.contrib.tpu.RunConfig( + master=master, + evaluation_master=master, + model_dir=FLAGS.model_dir, + session_config=tf.ConfigProto( + allow_soft_placement=True, log_device_placement=True), + tpu_config=tf.contrib.tpu.TPUConfig(FLAGS.iterations, + FLAGS.num_shards), +) +``` + +Then you must pass the @{tf.contrib.tpu.RunConfig} to the constructor: + +``` python +my_tpu_estimator = tf.contrib.tpu.TPUEstimator( + model_fn=my_model_fn, + config = my_tpu_run_config, + use_tpu=FLAGS.use_tpu) +``` + +Typically the `FLAGS` would be set by command line arguments. To switch from +training locally to training on a cloud TPU you would need to: + + 1) Set `FLAGS.use_tpu` to `True` + 1) Set `FLAGS.tpu_name` so the + `tf.contrib.cluster_resolver.TPUClusterResolver` can find it + 1) Set `FLAGS.model_dir` to a Google Cloud Storage bucket url (`gs://`). + + +## Optimizer + +When training on a cloud TPU you **must** wrap the optimizer in a +@{tf.contrib.tpu.CrossShardOptimizer}, which uses an `allreduce` to aggregate +gradients and broadcast the result to each shard (each TPU core). + +The `CrossShardOptimizer` is not compatible with local training. So, to have +the same code run both locally and on a Cloud TPU, add lines like the following: + +``` python +optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate) +if FLAGS.use_tpu: + optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer) +``` + +If you prefer to avoid a global `FLAGS` variable in your model code, one +approach is to set the optimizer as one of the `Estimator`'s params, +as follows: + +``` python +my_tpu_estimator = tf.contrib.tpu.TPUEstimator( + model_fn=my_model_fn, + config = my_tpu_run_config, + use_tpu=FLAGS.use_tpu, + params={'optimizer':optimizer}) +``` + +## Model Function + +This section details the changes you must make to the model function +(`model_fn()`) to make it `TPUEstimator` compatible. + +### Static shapes + +During regular usage TensorFlow attempts to determine the shapes of each +`tf.Tensor` during graph construction. During execution any unknown shape +dimensions are determined dynamically, +see @{$programmers_guide/tensors#shape$Tensor Shapes} for more details. + +To run on Cloud TPUs TensorFlow models are compiled using @{$xla$XLA}. +XLA uses a similar system for determining shapes at compile time. XLA requires +that all tensor dimensions be statically defined at compile time. All shapes +must evaluate to a constant, and not depend on external data, or stateful +operations like variables or a random number generator. + + +### Summaries + +Remove any use of `tf.summary` from your model. + +@{$summaries_and_tensorboard$TensorBoard summaries} are a great way see inside +your model. A minimal set of basic summaries are automatically recorded by the +`TPUEstimator`, to `event` files in the `model_dir`. Custom summaries, however, +are currently unsupported when training on a Cloud TPU. So while the +`TPUEstimator` will still run locally with summaries, it will fail if used on a +TPU. + +### Metrics + +Build your evaluation metrics dictionary in a stand-alone `metric_fn`. + + + +Evaluation metrics are an essential part of training a model. These are fully +supported on Cloud TPUs, but with a slightly different syntax. + +A standard @{tf.metrics} returns two tensors. The first returns the running +average of the metric value, while the second updates the running average and +returns the value for this batch: + +``` +running_average, current_batch = tf.metrics.accuracy(labels, predictions) +``` + +In a standard `Estimator` you create a dictionary of these pairs, and return it +as part of the `EstimatorSpec`. + +```python +my_metrics = {'accuracy': tf.metrics.accuracy(labels, predictions)} + +return tf.estimator.EstimatorSpec( + ... + eval_metric_ops=my_metrics +) +``` + +In a `TPUEstimator` you instead pass a function (which returns a metrics +dictionary) and a list of argument tensors, as shown below: + +```python +def my_metric_fn(labels, predictions): + return {'accuracy': tf.metrics.accuracy(labels, predictions)} + +return tf.contrib.tpu.TPUEstimatorSpec( + ... + eval_metrics=(my_metric_fn, [labels, predictions]) +) +``` + +### Use `TPUEstimatorSpec` + +`TPUEstimatorSpec` do not support hooks, and require function wrappers for +some fields. + +An `Estimator`'s `model_fn` must return an `EstimatorSpec`. An `EstimatorSpec` +is a simple structure of named fields containing all the `tf.Tensors` of the +model that the `Estimator` may need to interact with. + +`TPUEstimators` use a @{tf.contrib.tpu.TPUEstimatorSpec}. There are a few +differences between it and a standard @{tf.estimator.EstimatorSpec}: + + +* The `eval_metric_ops` must be wrapped into a `metrics_fn`, this field is + renamed `eval_metrics` ([see above](#metrics)). +* The @{tf.train.SessionRunHook$hooks} are unsupported, so these fields are + omitted. +* The @{tf.train.Scaffold$`scaffold`}, if used, must also be wrapped in a + function. This field is renamed to `scaffold_fn`. + +`Scaffold` and `Hooks` are for advanced usage, and can typically be omitted. + +## Input functions + +Input functions work mainly unchanged as they run on the host computer, not the +Cloud TPU itself. This section explains the two necessary adjustments. + +### Params argument + + + +The `input_fn` for a standard `Estimator` _can_ include a +`params` argument; the `input_fn` for a `TPUEstimator` *must* include a +`params` argument. This is necessary to allow the estimator to set the batch +size for each replica of the input stream. So the minimum signature for an +`input_fn` for a `TPUEstimator` is: + +``` +def my_input_fn(params): + pass +``` + +Where `params['batch-size']` will contain the batch size. + +### Static shapes and batch size + +The input pipeline generated by your `input_fn` is run on CPU. So it is mostly +free strict static shape requirements imposed by the XLA/TPU environment. The +one requirement is that the batches of data fed from your input pipeline to +the TPU have a static shape, as determined by the standard TensorFlow shape +inference algorithm. Intermediate tensors are free to have a dynamic shapes. +If shape inference has failed, but the shape is known it is possible to +impose the correct shape using `tf.set_shape()`. + +In the example below the shape +inference algorithm fails, but it is corrected using `set_shape`: + +``` +>>> x = tf.zeros(tf.constant([1,2,3])+1) +>>> x.shape + +TensorShape([Dimension(None), Dimension(None), Dimension(None)]) + +>>> x.set_shape([2,3,4]) +``` + +In many cases the batch size is the only unknown dimension. + +A typical input pipeline, using `tf.data`, will usually produce batches of a +fixed size. The last batch of a finite `Dataset`, however, is typically smaller, +containing just the remaining elements. Since a `Dataset` does not know its own +length or finiteness, the standard @{tf.data.Dataset.batch$`batch`} method +cannot determine if all batches will have a fixed size batch on its own: + +``` +>>> params = {'batch_size':32} +>>> ds = tf.data.Dataset.from_tensors([0, 1, 2]) +>>> ds = ds.repeat().batch(params['batch-size']) +>>> ds + + +``` + +The most straightforward fix is to +@{tf.data.Dataset.apply$apply} @{tf.contrib.data.batch_and_drop_remainder} +as follows: + +``` +>>> params = {'batch_size':32} +>>> ds = tf.data.Dataset.from_tensors([0, 1, 2]) +>>> ds = ds.repeat().apply( +... tf.contrib.data.batch_and_drop_remainder(params['batch-size'])) +>>> ds + + <_RestructuredDataset shapes: (32, 3), types: tf.int32> +``` + +The one downside to this approach is that, as the name implies, this batching +method throws out any fractional batch at the end of the dataset. This is fine +for an infinitely repeating dataset being used for training, but could be a +problem if you want to train for an exact number of epochs. + +To do an exact 1-epoch of _evaluation_ you can work around this by manually +padding the length of the batches, and setting the padding entries to have zero +weight when creating your `tf.metrics`. + +## Datasets + +Efficient use of the `tf.data.Dataset` API is critical when using a Cloud +TPU, as it is impossible to use the Cloud TPU's unless you can feed it data +quickly enough. See @{$datasets_performance} for details on dataset performance. + +For all but the simplest experimentation (using +@{tf.data.Dataset.from_tensor_slices} or other in-graph data) you will need to +store all data files read by the `TPUEstimator`'s `Dataset` in Google Cloud +Storage Buckets. + + + +For most use-cases, we recommend converting your data into `TFRecord` +format and using a @{tf.data.TFRecordDataset} to read it. This, however, is not +a hard requirement and you can use other dataset readers +(`FixedLengthRecordDataset` or `TextLineDataset`) if you prefer. + +Small datasets can be loaded entirely into memory using +@{tf.data.Dataset.cache}. + +Regardless of the data format used, it is strongly recommended that you +@{$performance_guide#use_large_files$use large files}, on the order of +100MB. This is especially important in this networked setting as the overhead +of opening a file is significantly higher. + +It is also important, regardless of the type of reader used, to enable buffering +using the `buffer_size` argument to the constructor. This argument is specified +in bytes. A minimum of a few MB (`buffer_size=8*1024*1024`) is recommended so +that data is available when needed. + +The TPU-demos repo includes +[a script](https://github.com/tensorflow/tpu-demos/blob/master/cloud_tpu/datasets/imagenet_to_gcs.py) +for downloading the imagenet dataset and converting it to an appropriate format. +This together with the imagenet +[models](https://github.com/tensorflow/tpu-demos/tree/master/cloud_tpu/models) +included in the repo demonstrate all of these best-practices. + + +## What Next + +For details on how to actually set up and run a Cloud TPU see: + + * [Google Cloud TPU Documentation](https://cloud.google.com/tpu/docs/) + +This document is by no means exhaustive. The best source of more detail on how +to make a Cloud TPU compatible model are the example models published in: + + * The [TPU Demos Repository.](https://github.com/tensorflow/tpu-demos/) + +For more information about tuning TensorFlow code for performance see: + + * The @{$performance$Performance Section.} + diff --git a/tensorflow/docs_src/api_guides/python/regression_examples.md b/tensorflow/docs_src/api_guides/python/regression_examples.md index dae50a8f032bae9421bc01d1ac4043fdaae30080..7de2be05521d9293e33664cdbbd7bf16b9ad7c52 100644 --- a/tensorflow/docs_src/api_guides/python/regression_examples.md +++ b/tensorflow/docs_src/api_guides/python/regression_examples.md @@ -38,7 +38,7 @@ The preceding examples rely on the following data set utility: Utility Description - imports85.py + imports85.py This program provides utility functions that load the imports85 data set into formats that other TensorFlow programs (for example, linear_regression.py and diff --git a/tensorflow/docs_src/community/welcome.md b/tensorflow/docs_src/community/welcome.md index d2d3f9edaed9fc3c921a98c95ae24ce168e00216..9f6fe91b1490ef4ffe43acc877ecb83cc9121118 100644 --- a/tensorflow/docs_src/community/welcome.md +++ b/tensorflow/docs_src/community/welcome.md @@ -65,5 +65,5 @@ please read the following list carefully: on GitHub. For example, use the issue tracker to request a new operation in TensorFlow. * To report vulnerabilities, please follow our - [vulnerability disclosure guidelines](https://github.com/tensorflow/tensorflow/blob/master/SECURITY.md). + [vulnerability disclosure guidelines](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/SECURITY.md). diff --git a/tensorflow/docs_src/get_started/checkpoints.md b/tensorflow/docs_src/get_started/checkpoints.md index 680e1c0d3f58166a4f6b352816914f5220d84996..dfa2110e691167f54e6ea8b7a832f0a88d0ec41a 100644 --- a/tensorflow/docs_src/get_started/checkpoints.md +++ b/tensorflow/docs_src/get_started/checkpoints.md @@ -16,7 +16,7 @@ This document focuses on checkpoints. For details on SavedModel, see the ## Sample code This document relies on the same -[https://github.com/tensorflow/models/blob/master/samples/core/get_started/premade_estimator.py](Iris classification example) detailed in @{$premade_estimators$Getting Started with TensorFlow}. +[Iris classification example](https://github.com/tensorflow/models/blob/master/samples/core/get_started/premade_estimator.py) detailed in @{$premade_estimators$Getting Started with TensorFlow}. To download and access the example, invoke the following two commands: ```shell diff --git a/tensorflow/docs_src/get_started/custom_estimators.md b/tensorflow/docs_src/get_started/custom_estimators.md index 79c4ee75d01c745d9e492c5db9df11a93eca0477..42a246678a054d637fea5a82a03ecb84ff412bd9 100644 --- a/tensorflow/docs_src/get_started/custom_estimators.md +++ b/tensorflow/docs_src/get_started/custom_estimators.md @@ -161,7 +161,7 @@ classifier = tf.estimator.Estimator( To implement a typical model function, you must do the following: -* (Define the model)[#define_the_model]. +* [Define the model](#define_the_model). * Specify additional calculations for each of the [three different modes](#modes): * [Predict](#predict) diff --git a/tensorflow/docs_src/get_started/get_started_for_beginners.md b/tensorflow/docs_src/get_started/get_started_for_beginners.md index ea1c2fb3f473b9e39567c7607d3b3ad10d2de6b5..9bca7540a73ea4354096de1b999ab708be26925c 100644 --- a/tensorflow/docs_src/get_started/get_started_for_beginners.md +++ b/tensorflow/docs_src/get_started/get_started_for_beginners.md @@ -357,7 +357,7 @@ my_feature_columns = [ ### Select the type of model -We need the select the kind of model that will be trained. +We need to select the kind of model that will be trained. Lots of model types exist; picking the ideal type takes experience. We've selected a neural network to solve the Iris problem. [**Neural networks**](https://developers.google.com/machine-learning/glossary/#neural_network) @@ -655,7 +655,9 @@ calls as follows: ```python predictions = classifier.predict( - input_fn=lambda:eval_input_fn(predict_x, batch_size=args.batch_size)) + input_fn=lambda:eval_input_fn(predict_x, + labels=None, + batch_size=args.batch_size)) ``` As with the `evaluate` method, our `predict` method also gathers examples @@ -700,7 +702,7 @@ for pred_dict, expec in zip(predictions, expected): class_id = pred_dict['class_ids'][0] probability = pred_dict['probabilities'][class_id] - print(template.format(SPECIES[class_id], 100 * probability, expec)) + print(template.format(iris_data.SPECIES[class_id], 100 * probability, expec)) ``` Running the program yields the following output: diff --git a/tensorflow/docs_src/get_started/premade_estimators.md b/tensorflow/docs_src/get_started/premade_estimators.md index 4ef212a5b55a905325e3889efe2884c2c4ff113d..4f01f997c33c211e8cff81b6b268bb320aa794df 100644 --- a/tensorflow/docs_src/get_started/premade_estimators.md +++ b/tensorflow/docs_src/get_started/premade_estimators.md @@ -2,37 +2,39 @@ # Getting Started with TensorFlow This document introduces the TensorFlow programming environment and shows you -how to write the Iris classification problem in TensorFlow. +how to solve the Iris classification problem in TensorFlow. -Prior to reading this document, do the following: +## Prerequisites + +Prior to using the sample code in this document, you'll need to do the +following: * @{$install$Install TensorFlow}. * If you installed TensorFlow with virtualenv or Anaconda, activate your TensorFlow environment. -* To keep the data import simple, our Iris example uses Pandas. You can - install Pandas with: +* Install or upgrade pandas by issuing the following command: - `pip install pandas` + pip install pandas ## Getting the sample code -Take the following steps to get the sample code for this program: +Take the following steps to get the sample code we'll be going through: -1. Clone the TensorFlow Models repository from github by entering the following +1. Clone the TensorFlow Models repository from GitHub by entering the following command: - `git clone https://github.com/tensorflow/models` + git clone https://github.com/tensorflow/models 1. Change directory within that branch to the location containing the examples used in this document: - `cd models/samples/core/get_started/` + cd models/samples/core/get_started/ The program described in this document is [`premade_estimator.py`](https://github.com/tensorflow/models/blob/master/samples/core/get_started/premade_estimator.py). This program uses [`iris_data.py`](https://github.com/tensorflow/models/blob/master/samples/core/get_started/iris_data.py) -To fetch its training data. +to fetch its training data. ### Running the program @@ -45,7 +47,7 @@ python premade_estimator.py The program should output training logs followed by some predictions against the test set. For example, the first line in the following output shows that the model thinks there is a 99.6% chance that the first example in the test -set is a Setosa. Since the test set `expected "Setosa"`, this appears to be +set is a Setosa. Since the test set expected Setosa, this appears to be a good prediction. ``` None @@ -61,9 +63,9 @@ If the program generates errors instead of answers, ask yourself the following questions: * Did you install TensorFlow properly? -* Are you using the correct version of tensorflow? +* Are you using the correct version of TensorFlow? * Did you activate the environment you installed TensorFlow in? (This is - only relevant in certain installation environments.) + only relevant in certain installation mechanisms.) ## The programming stack @@ -74,18 +76,15 @@ provides a programming stack consisting of multiple API layers:
-
-The TensorFlow Programming Environment -
We strongly recommend writing TensorFlow programs with the following APIs: -* @{tf.estimator$Estimators}, which represent a complete model. +* @{$programmers_guide/estimators$Estimators}, which represent a complete model. The Estimator API provides methods to train the model, to judge the model's accuracy, and to generate predictions. * @{$get_started/datasets_quickstart$Datasets}, which build a data input pipeline. The Dataset API has methods to load and manipulate data, and feed - it into your model. The Datasets API meshes well with the Estimators API. + it into your model. The Dataset API meshes well with the Estimators API. ## Classifying irises: an overview @@ -120,7 +119,7 @@ individual Iris flowers: * petal length * petal width -Our model will represent these features as float32 numerical data. +Our model will represent these features as `float32` numerical data. The label identifies the Iris species, which must be one of the following: @@ -154,9 +153,6 @@ The following figure illustrates the features, hidden layers, and predictions alt="A diagram of the network architecture: Inputs, 2 hidden layers, and outputs" src="../images/custom_estimators/full_network.png">
-
-The Model. -
### Inference @@ -174,12 +170,12 @@ example is an Iris Versicolor. ## Overview of programming with Estimators -An Estimator is TensorFlow's high level representation of a complete model. It +An Estimator is TensorFlow's high-level representation of a complete model. It handles the details of initialization, logging, saving and restoring, and many other features so you can concentrate on your model. For more details see @{$programmers_guide/estimators}. -An "Estimator" is any class derived from @{tf.estimator.Estimator}. TensorFlow +An Estimator is any class derived from @{tf.estimator.Estimator}. TensorFlow provides a collection of [pre-made Estimators](https://developers.google.com/machine-learning/glossary/#pre-made_Estimator) (for example, `LinearRegressor`) to implement common ML algorithms. Beyond @@ -199,7 +195,7 @@ following tasks: * Call one or more methods on the Estimator object, passing the appropriate input function as the source of the data. -Let's see how those tasks are implemented in Iris. +Let's see how those tasks are implemented for Iris classification. ## Create input functions @@ -209,17 +205,30 @@ evaluating, and prediction. An **input function** is a function that returns a @{tf.data.Dataset} object which outputs the following two-element tuple: -* "features" - A Python dictionary in which: +* [`features`](https://developers.google.com/machine-learning/glossary/#feature) - A Python dictionary in which: * Each key is the name of a feature. * Each value is an array containing all of that feature's values. -* "label" - An array containing the values of the +* `label` - An array containing the values of the [label](https://developers.google.com/machine-learning/glossary/#label) for every example. -Your input function may generate the "features" dictionary and "label" list any -way you like. However, we recommend using TensorFlow's @{tf.data.Dataset} API, -which can deftly parse all sorts of data. At a high-level, -the @{tf.data.Dataset} API consists of the following classes: +Just to demonstrate the format of the input function, here's a simple +implementation: + +```python +def input_evaluation_set(): + features = {'SepalLength': np.array([6.4, 5.0]), + 'SepalWidth': np.array([2.8, 2.3]), + 'PetalLength': np.array([5.6, 3.3]), + 'PetalWidth': np.array([2.2, 1.0])} + labels = np.array([2, 1]) + return features, labels +``` + +Your input function may generate the `features` dictionary and `label` list any +way you like. However, we recommend using TensorFlow's Dataset API, which can +parse all sorts of data. At a high level, the Dataset API consists of the +following classes:
+Where the individual members are: -Where: - -* Dataset: Base class containing methods to create and transform datasets. Also - allows you to initialize a dataset from data in memory, or from a Python - generator. -* TextLineDataset: Reads lines from text files. -* TFRecordDataset: Reads records from TFRecord files. -* FixedLengthRecordDataset: Reads fixed size records from binary files. -* Iterator: Provides a way to access one data set element at a time. +* `Dataset` - Base class containing methods to create and transform + datasets. Also allows you to initialize a dataset from data in memory, or from + a Python generator. +* `TextLineDataset` - Reads lines from text files. +* `TFRecordDataset` - Reads records from TFRecord files. +* `FixedLengthRecordDataset` - Reads fixed size records from binary files. +* `Iterator` - Provides a way to access one data set element at a time. The Dataset API can handle a lot of common cases for you. For example, using the Dataset API, you can easily read in records from a large collection of files in parallel and join them into a single stream. -To keep things simple in this example we are going to load the data with pandas, -and build our input pipeline from this in-memory data. +To keep things simple in this example we are going to load the data with +[pandas](https://pandas.pydata.org/), and build our input pipeline from this +in-memory data. Here is the input function used for training in this program, which is available in [`iris_data.py`](https://github.com/tensorflow/models/blob/master/samples/core/get_started/iris_data.py): @@ -258,9 +267,9 @@ def train_input_fn(features, labels, batch_size): return dataset.shuffle(1000).repeat().batch(batch_size) ``` -## Define the Feature Columns +## Define the feature columns -A [**Feature Column**](https://developers.google.com/machine-learning/glossary/#feature_columns) +A [**feature column**](https://developers.google.com/machine-learning/glossary/#feature_columns) is an object describing how the model should use raw input data from the features dictionary. When you build an Estimator model, you pass it a list of feature columns that describes each of the features you want the model to use. @@ -270,7 +279,7 @@ to the model. For Iris, the 4 raw features are numeric values, so we'll build a list of feature columns to tell the Estimator model to represent each of the four features as 32-bit floating-point values. Therefore, the code to create the -Feature Column is simply: +feature column is: ```python # Feature columns describe how to use the input. @@ -279,29 +288,29 @@ for key in train_x.keys(): my_feature_columns.append(tf.feature_column.numeric_column(key=key)) ``` -Feature Columns can be far more sophisticated than those we're showing here. -We detail feature columns @{$get_started/feature_columns$later on} in -getting started. +Feature columns can be far more sophisticated than those we're showing here. We +detail feature columns @{$get_started/feature_columns$later on} in our Getting +Started guide. Now that we have the description of how we want the model to represent the raw features, we can build the estimator. -## Instantiate an Estimator +## Instantiate an estimator The Iris problem is a classic classification problem. Fortunately, TensorFlow provides several pre-made classifier Estimators, including: -* @{tf.estimator.DNNClassifier}—for deep models that perform multi-class +* @{tf.estimator.DNNClassifier} for deep models that perform multi-class classification. -* @{tf.estimator.DNNLinearCombinedClassifier}—for wide-n-deep models. -* @{tf.estimator.LinearClassifier}— for classifiers based on linear models. +* @{tf.estimator.DNNLinearCombinedClassifier} for wide & deep models. +* @{tf.estimator.LinearClassifier} for classifiers based on linear models. For the Iris problem, `tf.estimator.DNNClassifier` seems like the best choice. Here's how we instantiated this Estimator: ```python -# Build 2 hidden layer DNN with 10, 10 units respectively. +# Build a DNN with 2 hidden layers and 10 nodes in each hidden layer. classifier = tf.estimator.DNNClassifier( feature_columns=my_feature_columns, # Two hidden layers of 10 nodes each. diff --git a/tensorflow/docs_src/install/install_c.md b/tensorflow/docs_src/install/install_c.md index ba1a4118aece1f42822f7cd084feed50c5cf6ebb..a783205b4a2d24182de6496e0173635990120185 100644 --- a/tensorflow/docs_src/install/install_c.md +++ b/tensorflow/docs_src/install/install_c.md @@ -38,7 +38,7 @@ enable TensorFlow for C: OS="linux" # Change to "darwin" for macOS TARGET_DIRECTORY="/usr/local" curl -L \ - "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-${TF_TYPE}-${OS}-x86_64-1.5.0-rc1.tar.gz" | + "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-${TF_TYPE}-${OS}-x86_64-1.6.0-rc0.tar.gz" | sudo tar -C $TARGET_DIRECTORY -xz The `tar` command extracts the TensorFlow C library into the `lib` diff --git a/tensorflow/docs_src/install/install_go.md b/tensorflow/docs_src/install/install_go.md index 87cc647317a11fab0d9d0219dd5764af3dcb2ecc..5249e04615b506186a12807bb71ec4079db8156c 100644 --- a/tensorflow/docs_src/install/install_go.md +++ b/tensorflow/docs_src/install/install_go.md @@ -38,7 +38,7 @@ steps to install this library and enable TensorFlow for Go: TF_TYPE="cpu" # Change to "gpu" for GPU support TARGET_DIRECTORY='/usr/local' curl -L \ - "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-${TF_TYPE}-$(go env GOOS)-x86_64-1.5.0-rc1.tar.gz" | + "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-${TF_TYPE}-$(go env GOOS)-x86_64-1.6.0-rc0.tar.gz" | sudo tar -C $TARGET_DIRECTORY -xz The `tar` command extracts the TensorFlow C library into the `lib` diff --git a/tensorflow/docs_src/install/install_java.md b/tensorflow/docs_src/install/install_java.md index 37e109a6e4bdee97ad02bc7aceb2c0c24e1ec7ec..0c6c773e62483b2272cf3b80da0932b4b800bb71 100644 --- a/tensorflow/docs_src/install/install_java.md +++ b/tensorflow/docs_src/install/install_java.md @@ -36,7 +36,7 @@ following to the project's `pom.xml` to use the TensorFlow Java APIs: org.tensorflow tensorflow - 1.5.0-rc1 + 1.6.0-rc0 ``` @@ -65,7 +65,7 @@ As an example, these steps will create a Maven project that uses TensorFlow: org.tensorflow tensorflow - 1.5.0-rc1 + 1.6.0-rc0 @@ -123,12 +123,12 @@ instead: org.tensorflow libtensorflow - 1.5.0-rc1 + 1.6.0-rc0 org.tensorflow libtensorflow_jni_gpu - 1.5.0-rc1 + 1.6.0-rc0 ``` @@ -147,7 +147,7 @@ refer to the simpler instructions above instead. Take the following steps to install TensorFlow for Java on Linux or macOS: 1. Download - [libtensorflow.jar](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-1.5.0-rc1.jar), + [libtensorflow.jar](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-1.6.0-rc0.jar), which is the TensorFlow Java Archive (JAR). 2. Decide whether you will run TensorFlow for Java on CPU(s) only or with @@ -166,7 +166,7 @@ Take the following steps to install TensorFlow for Java on Linux or macOS: OS=$(uname -s | tr '[:upper:]' '[:lower:]') mkdir -p ./jni curl -L \ - "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-${TF_TYPE}-${OS}-x86_64-1.5.0-rc1.tar.gz" | + "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-${TF_TYPE}-${OS}-x86_64-1.6.0-rc0.tar.gz" | tar -xz -C ./jni ### Install on Windows @@ -174,10 +174,10 @@ Take the following steps to install TensorFlow for Java on Linux or macOS: Take the following steps to install TensorFlow for Java on Windows: 1. Download - [libtensorflow.jar](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-1.5.0-rc1.jar), + [libtensorflow.jar](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-1.6.0-rc0.jar), which is the TensorFlow Java Archive (JAR). 2. Download the following Java Native Interface (JNI) file appropriate for - [TensorFlow for Java on Windows](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-windows-x86_64-1.5.0-rc1.zip). + [TensorFlow for Java on Windows](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-windows-x86_64-1.6.0-rc0.zip). 3. Extract this .zip file. @@ -225,7 +225,7 @@ must be part of your `classpath`. For example, you can include the downloaded `.jar` in your `classpath` by using the `-cp` compilation flag as follows: -
javac -cp libtensorflow-1.5.0-rc1.jar HelloTF.java
+
javac -cp libtensorflow-1.6.0-rc0.jar HelloTF.java
### Running @@ -239,11 +239,11 @@ two files are available to the JVM: For example, the following command line executes the `HelloTF` program on Linux and macOS X: -
java -cp libtensorflow-1.5.0-rc1.jar:. -Djava.library.path=./jni HelloTF
+
java -cp libtensorflow-1.6.0-rc0.jar:. -Djava.library.path=./jni HelloTF
And the following command line executes the `HelloTF` program on Windows: -
java -cp libtensorflow-1.5.0-rc1.jar;. -Djava.library.path=jni HelloTF
+
java -cp libtensorflow-1.6.0-rc0.jar;. -Djava.library.path=jni HelloTF
If the program prints Hello from version, you've successfully installed TensorFlow for Java and are ready to use the API. If the program diff --git a/tensorflow/docs_src/install/install_linux.md b/tensorflow/docs_src/install/install_linux.md index 7289224572009d6e7b77498b8c560381717ea73a..105b225177315db07b1117c3ece4b77dd2b60cb2 100644 --- a/tensorflow/docs_src/install/install_linux.md +++ b/tensorflow/docs_src/install/install_linux.md @@ -188,7 +188,7 @@ Take the following steps to install TensorFlow with Virtualenv: Virtualenv environment:
(tensorflow)$ pip3 install --upgrade \
-     https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.5.0rc1-cp34-cp34m-linux_x86_64.whl
+ https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.6.0rc0-cp34-cp34m-linux_x86_64.whl If you encounter installation problems, see [Common Installation Problems](#common_installation_problems). @@ -293,7 +293,7 @@ take the following steps:
      $ sudo pip3 install --upgrade \
-     https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.5.0rc1-cp34-cp34m-linux_x86_64.whl
+     https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.6.0rc0-cp34-cp34m-linux_x86_64.whl
      
If this step fails, see @@ -480,7 +480,7 @@ Take the following steps to install TensorFlow in an Anaconda environment:
      (tensorflow)$ pip install --ignore-installed --upgrade \
-     https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.5.0rc1-cp34-cp34m-linux_x86_64.whl
+ https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.6.0rc0-cp34-cp34m-linux_x86_64.whl @@ -648,14 +648,14 @@ This section documents the relevant values for Linux installations. CPU only:
-https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.5.0rc1-cp27-none-linux_x86_64.whl
+https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.6.0rc0-cp27-none-linux_x86_64.whl
 
GPU support:
-https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.5.0rc1-cp27-none-linux_x86_64.whl
+https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.6.0rc0-cp27-none-linux_x86_64.whl
 
Note that GPU support requires the NVIDIA hardware and software described in @@ -667,14 +667,14 @@ Note that GPU support requires the NVIDIA hardware and software described in CPU only:
-https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.5.0rc1-cp34-cp34m-linux_x86_64.whl
+https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.6.0rc0-cp34-cp34m-linux_x86_64.whl
 
GPU support:
-https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.5.0rc1-cp34-cp34m-linux_x86_64.whl
+https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.6.0rc0-cp34-cp34m-linux_x86_64.whl
 
Note that GPU support requires the NVIDIA hardware and software described in @@ -686,14 +686,14 @@ Note that GPU support requires the NVIDIA hardware and software described in CPU only:
-https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.5.0rc1-cp35-cp35m-linux_x86_64.whl
+https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.6.0rc0-cp35-cp35m-linux_x86_64.whl
 
GPU support:
-https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.5.0rc1-cp35-cp35m-linux_x86_64.whl
+https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.6.0rc0-cp35-cp35m-linux_x86_64.whl
 
@@ -705,14 +705,14 @@ Note that GPU support requires the NVIDIA hardware and software described in CPU only:
-https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.5.0rc1-cp36-cp36m-linux_x86_64.whl
+https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.6.0rc0-cp36-cp36m-linux_x86_64.whl
 
GPU support:
-https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.5.0rc1-cp36-cp36m-linux_x86_64.whl
+https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.6.0rc0-cp36-cp36m-linux_x86_64.whl
 
diff --git a/tensorflow/docs_src/install/install_mac.md b/tensorflow/docs_src/install/install_mac.md index 555a6837d8beb153bd2b55b089be99b701c4f30c..a6ea548cfbdb3070c19b5c19ebc903ca76a4656a 100644 --- a/tensorflow/docs_src/install/install_mac.md +++ b/tensorflow/docs_src/install/install_mac.md @@ -115,7 +115,7 @@ Take the following steps to install TensorFlow with Virtualenv: TensorFlow in the active Virtualenv is as follows:
 $ pip3 install --upgrade \
-     https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.5.0rc1-py3-none-any.whl
+ https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.6.0rc0-py3-none-any.whl If you encounter installation problems, see [Common Installation Problems](#common-installation-problems). @@ -238,7 +238,7 @@ take the following steps: issue the following command:
 $ sudo pip3 install --upgrade \
-     https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.5.0rc1-py3-none-any.whl 
+ https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.6.0rc0-py3-none-any.whl If the preceding command fails, see [installation problems](#common-installation-problems). @@ -347,7 +347,7 @@ Take the following steps to install TensorFlow in an Anaconda environment: TensorFlow for Python 2.7:
 (targetDirectory)$ pip install --ignore-installed --upgrade \
-     https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.5.0rc1-py2-none-any.whl
+ https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.6.0rc0-py2-none-any.whl @@ -520,7 +520,7 @@ This section documents the relevant values for Mac OS installations.
-https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.5.0rc1-py2-none-any.whl
+https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.6.0rc0-py2-none-any.whl
 
@@ -528,5 +528,5 @@ https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.5.0rc1-py2-none-a
-https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.5.0rc1-py3-none-any.whl
+https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.6.0rc0-py3-none-any.whl
 
diff --git a/tensorflow/docs_src/install/install_sources.md b/tensorflow/docs_src/install/install_sources.md index 0d99f9a47d46df5033c6cb6b02bc412a580ddeb7..0dbb15188e5074d48c8dc7fcdcdd33daf123c6d7 100644 --- a/tensorflow/docs_src/install/install_sources.md +++ b/tensorflow/docs_src/install/install_sources.md @@ -221,7 +221,7 @@ problem, do either of the following: * Download Xcode 7.2 and select it as your default by issuing the following command: -
 $ sudo xcode-select -s /Application/Xcode-7.2/Xcode.app
+
 $ sudo xcode-select -s /Applications/Xcode-7.2/Xcode.app
**NOTE:** Your system must fulfill the NVIDIA software requirements described in one of the following documents: @@ -272,8 +272,6 @@ Found possible Python library paths: Please input the desired Python library path to use. Default is [/usr/lib/python2.7/dist-packages] Using python library path: /usr/local/lib/python2.7/dist-packages -Do you wish to build TensorFlow with MKL support? [y/N] -No MKL support will be enabled for TensorFlow Please specify optimization flags to use during compilation when bazel option "--config=opt" is specified [Default is -march=native]: Do you wish to use jemalloc as the malloc implementation? [Y/n] jemalloc enabled @@ -361,10 +359,10 @@ Invoke `pip install` to install that pip package. The filename of the `.whl` file depends on your platform. For example, the following command will install the pip package -for TensorFlow 1.5.0rc1 on Linux: +for TensorFlow 1.6.0rc0 on Linux:
-$ sudo pip install /tmp/tensorflow_pkg/tensorflow-1.5.0rc1-py2-none-any.whl
+$ sudo pip install /tmp/tensorflow_pkg/tensorflow-1.6.0rc0-py2-none-any.whl
 
## Validate your installation @@ -462,9 +460,10 @@ Stack Overflow and specify the `tensorflow` tag. **Linux** - - - + + + + @@ -480,7 +479,8 @@ Stack Overflow and specify the `tensorflow` tag. **Mac**
Version:CPU/GPU:Python Version:Compiler:Build Tools:cuDNN:CUDA:
tensorflow-1.5.0-rc1CPU2.7, 3.3-3.6GCC 4.8Bazel 0.8.0N/AN/A
tensorflow_gpu-1.5.0-rc1GPU2.7, 3.3-3.6GCC 4.8Bazel 0.8.079
tensorflow-1.6.0rc0CPU2.7, 3.3-3.6GCC 4.8Bazel 0.9.0N/AN/A
tensorflow_gpu-1.6.0rc0GPU2.7, 3.3-3.6GCC 4.8Bazel 0.9.079
tensorflow-1.5.0CPU2.7, 3.3-3.6GCC 4.8Bazel 0.8.0N/AN/A
tensorflow_gpu-1.5.0GPU2.7, 3.3-3.6GCC 4.8Bazel 0.8.079
tensorflow-1.4.0CPU2.7, 3.3-3.6GCC 4.8Bazel 0.5.4N/AN/A
tensorflow_gpu-1.4.0GPU2.7, 3.3-3.6GCC 4.8Bazel 0.5.468
tensorflow-1.3.0CPU2.7, 3.3-3.6GCC 4.8Bazel 0.4.5N/AN/A
- + + @@ -493,8 +493,10 @@ Stack Overflow and specify the `tensorflow` tag. **Windows**
Version:CPU/GPU:Python Version:Compiler:Build Tools:cuDNN:CUDA:
tensorflow-1.5.0-rc1CPU2.7, 3.3-3.6Clang from xcodeBazel 0.8.1N/AN/A
tensorflow-1.6.0rc0CPU2.7, 3.3-3.6Clang from xcodeBazel 0.8.1N/AN/A
tensorflow-1.5.0CPU2.7, 3.3-3.6Clang from xcodeBazel 0.8.1N/AN/A
tensorflow-1.4.0CPU2.7, 3.3-3.6Clang from xcodeBazel 0.5.4N/AN/A
tensorflow-1.3.0CPU2.7, 3.3-3.6Clang from xcodeBazel 0.4.5N/AN/A
tensorflow-1.2.0CPU2.7, 3.3-3.6Clang from xcodeBazel 0.4.5N/AN/A
- - + + + + diff --git a/tensorflow/docs_src/performance/performance_guide.md b/tensorflow/docs_src/performance/performance_guide.md index 10e7ad7ada533c8da5e5b871b38809b90604685e..cd47fc2803bc1429d28bd0ae4c2ad68e632a6f03 100644 --- a/tensorflow/docs_src/performance/performance_guide.md +++ b/tensorflow/docs_src/performance/performance_guide.md @@ -498,7 +498,7 @@ For TensorFlow source versions after 1.3.0: ```bash ./configure # Pick the desired options -bazel build --config=mkl -c opt //tensorflow/tools/pip_package:build_pip_package +bazel build --config=mkl --config=opt //tensorflow/tools/pip_package:build_pip_package ``` diff --git a/tensorflow/docs_src/performance/xla/operation_semantics.md b/tensorflow/docs_src/performance/xla/operation_semantics.md index 1e9b8b35db65ef19a4bcb607b98af1e1de4e6d5b..f865c30aa8bd1639a72a0f4641b883a8890f9c13 100644 --- a/tensorflow/docs_src/performance/xla/operation_semantics.md +++ b/tensorflow/docs_src/performance/xla/operation_semantics.md @@ -252,7 +252,7 @@ Clamps an operand to within the range between a minimum and maximum value. Given an operand and minimum and maximum values, returns the operand if it is in the range between the minimum and maximum, else returns the minimum value if the operand is below this range or the maximum value if the operand is above this -range. That is, `clamp(a, x, b) = max(min(a, x), b)`. +range. That is, `clamp(a, x, b) = min(max(a, x), b)`. All three arrays must be the same shape. Alternately, as a restricted form of [broadcasting](broadcasting.md), `min` and/or `max` can be a scalar of type `T`. diff --git a/tensorflow/docs_src/programmers_guide/debugger.md b/tensorflow/docs_src/programmers_guide/debugger.md index 9eaee2702829cbfd96cd56e832003724eba5bb1b..c1a90dee0a6e5e6bc9b51cd232718179f1511f61 100644 --- a/tensorflow/docs_src/programmers_guide/debugger.md +++ b/tensorflow/docs_src/programmers_guide/debugger.md @@ -214,7 +214,7 @@ navigate between these screens by clicking the `<--` and ### Other Features of the tfdbg CLI In addition to the commands listed above, the tfdbg CLI provides the following -addditional features: +additional features: * To navigate through previous tfdbg commands, type in a few characters followed by the Up or Down arrow keys. tfdbg will show you the history of diff --git a/tensorflow/docs_src/programmers_guide/graphs.md b/tensorflow/docs_src/programmers_guide/graphs.md index 2b4896c381052b5a3fb97385a18dbff82c2c0d89..9049a5a9f3d44e255188c6c41cdb12a619464379 100644 --- a/tensorflow/docs_src/programmers_guide/graphs.md +++ b/tensorflow/docs_src/programmers_guide/graphs.md @@ -125,14 +125,14 @@ an operation: @{tf.Tensor} accepts an optional `name` argument. For example, `tf.constant(42.0, name="answer")` creates a new @{tf.Operation} named `"answer"` and returns a @{tf.Tensor} named `"answer:0"`. If the default graph - already contained an operation named `"answer"`, the TensorFlow would append + already contains an operation named `"answer"`, then TensorFlow would append `"_1"`, `"_2"`, and so on to the name, in order to make it unique. * The @{tf.name_scope} function makes it possible to add a **name scope** prefix to all operations created in a particular context. The current name scope prefix is a `"/"`-delimited list of the names of all active @{tf.name_scope} context managers. If a name scope has already been used in the current - context, TensorFlow appens `"_1"`, `"_2"`, and so on. For example: + context, TensorFlow appends `"_1"`, `"_2"`, and so on. For example: ```python c_0 = tf.constant(0, name="c") # => operation named "c" diff --git a/tensorflow/docs_src/programmers_guide/index.md b/tensorflow/docs_src/programmers_guide/index.md index d45e666ce7b440bae20ba32d894526372af7e17b..7a5e90081d9145ca934929f0af11f2a40cb2dcae 100644 --- a/tensorflow/docs_src/programmers_guide/index.md +++ b/tensorflow/docs_src/programmers_guide/index.md @@ -13,7 +13,7 @@ works. The units are as follows: ## Low Level APIs * @{$programmers_guide/low_level_intro}, which introduces the - basics of how you can to use TensorFlow outside of the high Level APIs. + basics of how you can use TensorFlow outside of the high Level APIs. * @{$programmers_guide/tensors}, which explains how to create, manipulate, and access Tensors--the fundamental object in TensorFlow. * @{$programmers_guide/variables}, which details how diff --git a/tensorflow/docs_src/programmers_guide/saved_model.md b/tensorflow/docs_src/programmers_guide/saved_model.md index 9f50be5b31cd8b61b81426f50aa9ef9beb3138f2..f27a658342b8d33407e1c6ed5799a10c2305a74c 100644 --- a/tensorflow/docs_src/programmers_guide/saved_model.md +++ b/tensorflow/docs_src/programmers_guide/saved_model.md @@ -285,7 +285,7 @@ with tf.Session(graph=tf.Graph()) as sess: ``` -### Loading a Savedmodel in C++ +### Loading a SavedModel in C++ The C++ version of the SavedModel [loader](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/cc/saved_model/loader.h) @@ -303,6 +303,30 @@ LoadSavedModel(session_options, run_options, export_dir, {kSavedModelTagTrain}, &bundle); ``` +### Loading and Serving a SavedModel in TensorFlow Serving + +You can easily load and serve a SavedModel with the TensorFlow Serving Model +Server binary. See [instructions](https://www.tensorflow.org/serving/setup#installing_using_apt-get) +on how to install the server, or build it if you wish. + +Once you have the Model Server, run it with: +``` +tensorflow_model_server --port=port-numbers --model_name=your-model-name --model_base_path=your_model_base_path +``` +Set the port and model_name flags to values of your choosing. The +model_base_path flag expects to be to a base directory, with each version of +your model residing in a numerically named subdirectory. If you only have a +single version of your model, simply place it in a subdirectory like so: +* Place the model in /tmp/model/0001 +* Set model_base_path to /tmp/model + +Store different versions of your model in numerically named subdirectories of a +common base directory. For example, suppose the base directory is `/tmp/model`. +If you have only one version of your model, store it in `/tmp/model/0001`. If +you have two versions of your model, store the second version in +`/tmp/model/0002`, and so on. Set the `--model-base_path` flag to the base +directory (`/tmp/model`, in this example). TensorFlow Model Server will serve +the model in the highest numbered subdirectory of that base directory. ### Standard constants diff --git a/tensorflow/examples/android/build.gradle b/tensorflow/examples/android/build.gradle index f7bdf8b816a8191770bc1ad59b890041b8e39912..0767726aa9a248fb073fbd4114f47d1b4ed6901b 100644 --- a/tensorflow/examples/android/build.gradle +++ b/tensorflow/examples/android/build.gradle @@ -56,10 +56,12 @@ def nativeOutDir = 'libs/' + cpuType def nativeBuildRule = 'buildNativeBazel' def demoLibPath = '../../../bazel-bin/tensorflow/examples/android/libtensorflow_demo.so' def inferenceLibPath = '../../../bazel-bin/tensorflow/contrib/android/libtensorflow_inference.so' + +// Override for Makefile builds. if (nativeBuildSystem == 'makefile') { nativeBuildRule = 'buildNativeMake' - demoLibPath = '../../../tensorflow/contrib/makefile/gen/lib/libtensorflow_demo.so' - inferenceLibPath = '../../../tensorflow/contrib/makefile/gen/lib/libtensorflow_inference.so' + demoLibPath = '../../../tensorflow/contrib/makefile/gen/lib/android_' + cpuType + '/libtensorflow_demo.so' + inferenceLibPath = '../../../tensorflow/contrib/makefile/gen/lib/android_' + cpuType + '/libtensorflow_inference.so' } // If building with Bazel, this is the location of the bazel binary. @@ -154,7 +156,8 @@ task buildNativeMake(type: Exec) { '-s', \ 'tensorflow/contrib/makefile/sub_makefiles/android/Makefile.in', \ '-t', \ - 'libtensorflow_inference.so libtensorflow_demo.so' \ + 'libtensorflow_inference.so libtensorflow_demo.so all' \ + , '-a', cpuType \ //, '-T' // Uncomment to skip protobuf and speed up subsequent builds. } diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/tracking/MultiBoxTracker.java b/tensorflow/examples/android/src/org/tensorflow/demo/tracking/MultiBoxTracker.java index 2fe2ba539edc84e80baf36b6d1ac1e192bc92163..af6af2bc8f508a70aa7e44a7236f0e7ea5e3d71c 100644 --- a/tensorflow/examples/android/src/org/tensorflow/demo/tracking/MultiBoxTracker.java +++ b/tensorflow/examples/android/src/org/tensorflow/demo/tracking/MultiBoxTracker.java @@ -199,7 +199,7 @@ public class MultiBoxTracker { final int w, final int h, final int rowStride, - final int sensorOrienation, + final int sensorOrientation, final byte[] frame, final long timestamp) { if (objectTracker == null && !initialized) { @@ -209,7 +209,7 @@ public class MultiBoxTracker { objectTracker = ObjectTracker.getInstance(w, h, rowStride, true); frameWidth = w; frameHeight = h; - this.sensorOrientation = sensorOrienation; + this.sensorOrientation = sensorOrientation; initialized = true; if (objectTracker == null) { diff --git a/tensorflow/examples/image_retraining/retrain.py b/tensorflow/examples/image_retraining/retrain.py index ec22684eaf63700c608c6ce45f22941555246b99..58c5f87884e5a091300f128403d00fb90bad59fe 100644 --- a/tensorflow/examples/image_retraining/retrain.py +++ b/tensorflow/examples/image_retraining/retrain.py @@ -344,8 +344,8 @@ def maybe_download_and_extract(data_url): filepath, _ = urllib.request.urlretrieve(data_url, filepath, _progress) print() statinfo = os.stat(filepath) - tf.logging.info('Successfully downloaded', filename, statinfo.st_size, - 'bytes.') + tf.logging.info('Successfully downloaded %s %d bytes.', + filename, statinfo.st_size) print('Extracting file from ', filepath) tarfile.open(filepath, 'r:gz').extractall(dest_directory) else: diff --git a/tensorflow/examples/learn/text_classification.py b/tensorflow/examples/learn/text_classification.py index eb117c39a122f4f6c108dd18f8f8035edf05eaa1..e4e61862b02f9827f42c8d0052a7be8a57502dd8 100644 --- a/tensorflow/examples/learn/text_classification.py +++ b/tensorflow/examples/learn/text_classification.py @@ -34,8 +34,7 @@ MAX_LABEL = 15 WORDS_FEATURE = 'words' # Name of the input words feature. -def estimator_spec_for_softmax_classification( - logits, labels, mode): +def estimator_spec_for_softmax_classification(logits, labels, mode): """Returns EstimatorSpec instance for softmax classification.""" predicted_classes = tf.argmax(logits, 1) if mode == tf.estimator.ModeKeys.PREDICT: @@ -53,8 +52,8 @@ def estimator_spec_for_softmax_classification( return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op) eval_metric_ops = { - 'accuracy': tf.metrics.accuracy( - labels=labels, predictions=predicted_classes) + 'accuracy': + tf.metrics.accuracy(labels=labels, predictions=predicted_classes) } return tf.estimator.EstimatorSpec( mode=mode, loss=loss, eval_metric_ops=eval_metric_ops) @@ -67,8 +66,7 @@ def bag_of_words_model(features, labels, mode): bow_embedding_column = tf.feature_column.embedding_column( bow_column, dimension=EMBEDDING_SIZE) bow = tf.feature_column.input_layer( - features, - feature_columns=[bow_embedding_column]) + features, feature_columns=[bow_embedding_column]) logits = tf.layers.dense(bow, MAX_LABEL, activation=None) return estimator_spec_for_softmax_classification( @@ -110,9 +108,9 @@ def main(unused_argv): # Prepare training and testing data dbpedia = tf.contrib.learn.datasets.load_dataset( 'dbpedia', test_with_fake_data=FLAGS.test_with_fake_data) - x_train = pandas.Series(dbpedia.train.data[:,1]) + x_train = pandas.Series(dbpedia.train.data[:, 1]) y_train = pandas.Series(dbpedia.train.target) - x_test = pandas.Series(dbpedia.test.data[:,1]) + x_test = pandas.Series(dbpedia.test.data[:, 1]) y_test = pandas.Series(dbpedia.test.target) # Process vocabulary @@ -152,10 +150,7 @@ def main(unused_argv): # Predict. test_input_fn = tf.estimator.inputs.numpy_input_fn( - x={WORDS_FEATURE: x_test}, - y=y_test, - num_epochs=1, - shuffle=False) + x={WORDS_FEATURE: x_test}, y=y_test, num_epochs=1, shuffle=False) predictions = classifier.predict(input_fn=test_input_fn) y_predicted = np.array(list(p['class'] for p in predictions)) y_predicted = y_predicted.reshape(np.array(y_test).shape) diff --git a/tensorflow/examples/tutorials/word2vec/word2vec_basic.py b/tensorflow/examples/tutorials/word2vec/word2vec_basic.py index d055d157454d4cb351e8db59eec484f212893fe5..f6906b0f79b86910b5354bea420d00f62ff0caf8 100644 --- a/tensorflow/examples/tutorials/word2vec/word2vec_basic.py +++ b/tensorflow/examples/tutorials/word2vec/word2vec_basic.py @@ -270,12 +270,6 @@ with tf.Session(graph=graph) as session: run_metadata=run_metadata) average_loss += loss_val - # Add returned summaries to writer in each step. - writer.add_summary(summary, step) - # Add metadata to visualize the graph for the last run. - if step == (num_steps - 1): - writer.add_run_metadata(run_metadata, 'step%d' % step) - # Add returned summaries to writer in each step. writer.add_summary(summary, step) # Add metadata to visualize the graph for the last run. diff --git a/tensorflow/go/graph.go b/tensorflow/go/graph.go index fc087d9d995dfe031e61fd0fa15d649c2ee35cc9..08943a527cbdc072b12b066240c213be45ffd54c 100644 --- a/tensorflow/go/graph.go +++ b/tensorflow/go/graph.go @@ -173,7 +173,11 @@ type OpSpec struct { // operation. Attrs map[string]interface{} - // Other possible fields: Device, ColocateWith, ControlInputs. + // Operations that must be executed before executing the operation + // being added. + ControlDependencies []*Operation + + // Other possible fields: Device, ColocateWith. } // AddOperation adds an operation to g. @@ -204,6 +208,9 @@ func (g *Graph) AddOperation(args OpSpec) (*Operation, error) { } } } + for _, in := range args.ControlDependencies { + C.TF_AddControlInput(cdesc, in.c) + } status := newStatus() for name, value := range args.Attrs { if err := setAttr(cdesc, status, name, value); err != nil { diff --git a/tensorflow/go/op/scope.go b/tensorflow/go/op/scope.go index a9ec79463a00022bf85bf00032df9004648525ae..13de4294dc2ebdfff9bb68d277c09239d0bc8593 100644 --- a/tensorflow/go/op/scope.go +++ b/tensorflow/go/op/scope.go @@ -33,10 +33,11 @@ import ( // A Scope object and all its derivates (e.g., obtained from Scope.SubScope) // are not safe for concurrent use by multiple goroutines. type Scope struct { - graph *tf.Graph - namemap map[string]int - namespace string - err *scopeErr + graph *tf.Graph + namemap map[string]int + namespace string + controlDependencies []*tf.Operation + err *scopeErr } // scopeErr is used to share errors between all derivatives of a root scope. @@ -80,6 +81,7 @@ func (s *Scope) AddOperation(args tf.OpSpec) *tf.Operation { if s.namespace != "" { args.Name = s.namespace + "/" + args.Name } + args.ControlDependencies = append(args.ControlDependencies, s.controlDependencies...) op, err := s.graph.AddOperation(args) if err != nil { s.UpdateErr(args.Type, err) @@ -103,6 +105,28 @@ func (s *Scope) SubScope(namespace string) *Scope { } } +// WithControlDependencies returns a new Scope which will cause all operations +// added to the graph to execute only after all the provided operations have +// executed first (in addition to any other control dependencies in s). +func (s *Scope) WithControlDependencies(ops ...*tf.Operation) *Scope { + // Force a copy of the control dependencies into a new underlying array on + // every call. We cannot alias the same underlying array as `ops`, otherwise + // the user could modify that array after calling s.WithControlDependencies, + // which would be confusing. We cannot alias the same underlying array as the + // original `s.controlDependencies`, since Scopes form a logical tree, and + // other calls to s.WithControlDependencies could stomp on each other. + deps := make([]*tf.Operation, 0, len(s.controlDependencies)+len(ops)) + deps = append(deps, s.controlDependencies...) + deps = append(deps, ops...) + return &Scope{ + graph: s.graph, + namemap: s.namemap, + namespace: s.namespace, + controlDependencies: deps, + err: s.err, + } +} + // Err returns the error, if any, encountered during the construction // of the Graph managed by s. // diff --git a/tensorflow/go/op/scope_test.go b/tensorflow/go/op/scope_test.go index 6fb5d32e503c7c9a5a48747844da15be81b1de2d..b58a61de98b0f5b04959e1eca35c6b6c4d77e42b 100644 --- a/tensorflow/go/op/scope_test.go +++ b/tensorflow/go/op/scope_test.go @@ -69,6 +69,49 @@ func TestScopeSubScopeErrors(t *testing.T) { } } +func TestControlDependencies(t *testing.T) { + var ( + s = NewScope() + zero = Const(s.SubScope("zero"), int32(0)) + one = Const(s.SubScope("one"), int32(1)) + variable = VarHandleOp(s, tf.Int32, tf.ScalarShape()) + init = AssignVariableOp(s, variable, zero) + update = AssignAddVariableOp(s, variable, one) + readDeps = []*tf.Operation{update} + ) + // We intend for `read` to have a control dependency on `update`. + s = s.WithControlDependencies(readDeps...) + // Ensure that Scope.WithControlDependencies makes a copy of the underlying + // array, rather than just holding a slice reference to the same user-supplied + // underlying array. If the copy is correctly performed, overwriting + // readDeps[0] should have no effect on control dependencies for `read`. + readDeps[0] = init + read := ReadVariableOp(s, variable, tf.Int32) + + graph, err := s.Finalize() + if err != nil { + t.Fatal(err) + } + sess, err := tf.NewSession(graph, nil) + if err != nil { + t.Fatal(err) + } + if _, err = sess.Run(nil, nil, []*tf.Operation{init}); err != nil { + t.Fatal(err) + } + // Without the control dependency, the read operation may not see the + // update. + for i := int32(0); i < 10; i++ { + out, err := sess.Run(nil, []tf.Output{read}, nil) + if err != nil { + t.Fatal(err) + } + if got, want := out[0].Value().(int32), i+1; got != want { + t.Errorf("Got %d, want %d", got, want) + } + } +} + func TestScopeFinalize(t *testing.T) { var ( root = NewScope() diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go index 5b19c90238ef3bb1361a5e2476e94dd06e76d128..cb47651d7b3199cc804b2a3e89aaf2cead7b75c1 100644 --- a/tensorflow/go/op/wrappers.go +++ b/tensorflow/go/op/wrappers.go @@ -8729,31 +8729,6 @@ func IRFFT2D(scope *Scope, input tf.Output, fft_length tf.Output) (output tf.Out return op.Output(0) } -// Compute the pairwise cross product. -// -// `a` and `b` must be the same shape; they can either be simple 3-element vectors, -// or any shape where the innermost dimension is 3. In the latter case, each pair -// of corresponding 3-element vectors is cross-multiplied independently. -// -// Arguments: -// a: A tensor containing 3-element vectors. -// b: Another tensor, of same type and shape as `a`. -// -// Returns Pairwise cross product of the vectors in `a` and `b`. -func Cross(scope *Scope, a tf.Output, b tf.Output) (product tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Cross", - Input: []tf.Input{ - a, b, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // Transforms a vector of brain.Example protos (as strings) into typed tensors. // // Arguments: @@ -21290,6 +21265,31 @@ func StatsAggregatorSummary(scope *Scope, iterator tf.Output) (summary tf.Output return op.Output(0) } +// Compute the pairwise cross product. +// +// `a` and `b` must be the same shape; they can either be simple 3-element vectors, +// or any shape where the innermost dimension is 3. In the latter case, each pair +// of corresponding 3-element vectors is cross-multiplied independently. +// +// Arguments: +// a: A tensor containing 3-element vectors. +// b: Another tensor, of same type and shape as `a`. +// +// Returns Pairwise cross product of the vectors in `a` and `b`. +func Cross(scope *Scope, a tf.Output, b tf.Output) (product tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Cross", + Input: []tf.Input{ + a, b, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // Performs a padding as a preprocess during a convolution. // // Similar to FusedResizeAndPadConv2d, this op allows for an optimized diff --git a/tensorflow/java/maven/libtensorflow/pom.xml b/tensorflow/java/maven/libtensorflow/pom.xml index 6285ee0483d9171d6cdb9b4dbf2675bafb953038..a9ce5372aeb32b6957359fdcaa9da01c732c9f9f 100644 --- a/tensorflow/java/maven/libtensorflow/pom.xml +++ b/tensorflow/java/maven/libtensorflow/pom.xml @@ -6,7 +6,7 @@ org.tensorflow parentpom - 1.5.0-rc1 + 1.5.0 ../ libtensorflow diff --git a/tensorflow/java/maven/libtensorflow_jni/pom.xml b/tensorflow/java/maven/libtensorflow_jni/pom.xml index b0e5c44fecc9bf3a95ac3d4e36d9f98d74d3b2bb..fe34ca83ff30373fa4f3c4f345323bad40a8754e 100644 --- a/tensorflow/java/maven/libtensorflow_jni/pom.xml +++ b/tensorflow/java/maven/libtensorflow_jni/pom.xml @@ -6,7 +6,7 @@ org.tensorflow parentpom - 1.5.0-rc1 + 1.5.0 ../ libtensorflow_jni diff --git a/tensorflow/java/maven/libtensorflow_jni_gpu/pom.xml b/tensorflow/java/maven/libtensorflow_jni_gpu/pom.xml index 02c5dca13f4d292718afca7e99bac82710e1949f..390152808eb0c4abebe093ed5db39faf37fcafe3 100644 --- a/tensorflow/java/maven/libtensorflow_jni_gpu/pom.xml +++ b/tensorflow/java/maven/libtensorflow_jni_gpu/pom.xml @@ -6,7 +6,7 @@ org.tensorflow parentpom - 1.5.0-rc1 + 1.5.0 ../ libtensorflow_jni_gpu diff --git a/tensorflow/java/maven/pom.xml b/tensorflow/java/maven/pom.xml index 949597ca7f1e7a05cf6c0e5a15cb5307b00859a1..524ec45f48bb91d09dfb5fca3cc19256d45587fb 100644 --- a/tensorflow/java/maven/pom.xml +++ b/tensorflow/java/maven/pom.xml @@ -6,7 +6,7 @@ 4.0.0org.tensorflowparentpom - 1.5.0-rc1 + 1.5.0pomhttps://www.tensorflow.org diff --git a/tensorflow/java/maven/proto/pom.xml b/tensorflow/java/maven/proto/pom.xml index 9f0ebcf84c9c8e01662a93034a4407c6b58a6d7e..9cf3217f51f73184a02a58ead1a2735c5a44fd26 100644 --- a/tensorflow/java/maven/proto/pom.xml +++ b/tensorflow/java/maven/proto/pom.xml @@ -6,7 +6,7 @@ org.tensorflow parentpom - 1.5.0-rc1 + 1.5.0 ../ proto diff --git a/tensorflow/java/maven/tensorflow-android/update.py b/tensorflow/java/maven/tensorflow-android/update.py index 7c250718347f5fdd65aaf8003aad75a87a19c96a..4ae666e4e5351f1bdaf79d1b5cfdb63b0f811e2b 100644 --- a/tensorflow/java/maven/tensorflow-android/update.py +++ b/tensorflow/java/maven/tensorflow-android/update.py @@ -95,7 +95,7 @@ def main(): release_prefix = 'https://storage.googleapis.com/tensorflow/libtensorflow' info_url = '%s/android_buildinfo-%s.json' % (release_prefix, args.version) aar_url = '%s/tensorflow-%s.aar' % (release_prefix, args.version) - build_type = 'release-matrix-android' + build_type = 'release-matrix-android2' # Retrieve build information build_info = get_json(info_url) diff --git a/tensorflow/java/maven/tensorflow/pom.xml b/tensorflow/java/maven/tensorflow/pom.xml index 88d897362ad6c8f84d93cbc9bcf3c30905b345be..d619f986a9a03ac67f5de6bbe80e686a05ce5d42 100644 --- a/tensorflow/java/maven/tensorflow/pom.xml +++ b/tensorflow/java/maven/tensorflow/pom.xml @@ -6,7 +6,7 @@ org.tensorflow parentpom - 1.5.0-rc1 + 1.5.0 ../ tensorflow diff --git a/tensorflow/java/src/main/java/org/tensorflow/NativeLibrary.java b/tensorflow/java/src/main/java/org/tensorflow/NativeLibrary.java index 499757e8cf4d6166e425d801ce20335bd8ad83e8..cf773e1686dea97f62f432be43f2c10b69fa8e24 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/NativeLibrary.java +++ b/tensorflow/java/src/main/java/org/tensorflow/NativeLibrary.java @@ -88,7 +88,7 @@ final class NativeLibrary { // Deletions are in the reverse order of requests, so we need to request that the directory be // deleted first, so that it is empty when the request is fulfilled. tempPath.deleteOnExit(); - final String tempDirectory = tempPath.toString(); + final String tempDirectory = tempPath.getCanonicalPath(); if (frameworkResource != null) { extractResource(frameworkResource, frameworkLibName, tempDirectory); } else { diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index c73d6c37eecea75aa64f721653b70fbdaf4855cf..f5cd7885e72f743713933db7ff754ee89bed2626 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -299,6 +299,7 @@ cc_library( ":safe_ptr", "//tensorflow/c:tf_status_helper", "//tensorflow/c/eager:c_api", + "//tensorflow/c/eager:c_api_internal", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", @@ -576,6 +577,7 @@ py_library( ":pywrap_tensorflow", ":random_seed", ":sparse_tensor", + ":tensor_spec", ":tensor_util", ":util", "//tensorflow/python/eager:context", @@ -780,6 +782,18 @@ py_library( ], ) +py_library( + name = "tensor_spec", + srcs = ["framework/tensor_spec.py"], + srcs_version = "PY2AND3", + deps = [ + ":common_shapes", + ":dtypes", + ":tensor_shape", + "//third_party/py/numpy", + ], +) + py_library( name = "tensor_util", srcs = ["framework/tensor_util.py"], @@ -1148,6 +1162,21 @@ py_test( ], ) +py_test( + name = "framework_tensor_spec_test", + size = "small", + srcs = ["framework/tensor_spec_test.py"], + main = "framework/tensor_spec_test.py", + srcs_version = "PY2AND3", + deps = [ + ":framework_for_generated_wrappers", + ":framework_test_lib", + ":platform_test", + ":tensor_spec", + "//third_party/py/numpy", + ], +) + py_test( name = "framework_sparse_tensor_test", size = "small", @@ -4265,12 +4294,6 @@ filegroup( visibility = ["//tensorflow:__subpackages__"], ) -filegroup( - name = "hidden_ops", - srcs = ["ops/hidden_ops.txt"], - visibility = ["//tensorflow:__subpackages__"], -) - cuda_py_test( name = "accumulate_n_benchmark", size = "large", diff --git a/tensorflow/python/build_defs.bzl b/tensorflow/python/build_defs.bzl index 7f29adc06fcc5922114b7cd2bde8a8df5b1e0665..b9056f86e6d0465a8521f054a459c06eb5aeb37c 100644 --- a/tensorflow/python/build_defs.bzl +++ b/tensorflow/python/build_defs.bzl @@ -22,7 +22,6 @@ def tf_gen_op_wrapper_private_py(name, out=None, deps=[], bare_op_name = name[:-4] # Strip off the _gen tf_gen_op_wrapper_py(name=bare_op_name, out=out, - hidden_file="ops/hidden_ops.txt", visibility=visibility, deps=deps, require_shape_functions=require_shape_functions, diff --git a/tensorflow/python/client/notebook.py b/tensorflow/python/client/notebook.py index 8babe35b3230e7b46c0c9484ccddae4e5e22a335..4b6a0f71ae65aa28b70dd22ce6cffa82e9bc5973 100644 --- a/tensorflow/python/client/notebook.py +++ b/tensorflow/python/client/notebook.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - """Notebook front-end to TensorFlow. When you run this binary, you'll see something like below, which indicates @@ -43,10 +42,8 @@ from tensorflow.python.platform import app os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "cpp" os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION"] = "2" - FLAGS = None - ORIG_ARGV = sys.argv # Main notebook process calls itself with argv[1]="kernel" to start kernel # subprocesses. @@ -73,8 +70,8 @@ def main(unused_argv): notebookapp.ip = "0.0.0.0" notebookapp.password = passwd(FLAGS.password) else: - print ("\nNo password specified; Notebook server will only be available" - " on the local machine.\n") + print("\nNo password specified; Notebook server will only be available" + " on the local machine.\n") notebookapp.initialize(argv=["--notebook-dir", FLAGS.notebook_dir]) if notebookapp.ip == "0.0.0.0": @@ -125,8 +122,8 @@ if __name__ == "__main__": # kernel app. if IS_KERNEL: # Drop everything except --flagfile. - sys.argv = ([sys.argv[0]] + - [x for x in sys.argv[1:] if x.startswith("--flagfile")]) + sys.argv = ( + [sys.argv[0]] + [x for x in sys.argv[1:] if x.startswith("--flagfile")]) FLAGS, unparsed = parser.parse_known_args() app.run(main=main, argv=[sys.argv[0]] + unparsed) diff --git a/tensorflow/python/client/session.py b/tensorflow/python/client/session.py index e6f94396b85eb4d0ab0774a53484089f735be940..f3c4fecdc0fde0436bea76cc774edaabe1bc07dd 100644 --- a/tensorflow/python/client/session.py +++ b/tensorflow/python/client/session.py @@ -35,6 +35,7 @@ from tensorflow.python.ops import session_ops from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import compat from tensorflow.python.util import nest +from tensorflow.python.util.tf_export import tf_export class SessionInterface(object): @@ -1441,6 +1442,7 @@ class BaseSession(SessionInterface): return handles +@tf_export('Session') class Session(BaseSession): """A class for running TensorFlow operations. @@ -1537,8 +1539,22 @@ class Session(BaseSession): def __exit__(self, exec_type, exec_value, exec_tb): if exec_type is errors.OpError: logging.error('Session closing due to OpError: %s', (exec_value,)) - self._default_session_context_manager.__exit__(exec_type, exec_value, - exec_tb) + try: + self._default_session_context_manager.__exit__(exec_type, exec_value, + exec_tb) + except RuntimeError as error: + if error == exec_value: + # NOTE(skyewm): for some reason, in Python3, + # _default_session_context_manager.__exit__ will re-raise the "not + # re-entrant" exception raised in __enter__ above (note that if we're + # here, we're in the outer session context manager, since __exit__ is + # not called when __enter__ raises an exception). We still want to + # continue cleaning up this context manager before the exception is + # further propagated, so we ignore it here (note that it'll continue + # being propagated after this method completes). + pass + else: + raise self._default_graph_context_manager.__exit__(exec_type, exec_value, exec_tb) self._default_session_context_manager = None @@ -1581,6 +1597,7 @@ class Session(BaseSession): tf_session.TF_Reset(target, containers, config) +@tf_export('InteractiveSession') class InteractiveSession(BaseSession): """A TensorFlow `Session` for use in interactive contexts, such as a shell. diff --git a/tensorflow/python/client/session_test.py b/tensorflow/python/client/session_test.py index 768a5db88aa647609dba1c479a5aca68cd26652a..f12c0055115d41f1b90ba319ad79ec23378bebb1 100644 --- a/tensorflow/python/client/session_test.py +++ b/tensorflow/python/client/session_test.py @@ -46,6 +46,7 @@ from tensorflow.python.framework import versions from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import data_flow_ops +from tensorflow.python.ops import gen_control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops # Import resource_variable_ops for the variables-to-tensor implicit conversion. @@ -1745,8 +1746,10 @@ class SessionTest(test_util.TensorFlowTestCase): def runTestBuildGraphError(self, sess): # Ensure that errors from building the graph get propagated. data = array_ops.placeholder(dtypes.float32, shape=[]) - enter_1 = control_flow_ops.enter(data, 'foo_1', False) - enter_2 = control_flow_ops.enter(data, 'foo_2', False) + # pylint: disable=protected-access + enter_1 = gen_control_flow_ops._enter(data, 'foo_1', False) + enter_2 = gen_control_flow_ops._enter(data, 'foo_2', False) + # pylint: enable=protected-access res = math_ops.add(enter_1, enter_2) with self.assertRaisesOpError('has inputs from different frames'): sess.run(res, feed_dict={data: 1.0}) diff --git a/tensorflow/python/data/kernel_tests/BUILD b/tensorflow/python/data/kernel_tests/BUILD index 43cbde69d9db20d85c55e071d8393074a78a4a1b..8b8adefa65a5c54d40bc28d8f50953513cfd3605 100644 --- a/tensorflow/python/data/kernel_tests/BUILD +++ b/tensorflow/python/data/kernel_tests/BUILD @@ -357,6 +357,9 @@ tf_py_test( "//tensorflow/python:session", "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/data/ops:iterator_ops", + "//tensorflow/python:constant_op", + "//tensorflow/python:string_ops", + "//tensorflow/python:lookup_ops", ], grpc_enabled = True, tags = [ diff --git a/tensorflow/python/data/kernel_tests/iterator_ops_cluster_test.py b/tensorflow/python/data/kernel_tests/iterator_ops_cluster_test.py index 45dfa13720b09c7bba979b72a339c13dcd2d827b..25c91b42dc65f849a680e65fc7fc2548c1cea8ea 100644 --- a/tensorflow/python/data/kernel_tests/iterator_ops_cluster_test.py +++ b/tensorflow/python/data/kernel_tests/iterator_ops_cluster_test.py @@ -17,10 +17,13 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import numpy as np + from tensorflow.core.protobuf import config_pb2 from tensorflow.python.client import session from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import iterator_ops +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import function @@ -28,6 +31,9 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import functional_ops +from tensorflow.python.ops import lookup_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import string_ops from tensorflow.python.platform import test @@ -103,6 +109,67 @@ class IteratorClusterTest(test.TestCase): "/job:worker/replica:0/task:1/cpu:0", workers[0].target) + def testCaptureHashTableInSharedIterator(self): + worker, _ = test_util.create_local_cluster(1, 1) + + # NOTE(mrry): We must use the V2 variants of `HashTable` + # etc. because these produce a `tf.resource`-typed output that is + # compatible with the in-graph function implementation. + default_val = -1 + keys = constant_op.constant(["brain", "salad", "surgery"]) + values = constant_op.constant([0, 1, 2], dtypes.int64) + table = lookup_ops.HashTable( + lookup_ops.KeyValueTensorInitializer(keys, values), + default_val, + shared_name="shared_table") + + input_sentences = dataset_ops.Dataset.from_tensor_slices( + ["brain brain tank salad surgery", "surgery brain"]) + + iterator = ( + input_sentences.map(lambda x: string_ops.string_split([x]).values).map( + table.lookup) + .make_initializable_iterator(shared_name="shared_iterator")) + init_op = iterator.initializer + get_next = iterator.get_next() + + with session.Session(worker[0].target) as sess: + sess.run(table.init) + sess.run(init_op) + self.assertAllEqual([0, 0, -1, 1, 2], sess.run(get_next)) + + with session.Session(worker[0].target) as sess: + self.assertAllEqual([2, 0], sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testImplicitDisposeParallelMapDataset(self): + # Tests whether a parallel map dataset will be cleaned up correctly when + # the pipeline does not run it until exhaustion. + # The pipeline is TensorSliceDataset -> MapDataset(square_3) -> + # RepeatDataset(None) -> PrefetchDataset(100). + worker, _ = test_util.create_local_cluster(1, 1) + + components = (np.arange(1000), + np.array([[1, 2, 3]]) * np.arange(1000)[:, np.newaxis], + np.array(37.0) * np.arange(1000)) + + def _map_fn(x, y, z): + return math_ops.square(x), math_ops.square(y), math_ops.square(z) + + dataset = ( + dataset_ops.Dataset.from_tensor_slices(components).map(_map_fn) + .repeat(None).prefetch(10000)) + + iterator = dataset.make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + + with session.Session(worker[0].target) as sess: + sess.run(init_op) + for _ in range(3): + sess.run(get_next) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py index c1ba67e4744c6282f0fd3d9a388aabc1ed51267b..c4b7e4919bbbdb4c2096f124b54c264fa62e3fab 100644 --- a/tensorflow/python/data/ops/dataset_ops.py +++ b/tensorflow/python/data/ops/dataset_ops.py @@ -41,8 +41,10 @@ from tensorflow.python.ops import gen_io_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import script_ops from tensorflow.python.util import deprecation +from tensorflow.python.util.tf_export import tf_export +@tf_export("data.Dataset") class Dataset(object): """Represents a potentially large set of elements. @@ -556,6 +558,8 @@ class Dataset(object): - /path/to/dir/b.py - /path/to/dir/c.py + NOTE: The order of the file names returned can be non-deterministic. + Args: file_pattern: A string or scalar string `tf.Tensor`, representing the filename pattern that will be matched. @@ -899,10 +903,11 @@ class Dataset(object): Args: transformation_func: A function that takes one `Dataset` argument and - returns a `Dataset`. + returns a `Dataset`. Returns: - Dataset: The `Dataset` returned by applying `transformation_func` to this dataset. + Dataset: The `Dataset` returned by applying `transformation_func` to this + dataset. """ dataset = transformation_func(self) if not isinstance(dataset, Dataset): @@ -1454,6 +1459,19 @@ def _padding_value_to_tensor(value, output_type): return value +def _default_padding(input_dataset): + + def make_zero(t): + if t.base_dtype == dtypes.string: + return "" + elif t.base_dtype == dtypes.variant: + raise TypeError("Unable to create padding for field of type 'variant'") + else: + return np.zeros_like(t.as_numpy_dtype()) + + return nest.map_structure(make_zero, input_dataset.output_types) + + class PaddedBatchDataset(Dataset): """A `Dataset` that batches and pads contiguous elements from its input.""" @@ -1469,23 +1487,13 @@ class PaddedBatchDataset(Dataset): batch_size, dtype=dtypes.int64, name="batch_size") padding_values = ( padding_values - if padding_values is not None else self._default_padding(input_dataset)) + if padding_values is not None else _default_padding(input_dataset)) self._padded_shapes = nest.map_structure_up_to( input_dataset.output_shapes, _partial_shape_to_tensor, padded_shapes) self._padding_values = nest.map_structure_up_to( input_dataset.output_shapes, _padding_value_to_tensor, padding_values, input_dataset.output_types) - def _default_padding(self, input_dataset): - - def make_zero(t): - if t.base_dtype == dtypes.string: - return "" - else: - return np.zeros_like(t.as_numpy_dtype()) - - return nest.map_structure(make_zero, input_dataset.output_types) - def _as_variant_tensor(self): return gen_dataset_ops.padded_batch_dataset( self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access diff --git a/tensorflow/python/data/ops/iterator_ops.py b/tensorflow/python/data/ops/iterator_ops.py index 53a3244ce1948803be5d8ee0f7db12fa40c8a32f..e573fe01928b77dea55a782e4e86a00873346f07 100644 --- a/tensorflow/python/data/ops/iterator_ops.py +++ b/tensorflow/python/data/ops/iterator_ops.py @@ -25,6 +25,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import gen_dataset_ops +from tensorflow.python.util.tf_export import tf_export # NOTE(mrry): It is legitimate to call `Iterator.get_next()` multiple @@ -47,6 +48,7 @@ GET_NEXT_CALL_WARNING_MESSAGE = ( "`next_element` inside the loop.") +@tf_export("data.Iterator") class Iterator(object): """Represents the state of iterating through a `Dataset`.""" diff --git a/tensorflow/python/data/ops/readers.py b/tensorflow/python/data/ops/readers.py index 830dc5cec4a54469d001f0ba57d1adc7bc5efd11..fa7601741b11f018e9b53ed3b77a7561be50d3f4 100644 --- a/tensorflow/python/data/ops/readers.py +++ b/tensorflow/python/data/ops/readers.py @@ -23,12 +23,14 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import gen_dataset_ops +from tensorflow.python.util.tf_export import tf_export # TODO(b/64974358): Increase default buffer size to 256 MB. _DEFAULT_READER_BUFFER_SIZE_BYTES = 256 * 1024 # 256 KB +@tf_export("data.TextLineDataset") class TextLineDataset(Dataset): """A `Dataset` comprising lines from one or more text files.""" @@ -71,6 +73,7 @@ class TextLineDataset(Dataset): return dtypes.string +@tf_export("data.TFRecordDataset") class TFRecordDataset(Dataset): """A `Dataset` comprising records from one or more TFRecord files.""" @@ -115,6 +118,7 @@ class TFRecordDataset(Dataset): return dtypes.string +@tf_export("data.FixedLengthRecordDataset") class FixedLengthRecordDataset(Dataset): """A `Dataset` of fixed-length records from one or more binary files.""" diff --git a/tensorflow/python/data/util/nest.py b/tensorflow/python/data/util/nest.py index df5498be5f0643dd533cb522423e5728d389d7fb..e90ce3fb40af68fb68d6ee8bac6892848d8c5a79 100644 --- a/tensorflow/python/data/util/nest.py +++ b/tensorflow/python/data/util/nest.py @@ -383,8 +383,8 @@ def assert_shallow_structure(shallow_tree, input_tree, check_types=True): "structure has keys %s, while shallow structure has keys %s." % (list(_six.iterkeys(input_tree)), list(_six.iterkeys(shallow_tree)))) - input_tree = list(_six.iteritems(input_tree)) - shallow_tree = list(_six.iteritems(shallow_tree)) + input_tree = list(sorted(_six.iteritems(input_tree))) + shallow_tree = list(sorted(_six.iteritems(shallow_tree))) for shallow_branch, input_branch in zip(shallow_tree, input_tree): assert_shallow_structure(shallow_branch, input_branch, @@ -479,8 +479,8 @@ def map_structure_up_to(shallow_tree, func, *inputs): The `inputs`, can be thought of as having the same structure as `shallow_tree`, but with leaf nodes that are themselves tree structures. - This function, therefore, will return something with the same base structure as - `shallow_tree`. + This function, therefore, will return something with the same base structure + as `shallow_tree`. Examples: diff --git a/tensorflow/python/data/util/nest_test.py b/tensorflow/python/data/util/nest_test.py index 90dd7dfe7775b2f10611e5579784fbda63fc9669..ff380815a4a32192de621888199e66355f9b4635 100644 --- a/tensorflow/python/data/util/nest_test.py +++ b/tensorflow/python/data/util/nest_test.py @@ -277,6 +277,10 @@ class NestTest(test.TestCase): with self.assertRaisesRegexp(ValueError, expected_message): nest.assert_shallow_structure(inp_ab2, inp_ab1) + inp_ab = collections.OrderedDict([("a", 1), ("b", (2, 3))]) + inp_ba = collections.OrderedDict([("b", (2, 3)), ("a", 1)]) + nest.assert_shallow_structure(inp_ab, inp_ba) + def testFlattenUpTo(self): input_tree = (((2, 2), (3, 3)), ((4, 9), (5, 5))) shallow_tree = ((True, True), (False, True)) diff --git a/tensorflow/python/debug/cli/tensor_format.py b/tensorflow/python/debug/cli/tensor_format.py index d4aea76d652e7606939f3d8a89ff0378da0774d2..e0759a8bc1ab271906fc4ec75b55529f8a0d2b74 100644 --- a/tensorflow/python/debug/cli/tensor_format.py +++ b/tensorflow/python/debug/cli/tensor_format.py @@ -535,7 +535,7 @@ def numeric_summary(tensor): if not isinstance(tensor, np.ndarray) or not np.size(tensor): return debugger_cli_common.RichTextLines([ "No numeric summary available due to empty tensor."]) - elif (np.issubdtype(tensor.dtype, np.float) or + elif (np.issubdtype(tensor.dtype, np.floating) or np.issubdtype(tensor.dtype, np.complex) or np.issubdtype(tensor.dtype, np.integer)): counts = [ diff --git a/tensorflow/python/debug/lib/debug_data.py b/tensorflow/python/debug/lib/debug_data.py index c4b13a1045dac4966b0e841155a2932216881d34..8d355aa27f6fa10a1889420a9087800be12a81ce 100644 --- a/tensorflow/python/debug/lib/debug_data.py +++ b/tensorflow/python/debug/lib/debug_data.py @@ -222,7 +222,7 @@ def has_inf_or_nan(datum, tensor): # Also return False for data types that cannot be represented as numpy # arrays. return False - elif (np.issubdtype(tensor.dtype, np.float) or + elif (np.issubdtype(tensor.dtype, np.floating) or np.issubdtype(tensor.dtype, np.complex) or np.issubdtype(tensor.dtype, np.integer)): return np.any(np.isnan(tensor)) or np.any(np.isinf(tensor)) diff --git a/tensorflow/python/debug/lib/debug_gradients_test.py b/tensorflow/python/debug/lib/debug_gradients_test.py index b6c7280a415b367751c4900a302e5af61f260cb0..c1e9869d978e4f5ddfd3cd5f1abd7f5c97b7ca88 100644 --- a/tensorflow/python/debug/lib/debug_gradients_test.py +++ b/tensorflow/python/debug/lib/debug_gradients_test.py @@ -22,6 +22,7 @@ import shutil import tempfile from tensorflow.core.protobuf import config_pb2 +from tensorflow.core.protobuf import rewriter_config_pb2 from tensorflow.python.client import session from tensorflow.python.debug.lib import debug_data from tensorflow.python.debug.lib import debug_gradients @@ -38,7 +39,11 @@ from tensorflow.python.training import gradient_descent class IdentifyGradientTest(test_util.TensorFlowTestCase): def setUp(self): - self.sess = session.Session() + rewriter_config = rewriter_config_pb2.RewriterConfig( + dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF) + graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config) + config = config_pb2.ConfigProto(graph_options=graph_options) + self.sess = session.Session(config=config) with self.sess.as_default(): self.u = variables.Variable(2.0, name="u") self.v = variables.Variable(3.0, name="v") diff --git a/tensorflow/python/debug/lib/session_debug_grpc_test.py b/tensorflow/python/debug/lib/session_debug_grpc_test.py index 367b3535450ac4bd17d4c5dba0eaf149aa4b68b3..b623ee31c5dc59894373ec7952e53acd0f6e1126 100644 --- a/tensorflow/python/debug/lib/session_debug_grpc_test.py +++ b/tensorflow/python/debug/lib/session_debug_grpc_test.py @@ -54,7 +54,8 @@ from tensorflow.python.training import monitored_session def no_rewrite_session_config(): rewriter_config = rewriter_config_pb2.RewriterConfig( disable_model_pruning=True, - arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF) + arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF, + dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF) graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config) return config_pb2.ConfigProto(graph_options=graph_options) diff --git a/tensorflow/python/debug/wrappers/dumping_wrapper_test.py b/tensorflow/python/debug/wrappers/dumping_wrapper_test.py index acea9433e22203d56f4ceb6cd92b681e35876a09..254201c39371e2034b08fad927e98418c8086ea5 100644 --- a/tensorflow/python/debug/wrappers/dumping_wrapper_test.py +++ b/tensorflow/python/debug/wrappers/dumping_wrapper_test.py @@ -389,6 +389,11 @@ class DumpingDebugWrapperSessionTest(test_util.TensorFlowTestCase): r"mode\."): sess.invoke_node_stepper(node_stepper) + def testDumpingWrapperWithEmptyFetchWorks(self): + sess = dumping_wrapper.DumpingDebugWrapperSession( + self.sess, session_root=self.session_root, log_usage=False) + sess.run([]) + if __name__ == "__main__": googletest.main() diff --git a/tensorflow/python/debug/wrappers/framework.py b/tensorflow/python/debug/wrappers/framework.py index 909150eb6aa21b45af39f7cbfd6248c701ae1fb5..c530204bbf6959f56a72c6e67add91f1e575f067 100644 --- a/tensorflow/python/debug/wrappers/framework.py +++ b/tensorflow/python/debug/wrappers/framework.py @@ -121,7 +121,9 @@ from tensorflow.python.debug.lib import debug_utils from tensorflow.python.debug.lib import stepper from tensorflow.python.framework import errors from tensorflow.python.framework import ops +from tensorflow.python.platform import tf_logging from tensorflow.python.training import monitored_session +from tensorflow.python.util import nest # Helper function. @@ -439,7 +441,12 @@ class BaseDebugWrapperSession(session.SessionInterface): "callable_runner and fetches/feed_dict are mutually exclusive, but " "are used simultaneously.") - if self._is_disabled_thread(): + empty_fetches = not nest.flatten(fetches) + if empty_fetches: + tf_logging.info( + "Due to empty fetches, tfdbg Session wrapper is letting a " + "Session.run pass through without any debugging actions.") + if self._is_disabled_thread() or empty_fetches: if callable_runner: return callable_runner(*callable_runner_args) else: diff --git a/tensorflow/python/debug/wrappers/local_cli_wrapper_test.py b/tensorflow/python/debug/wrappers/local_cli_wrapper_test.py index 770a496aa9d2f4bb8bee0f51526ba8c3d4278b81..490812c96d83791cdc20c56f16c968f1a1851af8 100644 --- a/tensorflow/python/debug/wrappers/local_cli_wrapper_test.py +++ b/tensorflow/python/debug/wrappers/local_cli_wrapper_test.py @@ -664,6 +664,20 @@ class LocalCLIDebugWrapperSessionTest(test_util.TensorFlowTestCase): [["run"], ["run"]], monitored_sess) self.assertFalse(wrapped_monitored_sess.should_stop()) + def testRunsWithEmptyFetchWorks(self): + wrapped_sess = LocalCLIDebuggerWrapperSessionForTest( + [["run"]], self.sess, dump_root="") + + run_output = wrapped_sess.run([]) + self.assertEqual([], run_output) + + def testRunsWithEmptyNestedFetchWorks(self): + wrapped_sess = LocalCLIDebuggerWrapperSessionForTest( + [["run"]], self.sess, dump_root="") + + run_output = wrapped_sess.run({"foo": {"baz": []}, "bar": ()}) + self.assertEqual({"foo": {"baz": []}, "bar": ()}, run_output) + if __name__ == "__main__": googletest.main() diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD index 9e3382d4f301529cd2b476bc76efe7dfd2be9298..ab81d40148476735492890f608315b19eaa0a33f 100644 --- a/tensorflow/python/eager/BUILD +++ b/tensorflow/python/eager/BUILD @@ -206,29 +206,6 @@ cc_library( ], ) -cc_library( - name = "python_eager_op_gen_main", - srcs = [ - "python_eager_op_gen_main.cc", - ], - visibility = ["//visibility:public"], - deps = [ - ":python_eager_op_gen", - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:op_gen_lib", - "//tensorflow/core:protos_all_cc", - ], -) - -tf_cc_binary( - name = "python_eager_op_gen_demo", - deps = [ - ":python_eager_op_gen_main", - "//tensorflow/core:ops", - ], -) - py_library( name = "custom_gradient", srcs = ["custom_gradient.py"], diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py index d79d1fc0a6400a894293f3254d5cac5a10661e13..898f5e90f3718fb316f0ba20a14977ef352aee24 100644 --- a/tensorflow/python/eager/backprop.py +++ b/tensorflow/python/eager/backprop.py @@ -157,6 +157,8 @@ _ops_which_dont_need_outputs = set([ "SegmentMax", "UnsortedSegmentSum", "UnsortedSegmentMax", + "UnsortedSegmentMin", + "UnsortedSegmentProd", "Abs", "Neg", "ReciprocalGrad", diff --git a/tensorflow/python/eager/execution_callbacks.py b/tensorflow/python/eager/execution_callbacks.py index 2f1654dda499583fe4766cbe2e330399defc96fd..988442c971f7bf978f1848278fd4955d79428fc5 100644 --- a/tensorflow/python/eager/execution_callbacks.py +++ b/tensorflow/python/eager/execution_callbacks.py @@ -153,7 +153,7 @@ def inf_nan_callback(op_type, continue numpy_dtype = output.dtype.as_numpy_dtype - if (np.issubdtype(numpy_dtype, np.float) or + if (np.issubdtype(numpy_dtype, np.floating) or np.issubdtype(numpy_dtype, np.complex) or np.issubdtype(numpy_dtype, np.integer)): try: diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index 81b1f6f12a1899ddccb711a81122905bfd363748..246df9afefbf7665fa100bb601efcf3df717458d 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -292,6 +292,22 @@ def _map_sequence_obj_to_idx(sequence): return {id(x): i for i, x in enumerate(sequence)} +def _flatten(sequence): + """A wrapper around `nest.flatten` that also unpacks `IndexedSlices`.""" + # TODO(akshayka): Support `SparseTensor` in a similar fashion. + flat_sequence = nest.flatten(sequence) + outputs = [] + for item in flat_sequence: + if isinstance(item, ops.IndexedSlices): + if item.dense_shape is not None: + outputs.extend([item.values, item.indices, item.dense_shape]) + else: + outputs.extend([item.values, item.indices]) + else: + outputs.append(item) + return outputs + + class GraphModeFunction(object): """Callable object representing a graph-mode function. @@ -333,14 +349,14 @@ class GraphModeFunction(object): self._input_placeholders = input_placeholders self._extra_inputs = list(extra_inputs) self._graph = graph - self._has_backprop = False + self._backward_function = None self._func_name = name self._function_def = defined_function self._num_outputs = len(defined_function.signature.output_arg) self._ops = operations self._func_outputs = func_outputs self._returns = [func_outputs] if isinstance( - func_outputs, (ops.Tensor, type(None))) else list(func_outputs) + func_outputs, (ops.Tensor, type(None))) else _flatten(func_outputs) self._output_shapes = output_shapes self._variables = variables if variables is not None else [] @@ -348,9 +364,8 @@ class GraphModeFunction(object): def variables(self): return self._variables - def _compute_backprop(self): - """Computes the backprop function object for this function.""" - self._has_backprop = True + def _construct_backprop_function(self): + """Constructs the backprop function object for this function.""" with self._graph.as_default(), context.graph_mode(): c = _CapturingContext() with c: @@ -361,13 +376,16 @@ class GraphModeFunction(object): filtered_outputs, self._input_placeholders, grad_ys=self._out_grad_placeholders) - shapes = tuple(x.shape for x in in_gradients if x is not None) + + backward_outputs = tuple( + grad for grad in _flatten(in_gradients) if grad is not None) + output_shapes = tuple(grad.shape for grad in backward_outputs) + captures = list(sorted(c.captured_tensors, key=lambda x: x.name)) forward_name = _forward_name(self._func_name) self._forward_fdef = _EagerDefinedFunction( forward_name, self._graph, self._ops, self._input_placeholders, filtered_outputs + captures) - backward_outputs = tuple(x for x in in_gradients if x is not None) all_inputs = self._out_grad_placeholders + captures # Excluding input ops from the body as we do not intend to execute these # operations when the function is executed. @@ -381,7 +399,7 @@ class GraphModeFunction(object): bname = _backward_name(self._func_name) self._backward_function = GraphModeFunction( bname, all_inputs, [], self._graph, function_def_ops, - backward_outputs, in_gradients, shapes) + backward_outputs, in_gradients, output_shapes) def _backprop_call(self, args): """Calls the wrapped function and records the result on a tape.""" @@ -426,9 +444,24 @@ class GraphModeFunction(object): @property def output_shapes(self): + """The function's output shapes.""" # TODO(ebrevdo): Should we only keep the output shapes associated # with len(self._returns) outputs? - return nest.pack_sequence_as(self._func_outputs, self._output_shapes) + outputs_list = nest.flatten(self._func_outputs) + j = 0 + for i, o in enumerate(outputs_list): + if o is not None: + if isinstance(o, ops.IndexedSlices): + # Extract the shape of the `IndexedSlices` object's `values` field. + outputs_list[i] = self._output_shapes[j] # the `values` shape + if o.dense_shape is not None: + j += 3 # skip over shapes for `values`, `indices`, `dense_shape` + else: + j += 2 # skip over shapes for `values`, `indices` + else: + outputs_list[i] = self._output_shapes[j] + j += 1 + return nest.pack_sequence_as(self._func_outputs, outputs_list) @property def output_dtypes(self): @@ -457,12 +490,11 @@ class GraphModeFunction(object): if v._trainable: # pylint: disable=protected-access tape.watch_variable(v) - tensor_inputs = [x for x in nest.flatten(args) - if isinstance(x, ops.Tensor)] + tensor_inputs = [x for x in nest.flatten(args) if isinstance(x, ops.Tensor)] if tape.should_record(tensor_inputs) or tape.should_record( self._extra_inputs): - if not self._has_backprop: - self._compute_backprop() + if self._backward_function is None: + self._construct_backprop_function() return self._backprop_call(tensor_inputs) ctx = context.context() @@ -503,13 +535,30 @@ class GraphModeFunction(object): """ if self._func_outputs is None: return None + # Use `nest.flatten` instead of `_flatten` in order to preserve any + # IndexedSlices in `self._func_outputs`. outputs_list = nest.flatten(self._func_outputs) j = 0 for i, o in enumerate(outputs_list): if o is not None: - outputs_list[i] = result[j] - j += 1 - return nest.pack_sequence_as(self._func_outputs, outputs_list) + if isinstance(o, ops.IndexedSlices): + # Repack Tensors for IndexedSlices. + if o.dense_shape is not None: + outputs_list[i] = ops.IndexedSlices( + values=result[j], + indices=result[j + 1], + dense_shape=result[j + 2]) + j += 3 + else: + outputs_list[i] = ops.IndexedSlices( + values=result[j], + indices=result[j + 1]) + j += 2 + else: + outputs_list[i] = result[j] + j += 1 + ret = nest.pack_sequence_as(self._func_outputs, outputs_list) + return ret def _get_defun_inputs(args): @@ -526,15 +575,13 @@ def _get_defun_inputs(args): def _defun_internal(name, func, args, kwds): """Defines and returns graph-mode version of func.""" - container_prefix = ops.get_default_graph()._container_prefix # pylint: disable=protected-access + graph_key = ops.get_default_graph()._graph_key # pylint: disable=protected-access with context.graph_mode(): captures = {} tmp_graph = CapturingGraph(captures) - # Inherit the container prefix, since this is used for error checking when - # isolating eager execution (the container prefix at creation must match the - # container prefix when used, and variables accessed in the defun will be - # used in the outside context). - tmp_graph._container_prefix = container_prefix # pylint: disable=protected-access + # Inherit the graph key, since this is used for matching variables in + # optimizers. + tmp_graph._graph_key = graph_key # pylint: disable=protected-access # Copy the graph collections to ensure summaries and other things work. This # lets the function access (but not mutate) collections of the containing # graph, such as the global step and the summary writer collections. @@ -555,7 +602,7 @@ def _defun_internal(name, func, args, kwds): # Returning a closed-over tensor as an output does not trigger a # call to convert_to_tensor, so we manually capture all such tensors. - outputs_list = nest.flatten(func_outputs) + outputs_list = _flatten(func_outputs) func_def_outputs = [ _convert_to_graph_tensor(x) for x in outputs_list if x is not None ] @@ -600,6 +647,18 @@ def _cache_key(x): """Cache key for tfe functions.""" if isinstance(x, ops.Tensor): return _TensorDtype(x.dtype, x._shape_tuple()) # pylint: disable=protected-access + if isinstance(x, ops.IndexedSlices): + if x.dense_shape is not None: + return tuple([ + _TensorDtype(x.values.dtype, x.values._shape_tuple()), # pylint: disable=protected-access + _TensorDtype(x.indices.dtype, x.indices._shape_tuple()), # pylint: disable=protected-access + _TensorDtype(x.dense_shape.dtype, x.dense_shape._shape_tuple()) # pylint: disable=protected-access + ]) + else: + return tuple([ + _TensorDtype(x.values.dtype, x.values._shape_tuple()), # pylint: disable=protected-access + _TensorDtype(x.indices.dtype, x.indices._shape_tuple()) # pylint: disable=protected-access + ]) if isinstance(x, np.ndarray): return ("array", x.shape, tuple(x.reshape(-1))) if isinstance(x, (list, tuple)): diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py index 2cb2cfb76c7c1713c66a060fc227331bb1acaf71..3e8e67ac7e242887e1c4f7d89a2e2bc395db22fe 100644 --- a/tensorflow/python/eager/function_test.py +++ b/tensorflow/python/eager/function_test.py @@ -374,6 +374,78 @@ class FunctionTest(test.TestCase): self.assertAllEqual(f(constant_op.constant(1.0)), 2.0) + def testGradientOfGatherWithDefun(self): + + v = resource_variable_ops.ResourceVariable([0.0, 1.0, 2.0]) + + def sum_gather(): + return math_ops.reduce_sum(array_ops.gather(v, [1, 2])) + + grad_fn = backprop.implicit_grad(sum_gather) + gradient = grad_fn() + defun_grad_fn = backprop.implicit_grad(function.defun(sum_gather)) + defun_gradient = defun_grad_fn() + self.assertEqual(len(gradient), len(defun_gradient)) + + gradient = gradient[0][0] + defun_gradient = defun_gradient[0][0] + self.assertAllEqual(gradient.values, defun_gradient.values) + self.assertAllEqual(gradient.indices, defun_gradient.indices) + self.assertAllEqual(gradient.dense_shape, defun_gradient.dense_shape) + + def testReturningIndexedSlicesWithDefun(self): + + def validate(indexed_slice): + def f(): + return indexed_slice + + output = function.defun(f)() + self.assertTrue(isinstance(output, ops.IndexedSlices)) + self.assertAllEqual(indexed_slice.values, output.values) + self.assertAllEqual(indexed_slice.indices, output.indices) + self.assertAllEqual(indexed_slice.dense_shape, output.dense_shape) + + self.assertEqual( + function.make_defun_op(f).output_shapes, indexed_slice.values.shape) + + arg = ops.IndexedSlices( + values=constant_op.constant([1, 2]), + indices=constant_op.constant([0, 1]), + dense_shape=constant_op.constant([2])) + validate(arg) + + arg = ops.IndexedSlices( + values=constant_op.constant([1, 2]), + indices=constant_op.constant([0, 1]), + dense_shape=None) + validate(arg) + + def testIndexedSliceAsArgumentWithDefun(self): + + @function.defun + def f(indexed_slice): + return indexed_slice + + def validate(arg): + output = f(arg) + self.assertTrue(isinstance(output, ops.IndexedSlices)) + self.assertAllEqual(arg.values, output.values) + self.assertAllEqual(arg.indices, output.indices) + self.assertAllEqual(arg.dense_shape, output.dense_shape) + + indexed_slice = ops.IndexedSlices( + values=constant_op.constant([1]), + indices=constant_op.constant([0]), + dense_shape=constant_op.constant([1])) + validate(indexed_slice) + + # Test that `f` works even when `dense_shape` is None. + indexed_slice = ops.IndexedSlices( + values=constant_op.constant([1]), + indices=constant_op.constant([0]), + dense_shape=None) + validate(indexed_slice) + def testFunctionOnDevice(self): if not context.context().num_gpus(): self.skipTest('No GPUs found') diff --git a/tensorflow/python/eager/gen_op.bzl b/tensorflow/python/eager/gen_op.bzl deleted file mode 100644 index 8bc1d6c10a60b89a026cb34dbf6fd98d29e909c2..0000000000000000000000000000000000000000 --- a/tensorflow/python/eager/gen_op.bzl +++ /dev/null @@ -1,65 +0,0 @@ -"""For eager-mode Python.""" - -load("//tensorflow:tensorflow.bzl", - "clean_dep", - "tf_binary_additional_srcs", - "tf_copts", - "tf_cc_binary") - -def tfe_gen_op_wrapper_py(name, - out=None, - visibility=None, - deps=[], - generated_target_name=None, - # ApiDefs will be loaded in the order specified in this list. - api_def_srcs=[]): - """Generate an eager-mode Python op wrapper for an op library.""" - # Construct a cc_binary containing the specified ops. - tool_name = "gen_" + name + "_py_wrappers_cc" - if not deps: - deps = [str(Label("//tensorflow/core:" + name + "_op_lib"))] - tf_cc_binary( - name=tool_name, - linkopts=["-lm"], - copts=tf_copts(), - linkstatic=1, - deps=([ - clean_dep("//tensorflow/python/eager:python_eager_op_gen_main") - ] + deps), - visibility=[clean_dep("//visibility:public")],) - - # Invoke the previous cc_binary to generate a python file. - if not out: - out = "gen_" + name + ".py" - - if not api_def_srcs: - api_def_args_str = "," - else: - api_def_args = [] - for api_def_src in api_def_srcs: - # Add directory of the first ApiDef source to args. - # We are assuming all ApiDefs in a single api_def_src are in the - # same directory. - api_def_args.append( - "$$(dirname $$(echo $(locations " + api_def_src + - ") | cut -d\" \" -f1))") - api_def_args_str = ",".join(api_def_args) - - native.genrule( - name=name + "_pygenrule", - outs=[out], - srcs=api_def_srcs, - tools=[tool_name] + tf_binary_additional_srcs(), - cmd=("$(location " + tool_name + ") " + api_def_args_str + " > $@")) - - # Make a py_library out of the generated python file. - if not generated_target_name: - generated_target_name = name - native.py_library( - name=generated_target_name, - srcs=[out], - srcs_version="PY2AND3", - visibility=visibility, - deps=[ - clean_dep("//tensorflow/python/eager:framework_for_generated_wrappers"), - ],) diff --git a/tensorflow/python/eager/graph_callable.py b/tensorflow/python/eager/graph_callable.py index 5c13ea89081a7d060c0ed1201f0169b739a204c2..62106bf0e2809e3c056e4a357f3d05251b7dca68 100644 --- a/tensorflow/python/eager/graph_callable.py +++ b/tensorflow/python/eager/graph_callable.py @@ -252,21 +252,17 @@ def _graph_callable_internal(func, shape_and_dtypes): Callable graph object. """ container = tf_ops.get_default_graph()._container # pylint: disable=protected-access - container_prefix = tf_ops.get_default_graph()._container_prefix # pylint: disable=protected-access + graph_key = tf_ops.get_default_graph()._graph_key # pylint: disable=protected-access with context.graph_mode(): # This graph will store both the initialization and the call version of the # wrapped function. It will later be used by the backprop code to build the # backprop graph, if necessary. captures = {} tmp_graph = function.CapturingGraph(captures) - # Inherit the container from the original graph to create resources at user - # expected containers. Also inherits the container prefix, since this is - # used for error checking when isolating Eager execution (the container - # prefix at creation must match the container prefix when used, and - # variables returned from the graph callable will be used in the outside - # context). + # Inherit the graph key from the original graph to ensure optimizers don't + # misbehave. tmp_graph._container = container # pylint: disable=protected-access - tmp_graph._container_prefix = container_prefix # pylint: disable=protected-access + tmp_graph._graph_key = graph_key # pylint: disable=protected-access with tmp_graph.as_default(): # Placeholders for the non-variable inputs. func_inputs = _get_graph_callable_inputs(shape_and_dtypes) diff --git a/tensorflow/python/eager/python_eager_op_gen.cc b/tensorflow/python/eager/python_eager_op_gen.cc index 90a8779ff845b2fd63d1ba1019e8601fef257e42..0f18f28c9556c52c7f0d571dbdfa932dcb7e2721 100644 --- a/tensorflow/python/eager/python_eager_op_gen.cc +++ b/tensorflow/python/eager/python_eager_op_gen.cc @@ -756,11 +756,21 @@ from tensorflow.python.util.tf_export import tf_export auto out = cleaned_ops.mutable_op(); out->Reserve(ops.op_size()); for (const auto& op_def : ops.op()) { - bool is_hidden = false; - for (const string& hidden : hidden_ops) { - if (op_def.name() == hidden) { - is_hidden = true; - break; + const auto* api_def = api_defs.GetApiDef(op_def.name()); + + if (api_def->visibility() == ApiDef::SKIP) { + continue; + } + + // An op is hidden if either its ApiDef visibility is HIDDEN + // or it is in the hidden_ops list. + bool is_hidden = api_def->visibility() == ApiDef::HIDDEN; + if (!is_hidden) { + for (const string& hidden : hidden_ops) { + if (op_def.name() == hidden) { + is_hidden = true; + break; + } } } @@ -777,7 +787,6 @@ from tensorflow.python.util.tf_export import tf_export continue; } - const auto* api_def = api_defs.GetApiDef(op_def.name()); strings::StrAppend(&result, GetEagerPythonOp(op_def, *api_def, function_name)); diff --git a/tensorflow/python/eager/python_eager_op_gen_main.cc b/tensorflow/python/eager/python_eager_op_gen_main.cc deleted file mode 100644 index 05351bd8b115ae07482b82166974e86758bc7712..0000000000000000000000000000000000000000 --- a/tensorflow/python/eager/python_eager_op_gen_main.cc +++ /dev/null @@ -1,72 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/python/eager/python_eager_op_gen.h" - -#include -#include -#include - -#include "tensorflow/core/framework/op.h" -#include "tensorflow/core/framework/op_def.pb.h" -#include "tensorflow/core/framework/op_gen_lib.h" -#include "tensorflow/core/lib/io/path.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/platform/env.h" -#include "tensorflow/core/platform/init_main.h" - -namespace tensorflow { -namespace { - -void PrintAllPythonOps(const std::vector& hidden_ops, - const std::vector& api_def_dirs) { - OpList ops; - OpRegistry::Global()->Export(false, &ops); - - ApiDefMap api_def_map(ops); - if (!api_def_dirs.empty()) { - Env* env = Env::Default(); - - for (const auto& api_def_dir : api_def_dirs) { - std::vector api_files; - TF_CHECK_OK(env->GetMatchingPaths(io::JoinPath(api_def_dir, "*.pbtxt"), - &api_files)); - TF_CHECK_OK(api_def_map.LoadFileList(env, api_files)); - } - api_def_map.UpdateDocs(); - } - - PrintEagerPythonOps(ops, api_def_map, hidden_ops, true /* require_shapes */); -} - -} // namespace -} // namespace tensorflow - -int main(int argc, char* argv[]) { - tensorflow::port::InitMain(argv[0], &argc, &argv); - - // Usage: - // python_eager_op_gen_main api_def_dir1,api_def_dir2,... - if (argc == 1) { - tensorflow::PrintAllPythonOps({}, {}); - } else if (argc == 2) { - const std::vector api_def_dirs = - tensorflow::str_util::Split(argv[1], ",", - tensorflow::str_util::SkipEmpty()); - tensorflow::PrintAllPythonOps({}, api_def_dirs); - } else { - return -1; - } - return 0; -} diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc index 836998cfdc39ced78732a29b6a11329077239ca5..d927f3abedb88deddabd4c4d931d12053005a3ff 100644 --- a/tensorflow/python/eager/pywrap_tfe_src.cc +++ b/tensorflow/python/eager/pywrap_tfe_src.cc @@ -528,6 +528,34 @@ tensorflow::gtl::CompactPointerSet* GetTapeSet() { return tape_set; } +// A safe copy of the current tapeset. Does not get affected by other python +// threads changing the set of active tapes. +class SafeTapeSet { + public: + SafeTapeSet() : tape_set_(*GetTapeSet()) { + for (auto* tape : tape_set_) { + Py_INCREF(tape); + } + } + + ~SafeTapeSet() { + for (auto* tape : tape_set_) { + Py_DECREF(tape); + } + } + + tensorflow::gtl::CompactPointerSet::const_iterator begin() { + return tape_set_.begin(); + } + + tensorflow::gtl::CompactPointerSet::const_iterator end() { + return tape_set_.end(); + } + + private: + tensorflow::gtl::CompactPointerSet tape_set_; +}; + // xcode 7 doesn't define thread_local, so for compatibility we implement our // own. TODO(apassos) remove once we can deprecate xcode 7. #ifndef __APPLE__ @@ -718,10 +746,7 @@ void TFE_Py_TapeSetWatchVariable(PyObject* variable) { if (*ThreadTapeIsStopped()) { return; } - // Note: making a copy because watching a variable can trigger a change to the - // set of tapes by allowing python's garbage collector to run. - auto tape_set = *GetTapeSet(); - for (TFE_Py_Tape* tape : tape_set) { + for (TFE_Py_Tape* tape : SafeTapeSet()) { tape->tape->WatchVariable(variable); } } @@ -777,8 +802,7 @@ void TFE_Py_TapeSetRecordOperation(PyObject* op_type, PyObject* output_tensors, return; } - auto set = *GetTapeSet(); - for (TFE_Py_Tape* tape : set) { + for (TFE_Py_Tape* tape : SafeTapeSet()) { Py_INCREF(backward_function); tape->tape->RecordOperation( op_type_str, output_info, input_ids, backward_function, @@ -787,10 +811,7 @@ void TFE_Py_TapeSetRecordOperation(PyObject* op_type, PyObject* output_tensors, } void TFE_Py_TapeSetDeleteTrace(tensorflow::int64 tensor_id) { - // Note: making a copy because deleting the trace can trigger a change to the - // set of tapes by allowing python's garbage collector to run. - auto tape_set = *GetTapeSet(); - for (TFE_Py_Tape* tape : tape_set) { + for (TFE_Py_Tape* tape : SafeTapeSet()) { tape->tape->DeleteTrace(tensor_id); } } diff --git a/tensorflow/python/estimator/canned/baseline.py b/tensorflow/python/estimator/canned/baseline.py index 96e4ecd29fbcd4f4335077e9f81c5704ae2b9bec..138152ac1c6b2d7e399218208dd7bdf2d8136f5e 100644 --- a/tensorflow/python/estimator/canned/baseline.py +++ b/tensorflow/python/estimator/canned/baseline.py @@ -57,6 +57,7 @@ from tensorflow.python.ops import check_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import variable_scope +from tensorflow.python.ops.losses import losses from tensorflow.python.training import training_util # The default learning rate of 0.3 is a historical artifact of the initial @@ -220,7 +221,8 @@ class BaselineClassifier(estimator.Estimator): weight_column=None, label_vocabulary=None, optimizer='Ftrl', - config=None): + config=None, + loss_reduction=losses.Reduction.SUM): """Initializes a BaselineClassifier instance. Args: @@ -240,6 +242,8 @@ class BaselineClassifier(estimator.Estimator): optimizer to use for training. If not specified, will use `FtrlOptimizer` with a default learning rate of 0.3. config: `RunConfig` object to configure the runtime settings. + loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how + to reduce training loss over batch. Defaults to `SUM`. Returns: A `BaselineClassifier` estimator. @@ -249,11 +253,13 @@ class BaselineClassifier(estimator.Estimator): if n_classes == 2: head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss( # pylint: disable=protected-access weight_column=weight_column, - label_vocabulary=label_vocabulary) + label_vocabulary=label_vocabulary, + loss_reduction=loss_reduction) else: head = head_lib._multi_class_head_with_softmax_cross_entropy_loss( # pylint: disable=protected-access n_classes, weight_column=weight_column, - label_vocabulary=label_vocabulary) + label_vocabulary=label_vocabulary, + loss_reduction=loss_reduction) def _model_fn(features, labels, mode, config): return _baseline_model_fn( features=features, @@ -311,7 +317,8 @@ class BaselineRegressor(estimator.Estimator): label_dimension=1, weight_column=None, optimizer='Ftrl', - config=None): + config=None, + loss_reduction=losses.Reduction.SUM): """Initializes a BaselineRegressor instance. Args: @@ -328,13 +335,16 @@ class BaselineRegressor(estimator.Estimator): optimizer to use for training. If not specified, will use `FtrlOptimizer` with a default learning rate of 0.3. config: `RunConfig` object to configure the runtime settings. + loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how + to reduce training loss over batch. Defaults to `SUM`. Returns: A `BaselineRegressor` estimator. """ head = head_lib._regression_head_with_mean_squared_error_loss( # pylint: disable=protected-access label_dimension=label_dimension, - weight_column=weight_column) + weight_column=weight_column, + loss_reduction=loss_reduction) def _model_fn(features, labels, mode, config): return _baseline_model_fn( features=features, diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py index 96555b5e03c7a291480b3c30fe1f2c641c5c75e1..78d74b63d3ebea29e2dc3ab8f655efcc7ab8e130 100644 --- a/tensorflow/python/estimator/estimator.py +++ b/tensorflow/python/estimator/estimator.py @@ -478,13 +478,16 @@ class Estimator(object): estimator_spec = self._call_model_fn( features, None, model_fn_lib.ModeKeys.PREDICT, self.config) predictions = self._extract_keys(estimator_spec.predictions, predict_keys) + all_hooks = list(input_hooks) + all_hooks.extend(hooks) + all_hooks.extend(list(estimator_spec.prediction_hooks or [])) with training.MonitoredSession( session_creator=training.ChiefSessionCreator( checkpoint_filename_with_path=checkpoint_path, master=self._config.master, scaffold=estimator_spec.scaffold, config=self._session_config), - hooks=input_hooks + hooks) as mon_sess: + hooks=all_hooks) as mon_sess: while not mon_sess.should_stop(): preds_evaluated = mon_sess.run(predictions) if not isinstance(predictions, dict): diff --git a/tensorflow/python/estimator/estimator_test.py b/tensorflow/python/estimator/estimator_test.py index 833f3dcac3b97962c967cba9ac7ab53a3b9c61f1..39a5b998ebdcccfbeddf0fc96dab44dc91a289fa 100644 --- a/tensorflow/python/estimator/estimator_test.py +++ b/tensorflow/python/estimator/estimator_test.py @@ -1355,6 +1355,25 @@ class EstimatorPredictTest(test.TestCase): est.train(dummy_input_fn, steps=1) self.assertEqual(10., next(est.predict(dummy_input_fn))) + def test_predictionhooks_are_used(self): + hook = test.mock.MagicMock( + wraps=training.SessionRunHook(), spec=training.SessionRunHook) + + def _model_fn_hooks(features, labels, mode): + _, _ = features, labels + return model_fn_lib.EstimatorSpec( + mode=mode, + loss=constant_op.constant(0.), + train_op=state_ops.assign_add(training.get_global_step(), 1), + predictions=constant_op.constant([[10.]]), + prediction_hooks=[hook]) + + est = estimator.Estimator(model_fn=_model_fn_hooks) + est.train(dummy_input_fn, steps=1) + self.assertFalse(hook.begin.called) + next(est.predict(dummy_input_fn)) + self.assertTrue(hook.begin.called) + def test_warn_if_no_queue_runner(self): def _model_fn(features, labels, mode): diff --git a/tensorflow/python/estimator/export/export.py b/tensorflow/python/estimator/export/export.py index 51075731ddc52a55799958c3bfa6140f77404541..83251c79fc561e16ebddb638668b92b3c69b8af4 100644 --- a/tensorflow/python/estimator/export/export.py +++ b/tensorflow/python/estimator/export/export.py @@ -36,12 +36,14 @@ from tensorflow.python.platform import tf_logging as logging from tensorflow.python.saved_model import signature_constants from tensorflow.python.saved_model import signature_def_utils from tensorflow.python.util import compat +from tensorflow.python.util.tf_export import tf_export _SINGLE_FEATURE_DEFAULT_NAME = 'feature' _SINGLE_RECEIVER_DEFAULT_NAME = 'input' +@tf_export('estimator.export.ServingInputReceiver') class ServingInputReceiver(collections.namedtuple( 'ServingInputReceiver', ['features', 'receiver_tensors', 'receiver_tensors_alternatives'])): @@ -118,6 +120,7 @@ class ServingInputReceiver(collections.namedtuple( receiver_tensors_alternatives=receiver_tensors_alternatives) +@tf_export('estimator.export.build_parsing_serving_input_receiver_fn') def build_parsing_serving_input_receiver_fn(feature_spec, default_batch_size=None): """Build a serving_input_receiver_fn expecting fed tf.Examples. @@ -146,6 +149,7 @@ def build_parsing_serving_input_receiver_fn(feature_spec, return serving_input_receiver_fn +@tf_export('estimator.export.build_raw_serving_input_receiver_fn') def build_raw_serving_input_receiver_fn(features, default_batch_size=None): """Build a serving_input_receiver_fn expecting feature Tensors. diff --git a/tensorflow/python/estimator/export/export_output.py b/tensorflow/python/estimator/export/export_output.py index 863af6d41d985043542b03375372fe564c283b82..87b964be37197dac99b8ce4398cbdaf3b4989c7f 100644 --- a/tensorflow/python/estimator/export/export_output.py +++ b/tensorflow/python/estimator/export/export_output.py @@ -26,8 +26,10 @@ import six from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.saved_model import signature_def_utils +from tensorflow.python.util.tf_export import tf_export +@tf_export('estimator.export.ExportOutput') class ExportOutput(object): """Represents an output of a model that can be served. @@ -50,6 +52,7 @@ class ExportOutput(object): pass +@tf_export('estimator.export.ClassificationOutput') class ClassificationOutput(ExportOutput): """Represents the output of a classification head. @@ -118,6 +121,7 @@ class ClassificationOutput(ExportOutput): examples, self.classes, self.scores) +@tf_export('estimator.export.RegressionOutput') class RegressionOutput(ExportOutput): """Represents the output of a regression head.""" @@ -153,6 +157,7 @@ class RegressionOutput(ExportOutput): _SINGLE_OUTPUT_DEFAULT_NAME = 'output' +@tf_export('estimator.export.PredictOutput') class PredictOutput(ExportOutput): """Represents the output of a generic prediction head. diff --git a/tensorflow/python/estimator/inputs/numpy_io.py b/tensorflow/python/estimator/inputs/numpy_io.py index c4c2e30e8771c5cb1e492fed751c71583dcf477b..a6f471291008e3c27dea1aeea5865e334f76e5c8 100644 --- a/tensorflow/python/estimator/inputs/numpy_io.py +++ b/tensorflow/python/estimator/inputs/numpy_io.py @@ -24,6 +24,7 @@ import numpy as np from six import string_types from tensorflow.python.estimator.inputs.queues import feeding_functions +from tensorflow.python.util.tf_export import tf_export # Key name to pack the target into dict of `features`. See # `_get_unique_target_key` for details. @@ -86,6 +87,7 @@ def _validate_and_convert_features(x): return ordered_dict_data +@tf_export('estimator.inputs.numpy_input_fn') def numpy_input_fn(x, y=None, batch_size=128, diff --git a/tensorflow/python/estimator/inputs/pandas_io.py b/tensorflow/python/estimator/inputs/pandas_io.py index 90d6145377d8f931b94793f8a912f77f1620f16e..bd06843021f47f81fc0c22d0fcee43530dc10098 100644 --- a/tensorflow/python/estimator/inputs/pandas_io.py +++ b/tensorflow/python/estimator/inputs/pandas_io.py @@ -21,6 +21,7 @@ from __future__ import print_function import numpy as np from tensorflow.python.estimator.inputs.queues import feeding_functions +from tensorflow.python.util.tf_export import tf_export try: # pylint: disable=g-import-not-at-top @@ -34,6 +35,7 @@ except ImportError: HAS_PANDAS = False +@tf_export('estimator.inputs.pandas_input_fn') def pandas_input_fn(x, y=None, batch_size=128, diff --git a/tensorflow/python/estimator/model_fn.py b/tensorflow/python/estimator/model_fn.py index da202408c3680b397994620e221fa4937d7c65e4..b08f83fc569b1bb1ea6e5c93c57be7b5bb96f0a5 100644 --- a/tensorflow/python/estimator/model_fn.py +++ b/tensorflow/python/estimator/model_fn.py @@ -56,7 +56,7 @@ class EstimatorSpec( collections.namedtuple('EstimatorSpec', [ 'mode', 'predictions', 'loss', 'train_op', 'eval_metric_ops', 'export_outputs', 'training_chief_hooks', 'training_hooks', 'scaffold', - 'evaluation_hooks' + 'evaluation_hooks', 'prediction_hooks' ])): """Ops and objects returned from a `model_fn` and passed to an `Estimator`. @@ -73,7 +73,8 @@ class EstimatorSpec( training_chief_hooks=None, training_hooks=None, scaffold=None, - evaluation_hooks=None): + evaluation_hooks=None, + prediction_hooks=None): """Creates a validated `EstimatorSpec` instance. Depending on the value of `mode`, different arguments are required. Namely @@ -154,6 +155,8 @@ class EstimatorSpec( initialization, saver, and more to be used in training. evaluation_hooks: Iterable of `tf.train.SessionRunHook` objects to run during evaluation. + prediction_hooks: Iterable of `tf.train.SessionRunHook` objects to + run during predictions. Returns: A validated `EstimatorSpec` object. @@ -282,7 +285,10 @@ class EstimatorSpec( training_chief_hooks = tuple(training_chief_hooks or []) training_hooks = tuple(training_hooks or []) evaluation_hooks = tuple(evaluation_hooks or []) - for hook in training_hooks + training_chief_hooks + evaluation_hooks: + prediction_hooks = tuple(prediction_hooks or []) + + for hook in (training_hooks + training_chief_hooks + evaluation_hooks + + prediction_hooks): if not isinstance(hook, session_run_hook.SessionRunHook): raise TypeError( 'All hooks must be SessionRunHook instances, given: {}'.format( @@ -305,7 +311,8 @@ class EstimatorSpec( training_chief_hooks=training_chief_hooks, training_hooks=training_hooks, scaffold=scaffold, - evaluation_hooks=evaluation_hooks) + evaluation_hooks=evaluation_hooks, + prediction_hooks=prediction_hooks) def _replace(self, **kwds): """Return a new EstimatorSpec replacing specified fields with new values.""" diff --git a/tensorflow/python/estimator/model_fn_test.py b/tensorflow/python/estimator/model_fn_test.py index d67c4b716161816d941eef94a4b9aeb0643de55e..b7eeeb437cb4a624cdee552be3032364b18a8290 100644 --- a/tensorflow/python/estimator/model_fn_test.py +++ b/tensorflow/python/estimator/model_fn_test.py @@ -72,7 +72,8 @@ class EstimatorSpecTrainTest(test.TestCase): training_chief_hooks=[_FakeHook()], training_hooks=[_FakeHook()], scaffold=monitored_session.Scaffold(), - evaluation_hooks=[_FakeHook()]) + evaluation_hooks=[_FakeHook()], + prediction_hooks=[_FakeHook()]) def testLossNumber(self): """Tests that error is raised when loss is a number (not Tensor).""" @@ -465,7 +466,17 @@ class EstimatorSpecInferTest(test.TestCase): training_chief_hooks=[_FakeHook()], training_hooks=[_FakeHook()], scaffold=monitored_session.Scaffold(), - evaluation_hooks=[_FakeHook()]) + evaluation_hooks=[_FakeHook()], + prediction_hooks=[_FakeHook()]) + + def testPredictionHookInvalid(self): + with ops.Graph().as_default(), self.test_session(): + with self.assertRaisesRegexp( + TypeError, 'All hooks must be SessionRunHook instances'): + model_fn.EstimatorSpec( + mode=model_fn.ModeKeys.PREDICT, + predictions=constant_op.constant(1.), + prediction_hooks=[_InvalidHook()]) def testPredictionsMissing(self): with ops.Graph().as_default(), self.test_session(): diff --git a/tensorflow/python/estimator/training.py b/tensorflow/python/estimator/training.py index 52fb1d39ae2e9c84e4269785a72be4f9a495b73c..2e84c5014f6e17b34b38d2dfe5711b5b654553bb 100644 --- a/tensorflow/python/estimator/training.py +++ b/tensorflow/python/estimator/training.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - """Classes and functions related to train_and_evaluate.""" from __future__ import absolute_import @@ -37,7 +36,6 @@ from tensorflow.python.training import server_lib from tensorflow.python.training import session_run_hook from tensorflow.python.util import compat - _MAX_DELAY_SECS = 60 _DELAY_SECS_PER_WORKER = 5 _TF_CONFIG_ENV = 'TF_CONFIG' @@ -50,8 +48,7 @@ _TRAINER_JOBS = (run_config_lib.TaskType.CHIEF, run_config_lib.TaskType.MASTER, def _validate_input_fn(input_fn): """Validates the `input_fn`.""" if not callable(input_fn): - raise TypeError( - '`input_fn` must be callable, given: {}'.format(input_fn)) + raise TypeError('`input_fn` must be callable, given: {}'.format(input_fn)) def _validate_hooks(hooks): @@ -125,10 +122,7 @@ class TrainSpec( duration. Optional hooks run at various stages of training. """ - def __new__(cls, - input_fn, - max_steps=None, - hooks=None): + def __new__(cls, input_fn, max_steps=None, hooks=None): """Creates a validated `TrainSpec` instance. Args: @@ -161,16 +155,13 @@ class TrainSpec( hooks = _validate_hooks(hooks) return super(TrainSpec, cls).__new__( - cls, - input_fn=input_fn, - max_steps=max_steps, - hooks=hooks) + cls, input_fn=input_fn, max_steps=max_steps, hooks=hooks) class EvalSpec( collections.namedtuple('EvalSpec', [ - 'input_fn', 'steps', 'name', 'hooks', 'exporters', - 'start_delay_secs', 'throttle_secs' + 'input_fn', 'steps', 'name', 'hooks', 'exporters', 'start_delay_secs', + 'throttle_secs' ])): """Configuration for the "eval" part for the `train_and_evaluate` call. @@ -417,8 +408,8 @@ def train_and_evaluate(estimator, train_spec, eval_spec): Raises: ValueError: if environment variable `TF_CONFIG` is incorrectly set. """ - executor = _TrainingExecutor(estimator=estimator, train_spec=train_spec, - eval_spec=eval_spec) + executor = _TrainingExecutor( + estimator=estimator, train_spec=train_spec, eval_spec=eval_spec) config = estimator.config if (config.task_type == run_config_lib.TaskType.EVALUATOR and @@ -561,9 +552,8 @@ class _TrainingExecutor(object): self._timer.update_last_triggered_step(global_step_value) self._evaluator.evaluate_and_export() else: - logging.info( - 'Skip the current checkpoint eval due to throttle secs ' - '({} secs).'.format(self._eval_throttle_secs)) + logging.info('Skip the current checkpoint eval due to throttle secs ' + '({} secs).'.format(self._eval_throttle_secs)) # Final export signal: For any eval result with global_step >= train # max_steps, the evaluator will send the final export signal. There is a @@ -576,8 +566,8 @@ class _TrainingExecutor(object): # # But here, throttle_secs will skip the next intermediate checkpoint and, # so, the double final export chance is very small. - evaluator = _TrainingExecutor._Evaluator( - self._estimator, self._eval_spec, self._train_spec.max_steps) + evaluator = _TrainingExecutor._Evaluator(self._estimator, self._eval_spec, + self._train_spec.max_steps) # When the underlying `Estimator` object saves a new checkpoint, we would # like this callback to be called so that evaluation and export can trigger. @@ -617,8 +607,7 @@ class _TrainingExecutor(object): raise ValueError('eval_spec.throttle_secs should be positive, given: {}.' 'It is used do determine how long each training ' 'iteration should go when train and evaluate ' - 'locally.'.format( - self._eval_spec.throttle_secs)) + 'locally.'.format(self._eval_spec.throttle_secs)) stop_hook = _StopAtSecsHook(self._eval_spec.throttle_secs) train_hooks = ( @@ -663,8 +652,9 @@ class _TrainingExecutor(object): if not config.master: jobs = config.cluster_spec.jobs - if (len(jobs) == 1 and len(config.cluster_spec.job_tasks(jobs[0])) == 1 - and config.task_type in _TRAINER_JOBS): + if (len(jobs) == 1 and + len(config.cluster_spec.job_tasks(jobs[0])) == 1 and + config.task_type in _TRAINER_JOBS): # For distributed training, config.master is empty if and only if it has # a single node in the cluster spec. In this case, we should not start # the server. @@ -679,9 +669,9 @@ class _TrainingExecutor(object): logging.info('Start Tensorflow server.') if config.session_config is None: - session_config=config_pb2.ConfigProto(log_device_placement=False) + session_config = config_pb2.ConfigProto(log_device_placement=False) else: - session_config=config_pb2.ConfigProto( + session_config = config_pb2.ConfigProto( log_device_placement=False, gpu_options=config.session_config.gpu_options) @@ -744,8 +734,7 @@ class _TrainingExecutor(object): global_step >= self._train_spec.max_steps): logging.info( 'Exiting evaluation, global_step=%s >= train max_steps=%s', - global_step, - self._train_spec.max_steps) + global_step, self._train_spec.max_steps) return latest_eval_result, should_early_stop = self._execute_evaluator_once( @@ -781,10 +770,9 @@ class _TrainingExecutor(object): # Throttle if necessary. elapsed_time = time.time() - start - difference = throttle_secs - elapsed_time + difference = throttle_secs - elapsed_time if difference > 0: - logging.info('Waiting %f secs before starting next eval run.', - difference) + logging.info('Waiting %f secs before starting next eval run.', difference) time.sleep(difference) return (eval_result, should_early_stop) @@ -929,8 +917,8 @@ class _EvalResult( if checkpoint_path: raise ValueError( 'checkpoint must be `None` if status is not {}; got status {}, ' - 'checkpoint_path {}'.format( - _EvalStatus.EVALUATED, status, checkpoint_path)) + 'checkpoint_path {}'.format(_EvalStatus.EVALUATED, status, + checkpoint_path)) return super(_EvalResult, cls).__new__(cls, status, metrics, checkpoint_path) diff --git a/tensorflow/python/estimator/warm_starting_util.py b/tensorflow/python/estimator/warm_starting_util.py index ad95c71234f82457cb938ca55214b28086b033a2..48110ef57fcba897cc495323973b2f6761c3add4 100644 --- a/tensorflow/python/estimator/warm_starting_util.py +++ b/tensorflow/python/estimator/warm_starting_util.py @@ -415,8 +415,8 @@ def _warm_start(warm_start_settings): a stronger check for variable configuration than relying on users to examine the logs. """ - logging.info("Warm-starting from: ", - warm_start_settings.ckpt_to_initialize_from) + logging.info("Warm-starting from: %s", + (warm_start_settings.ckpt_to_initialize_from,)) # We have to deal with partitioned variables, since get_collection flattens # out the list. grouped_variables = {} diff --git a/tensorflow/python/feature_column/feature_column.py b/tensorflow/python/feature_column/feature_column.py index 7feb209cc49c4be70387c44168dbdeea6d108d66..5947d8f6e2348b12dae8f8ee05c26ecd9e342fcd 100644 --- a/tensorflow/python/feature_column/feature_column.py +++ b/tensorflow/python/feature_column/feature_column.py @@ -157,6 +157,7 @@ from tensorflow.python.platform import gfile from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import checkpoint_utils from tensorflow.python.util import nest +from tensorflow.python.util.tf_export import tf_export def _internal_input_layer(features, @@ -209,6 +210,7 @@ def _internal_input_layer(features, return array_ops.concat(output_tensors, 1) +@tf_export('feature_column.input_layer') def input_layer(features, feature_columns, weight_collections=None, @@ -329,6 +331,7 @@ class InputLayer(object): return self._input_layer_template.weights +@tf_export('feature_column.linear_model') def linear_model(features, feature_columns, units=1, @@ -498,6 +501,7 @@ def _transform_features(features, feature_columns): return outputs +@tf_export('feature_column.make_parse_example_spec') def make_parse_example_spec(feature_columns): """Creates parsing spec dictionary from input feature_columns. @@ -557,6 +561,7 @@ def make_parse_example_spec(feature_columns): return result +@tf_export('feature_column.embedding_column') def embedding_column( categorical_column, dimension, combiner='mean', initializer=None, ckpt_to_load_from=None, tensor_name_in_ckpt=None, max_norm=None, @@ -807,6 +812,7 @@ def shared_embedding_columns( return result +@tf_export('feature_column.numeric_column') def numeric_column(key, shape=(1,), default_value=None, @@ -881,6 +887,7 @@ def numeric_column(key, normalizer_fn=normalizer_fn) +@tf_export('feature_column.bucketized_column') def bucketized_column(source_column, boundaries): """Represents discretized dense input. @@ -970,6 +977,7 @@ def _assert_string_or_int(dtype, prefix): '{} dtype must be string or integer. dtype: {}.'.format(prefix, dtype)) +@tf_export('feature_column.categorical_column_with_hash_bucket') def categorical_column_with_hash_bucket(key, hash_bucket_size, dtype=dtypes.string): @@ -1026,6 +1034,7 @@ def categorical_column_with_hash_bucket(key, return _HashedCategoricalColumn(key, hash_bucket_size, dtype) +@tf_export('feature_column.categorical_column_with_vocabulary_file') def categorical_column_with_vocabulary_file(key, vocabulary_file, vocabulary_size=None, @@ -1145,6 +1154,7 @@ def categorical_column_with_vocabulary_file(key, dtype=dtype) +@tf_export('feature_column.categorical_column_with_vocabulary_list') def categorical_column_with_vocabulary_list( key, vocabulary_list, dtype=None, default_value=-1, num_oov_buckets=0): """A `_CategoricalColumn` with in-memory vocabulary. @@ -1255,6 +1265,7 @@ def categorical_column_with_vocabulary_list( default_value=default_value, num_oov_buckets=num_oov_buckets) +@tf_export('feature_column.categorical_column_with_identity') def categorical_column_with_identity(key, num_buckets, default_value=None): """A `_CategoricalColumn` that returns identity values. @@ -1322,6 +1333,7 @@ def categorical_column_with_identity(key, num_buckets, default_value=None): key=key, num_buckets=num_buckets, default_value=default_value) +@tf_export('feature_column.indicator_column') def indicator_column(categorical_column): """Represents multi-hot representation of given categorical column. @@ -1350,6 +1362,7 @@ def indicator_column(categorical_column): return _IndicatorColumn(categorical_column) +@tf_export('feature_column.weighted_categorical_column') def weighted_categorical_column( categorical_column, weight_feature_key, dtype=dtypes.float32): """Applies weight values to a `_CategoricalColumn`. @@ -1424,6 +1437,7 @@ def weighted_categorical_column( dtype=dtype) +@tf_export('feature_column.crossed_column') def crossed_column(keys, hash_bucket_size, hash_key=None): """Returns a column for performing crosses of categorical features. diff --git a/tensorflow/python/framework/function_test.py b/tensorflow/python/framework/function_test.py index a4ca3f9a89bd4cce2240d90895c43dda1acb849b..b35cee0111266fd7b744161d8a7e75664cbda122 100644 --- a/tensorflow/python/framework/function_test.py +++ b/tensorflow/python/framework/function_test.py @@ -19,8 +19,8 @@ from __future__ import division from __future__ import print_function import re -import time import sys +import time import numpy as np @@ -86,6 +86,21 @@ class FunctionTest(test.TestCase): with session.Session() as sess: self.assertAllEqual([18.0], sess.run(call)) + def testIdentityImplicitDeref(self): + + @function.Defun(dtypes.float32, func_name="MyIdentity") + def MyIdentityFunc(a): + return a + + with ops.Graph().as_default(): + var = variables.Variable([18.0]) + call = MyIdentityFunc(var._ref()) # pylint: disable=protected-access + self.assertEqual("MyIdentity", call.op.name) + for cfg in _OptimizerOptions(): + with session.Session(config=cfg) as sess: + sess.run(var.initializer) + self.assertAllEqual([18.0], sess.run(call)) + def testIdentityOutputName(self): @function.Defun( @@ -771,7 +786,7 @@ class FunctionTest(test.TestCase): # We added more randomness to function names in C API. # TODO(iga): Remove this if statement when we switch to C API. if ops._USE_C_API: # pylint: disable=protected-access - if sys.byteorder == 'big': + if sys.byteorder == "big": self.assertEqual("Foo_kEdkAG8SJvg", Foo.instantiate([dtypes.float32] * 3).name) else: diff --git a/tensorflow/python/framework/meta_graph.py b/tensorflow/python/framework/meta_graph.py index fc1a82361ba59cddc02a65a96da98283d871fd2c..8c03a5f19dee31a6609590e46d608af9a686c5fe 100644 --- a/tensorflow/python/framework/meta_graph.py +++ b/tensorflow/python/framework/meta_graph.py @@ -87,6 +87,10 @@ def _node_def(from_node_def, export_scope, unbound_inputs, clear_devices=False): compat.as_str(s).split("@")[1].startswith(export_scope)] node_def.attr[k].CopyFrom(attr_value_pb2.AttrValue( list=attr_value_pb2.AttrValue.ListValue(s=new_s))) + elif node_def.op in ("Enter", "RefEnter") and k == "frame_name": + if not export_scope or compat.as_str(v.s).startswith(export_scope): + new_s = compat.as_bytes(ops.strip_name_scope(v.s, export_scope)) + node_def.attr[k].CopyFrom(attr_value_pb2.AttrValue(s=new_s)) else: node_def.attr[k].CopyFrom(v) @@ -959,5 +963,3 @@ def copy_scoped_meta_graph(from_scope, to_scope, graph=to_graph, import_scope=to_scope) return var_list - - diff --git a/tensorflow/python/framework/meta_graph_test.py b/tensorflow/python/framework/meta_graph_test.py index b5ed1352843eac31b3e34eb96385acd13a5bc7a9..f2f1e83da15eacdbb4f194967b51559d279ae1a4 100644 --- a/tensorflow/python/framework/meta_graph_test.py +++ b/tensorflow/python/framework/meta_graph_test.py @@ -25,6 +25,7 @@ import shutil from tensorflow.core.framework import graph_pb2 from tensorflow.core.protobuf import meta_graph_pb2 +from tensorflow.python.client import session from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import function @@ -34,6 +35,7 @@ from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import data_flow_ops +from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import math_ops from tensorflow.python.ops import metrics from tensorflow.python.ops import nn_ops @@ -447,6 +449,56 @@ class ScopedMetaGraphTest(test.TestCase): del b.collection_def["unbound_inputs"] test_util.assert_meta_graph_protos_equal(self, a, b) + def testWhileLoopGradients(self): + # Create a simple while loop. + with ops.Graph().as_default(): + with ops.name_scope("export"): + var = variables.Variable(0) + var_name = var.name + _, output = control_flow_ops.while_loop(lambda i, x: i < 5, + lambda i, x: (i + 1, x + i), + [0, var]) + output_name = output.name + + # Generate a MetaGraphDef containing the while loop with an export scope. + meta_graph_def, _ = meta_graph.export_scoped_meta_graph( + export_scope="export") + + # Build and run the gradients of the while loop. We use this below to + # verify that the gradients are correct with the imported MetaGraphDef. + init_op = variables.global_variables_initializer() + grad = gradients_impl.gradients([output], [var]) + with session.Session() as sess: + sess.run(init_op) + expected_grad_value = sess.run(grad) + + # Restore the MetaGraphDef into a new Graph with an import scope. + with ops.Graph().as_default(): + meta_graph.import_scoped_meta_graph(meta_graph_def, import_scope="import") + + # Re-export and make sure we get the same MetaGraphDef. + new_meta_graph_def, _ = meta_graph.export_scoped_meta_graph( + export_scope="import") + test_util.assert_meta_graph_protos_equal( + self, meta_graph_def, new_meta_graph_def) + + # Make sure we can still build gradients and get the same result. + + def new_name(tensor_name): + base_tensor_name = tensor_name.replace("export/", "") + return "import/" + base_tensor_name + + var = ops.get_default_graph().get_tensor_by_name(new_name(var_name)) + output = ops.get_default_graph().get_tensor_by_name(new_name(output_name)) + grad = gradients_impl.gradients([output], [var]) + + init_op = variables.global_variables_initializer() + + with session.Session() as sess: + sess.run(init_op) + actual_grad_value = sess.run(grad) + self.assertEqual(expected_grad_value, actual_grad_value) + def testScopedImportUnderNameScope(self): graph = ops.Graph() with graph.as_default(): diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index d5786cac68dc31210f45f0af9ff6c347d93c026f..ea589cc4d401607621d916f775f30986bdffba8f 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -2760,15 +2760,12 @@ class Graph(object): self._handle_movers = {} # A map from tensor handle to its delete op. self._handle_deleters = {} - # Resource container. - if context.in_graph_mode(): - self._container_prefix = "" - else: - # In Eager mode, isolate resources (particularly ResourceVariables) in - # Graphs by default. This prevents unintended variable sharing. Graph mode - # gets this kind of isolation from Sessions. - self._container_prefix = "eager-execution-%d/" % (uid(),) - self._container = self._container_prefix + # Allow optimizers and other objects to pseudo-uniquely key graphs (this key + # will be shared when defining function graphs, for example, so optimizers + # being called inside function definitions behave as if they were seeing the + # actual outside graph). + self._graph_key = "grap-key-%d/" % (uid(),) + self._container = "" self._registered_ops = op_def_registry.get_registered_ops() # TODO(skyewm): fold as much of the above as possible into the C @@ -4229,7 +4226,7 @@ class Graph(object): """ original_container = self._container try: - self._container = self._container_prefix + container_name + self._container = container_name yield self._container finally: self._container = original_container diff --git a/tensorflow/python/framework/python_op_gen.cc b/tensorflow/python/framework/python_op_gen.cc index 65810fa7094409c7429dbaaa6c1e62efb263eafc..85cba59be4d8337322d6057d818341428649ecba 100644 --- a/tensorflow/python/framework/python_op_gen.cc +++ b/tensorflow/python/framework/python_op_gen.cc @@ -476,9 +476,6 @@ GenPythonOp::GenPythonOp(const OpDef& op_def, const ApiDef& api_def, GenPythonOp::~GenPythonOp() {} string GenPythonOp::Code() { - if (api_def_.visibility() == ApiDef::SKIP) { - return ""; - } // This has all the input args followed by those attrs that don't have // defaults. std::vector params_no_default; @@ -805,11 +802,21 @@ from tensorflow.python.util.tf_export import tf_export auto out = cleaned_ops.mutable_op(); out->Reserve(ops.op_size()); for (const auto& op_def : ops.op()) { - bool is_hidden = false; - for (const string& hidden : hidden_ops) { - if (op_def.name() == hidden) { - is_hidden = true; - break; + const auto* api_def = api_defs.GetApiDef(op_def.name()); + + if (api_def->visibility() == ApiDef::SKIP) { + continue; + } + + // An op is hidden if either its ApiDef visibility is HIDDEN + // or it is in the hidden_ops list. + bool is_hidden = api_def->visibility() == ApiDef::HIDDEN; + if (!is_hidden) { + for (const string& hidden : hidden_ops) { + if (op_def.name() == hidden) { + is_hidden = true; + break; + } } } @@ -826,7 +833,6 @@ from tensorflow.python.util.tf_export import tf_export continue; } - const auto* api_def = api_defs.GetApiDef(op_def.name()); strings::StrAppend(&result, GetPythonOp(op_def, *api_def, function_name)); if (!require_shapes) { diff --git a/tensorflow/python/framework/tensor_spec.py b/tensorflow/python/framework/tensor_spec.py new file mode 100644 index 0000000000000000000000000000000000000000..a0411bc3d9b4b2b87e5a31e9f201154f28ccf1cc --- /dev/null +++ b/tensorflow/python/framework/tensor_spec.py @@ -0,0 +1,201 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""A TensorSpec class.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.framework import common_shapes +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape + + +class TensorSpec(object): + """Describes a tf.Tensor. + + A TensorSpec allows an API to describe the Tensors that it accepts or + returns, before that Tensor exists. This allows dynamic and flexible graph + construction and configuration. + """ + + __slots__ = ["_shape", "_dtype", "_name"] + + def __init__(self, shape, dtype, name=None): + """Creates a TensorSpec. + + Args: + shape: Value convertible to `tf.TensorShape`. The shape of the tensor. + dtype: Value convertible to `tf.DType`. The type of the tensor values. + name: Optional name for the Tensor. + + Raises: + TypeError: If shape is not convertible to a `tf.TensorShape`, or dtype is + not convertible to a `tf.DType`. + """ + self._shape = tensor_shape.TensorShape(shape) + self._dtype = dtypes.as_dtype(dtype) + self._name = name + + @classmethod + def from_spec(cls, spec, name=None): + return cls(spec.shape, spec.dtype, name or spec.name) + + @classmethod + def from_tensor(cls, tensor, name=None): + if isinstance(tensor, ops.EagerTensor): + return TensorSpec(tensor.shape, tensor.dtype, name) + elif isinstance(tensor, ops.Tensor): + return TensorSpec(tensor.shape, tensor.dtype, name or tensor.op.name) + else: + raise ValueError("`tensor` should be a tf.Tensor") + + @property + def shape(self): + """Returns the `TensorShape` that represents the shape of the tensor.""" + return self._shape + + @property + def dtype(self): + """Returns the `dtype` of elements in the tensor.""" + return self._dtype + + @property + def name(self): + """Returns the name of the described tensor.""" + return self._name + + def is_compatible_with(self, spec_or_tensor): + """True if the shape and dtype of `spec_or_tensor` are compatible.""" + return (self._dtype.is_compatible_with(spec_or_tensor.dtype) and + self._shape.is_compatible_with(spec_or_tensor.shape)) + + def __repr__(self): + return "TensorSpec(shape={}, dtype={}, name={})".format( + self.shape, repr(self.dtype), repr(self.name)) + + def __eq__(self, other): + return self.shape == other.shape and self.dtype == other.dtype + + def __ne__(self, other): + return not self == other + + +class BoundedTensorSpec(TensorSpec): + """A `TensorSpec` that specifies minimum and maximum values. + + Example usage: + ```python + spec = tensor_spec.BoundedTensorSpec((1, 2, 3), tf.float32, 0, (5, 5, 5)) + tf_minimum = tf.convert_to_tensor(spec.minimum, dtype=spec.dtype) + tf_maximum = tf.convert_to_tensor(spec.maximum, dtype=spec.dtype) + ``` + + Bounds are meant to be inclusive. This is especially important for + integer types. The following spec will be satisfied by tensors + with values in the set {0, 1, 2}: + ```python + spec = tensor_spec.BoundedTensorSpec((3, 5), tf.int32, 0, 2) + ``` + """ + + __slots__ = ("_minimum", "_maximum") + + def __init__(self, shape, dtype, minimum, maximum, name=None): + """Initializes a new `BoundedTensorSpec`. + + Args: + shape: Value convertible to `tf.TensorShape`. The shape of the tensor. + dtype: Value convertible to `tf.DType`. The type of the tensor values. + minimum: Number or sequence specifying the minimum element bounds + (inclusive). Must be broadcastable to `shape`. + maximum: Number or sequence specifying the maximum element bounds + (inclusive). Must be broadcastable to `shape`. + name: Optional string containing a semantic name for the corresponding + array. Defaults to `None`. + + Raises: + ValueError: If `minimum` or `maximum` are not provided or not + broadcastable to `shape`. + TypeError: If the shape is not an iterable or if the `dtype` is an invalid + numpy dtype. + """ + super(BoundedTensorSpec, self).__init__(shape, dtype, name) + + if minimum is None or maximum is None: + raise ValueError("minimum and maximum must be provided; but saw " + "'%s' and '%s'" % (minimum, maximum)) + + try: + minimum_shape = np.shape(minimum) + common_shapes.broadcast_shape( + tensor_shape.TensorShape(minimum_shape), self.shape) + except ValueError as exception: + raise ValueError("minimum is not compatible with shape. " + "Message: {!r}.".format(exception)) + + try: + maximum_shape = np.shape(maximum) + common_shapes.broadcast_shape( + tensor_shape.TensorShape(maximum_shape), self.shape) + except ValueError as exception: + raise ValueError("maximum is not compatible with shape. " + "Message: {!r}.".format(exception)) + + self._minimum = np.array(minimum, dtype=self.dtype.as_numpy_dtype()) + self._minimum.setflags(write=False) + + self._maximum = np.array(maximum, dtype=self.dtype.as_numpy_dtype()) + self._maximum.setflags(write=False) + + @classmethod + def from_spec(cls, spec): + dtype = dtypes.as_dtype(spec.dtype) + if dtype in [dtypes.float64, dtypes.float32]: + # Avoid under/over-flow for `dtype.maximum - dtype.minimum`. + low = dtype.min / 2 + high = dtype.max / 2 + else: + low = dtype.min + high = dtype.max + + minimum = getattr(spec, "minimum", low) + maximum = getattr(spec, "maximum", high) + return BoundedTensorSpec(spec.shape, dtype, minimum, maximum, spec.name) + + @property + def minimum(self): + """Returns a NumPy array specifying the minimum bounds (inclusive).""" + return self._minimum + + @property + def maximum(self): + """Returns a NumPy array specifying the maximum bounds (inclusive).""" + return self._maximum + + def __repr__(self): + s = "BoundedTensorSpec(shape={}, dtype={}, name={}, minimum={}, maximum={})" + return s.format(self.shape, repr(self.dtype), repr(self.name), + repr(self.minimum), repr(self.maximum)) + + def __eq__(self, other): + tensor_spec_eq = super(BoundedTensorSpec, self).__eq__(other) + return (tensor_spec_eq and np.allclose(self.minimum, other.minimum) and + np.allclose(self.maximum, other.maximum)) + + diff --git a/tensorflow/python/framework/tensor_spec_test.py b/tensorflow/python/framework/tensor_spec_test.py new file mode 100644 index 0000000000000000000000000000000000000000..54ca4d9a19c2e1c879c05cfb828085951bdd8444 --- /dev/null +++ b/tensorflow/python/framework/tensor_spec_test.py @@ -0,0 +1,227 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tensor_spec.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import tensor_spec +from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import googletest + + +class TensorSpecTest(test_util.TensorFlowTestCase): + + def testAcceptsNumpyDType(self): + desc = tensor_spec.TensorSpec([1], np.float32) + self.assertEqual(desc.dtype, dtypes.float32) + + def testAcceptsTensorShape(self): + desc = tensor_spec.TensorSpec(tensor_shape.TensorShape([1]), dtypes.float32) + self.assertEqual(desc.shape, tensor_shape.TensorShape([1])) + + def testUnknownShape(self): + desc = tensor_spec.TensorSpec(shape=None, dtype=dtypes.float32) + self.assertEqual(desc.shape, tensor_shape.TensorShape(None)) + + def testShapeCompatibility(self): + unknown = array_ops.placeholder(dtypes.int64) + partial = array_ops.placeholder(dtypes.int64, shape=[None, 1]) + full = array_ops.placeholder(dtypes.int64, shape=[2, 3]) + rank3 = array_ops.placeholder(dtypes.int64, shape=[4, 5, 6]) + + desc_unknown = tensor_spec.TensorSpec(None, dtypes.int64) + self.assertTrue(desc_unknown.is_compatible_with(unknown)) + self.assertTrue(desc_unknown.is_compatible_with(partial)) + self.assertTrue(desc_unknown.is_compatible_with(full)) + self.assertTrue(desc_unknown.is_compatible_with(rank3)) + + desc_partial = tensor_spec.TensorSpec([2, None], dtypes.int64) + self.assertTrue(desc_partial.is_compatible_with(unknown)) + self.assertTrue(desc_partial.is_compatible_with(partial)) + self.assertTrue(desc_partial.is_compatible_with(full)) + self.assertFalse(desc_partial.is_compatible_with(rank3)) + + desc_full = tensor_spec.TensorSpec([2, 3], dtypes.int64) + self.assertTrue(desc_full.is_compatible_with(unknown)) + self.assertFalse(desc_full.is_compatible_with(partial)) + self.assertTrue(desc_full.is_compatible_with(full)) + self.assertFalse(desc_full.is_compatible_with(rank3)) + + desc_rank3 = tensor_spec.TensorSpec([4, 5, 6], dtypes.int64) + self.assertTrue(desc_rank3.is_compatible_with(unknown)) + self.assertFalse(desc_rank3.is_compatible_with(partial)) + self.assertFalse(desc_rank3.is_compatible_with(full)) + self.assertTrue(desc_rank3.is_compatible_with(rank3)) + + def testTypeCompatibility(self): + floats = array_ops.placeholder(dtypes.float32, shape=[10, 10]) + ints = array_ops.placeholder(dtypes.int32, shape=[10, 10]) + desc = tensor_spec.TensorSpec(shape=(10, 10), dtype=dtypes.float32) + self.assertTrue(desc.is_compatible_with(floats)) + self.assertFalse(desc.is_compatible_with(ints)) + + def testName(self): + desc = tensor_spec.TensorSpec([1], dtypes.float32, name="beep") + self.assertEqual(desc.name, "beep") + + def testRepr(self): + desc1 = tensor_spec.TensorSpec([1], dtypes.float32, name="beep") + self.assertEqual( + repr(desc1), + "TensorSpec(shape=(1,), dtype=tf.float32, name='beep')") + desc2 = tensor_spec.TensorSpec([1, None], dtypes.int32) + self.assertEqual( + repr(desc2), + "TensorSpec(shape=(1, ?), dtype=tf.int32, name=None)") + + def testFromTensorSpec(self): + spec_1 = tensor_spec.TensorSpec((1, 2), dtypes.int32) + spec_2 = tensor_spec.TensorSpec.from_spec(spec_1) + self.assertEqual(spec_1, spec_2) + + def testFromTensor(self): + zero = constant_op.constant(0) + spec = tensor_spec.TensorSpec.from_tensor(zero) + self.assertEqual(spec.dtype, dtypes.int32) + self.assertEqual(spec.shape, []) + self.assertEqual(spec.name, "Const") + + def testFromPlaceholder(self): + unknown = array_ops.placeholder(dtypes.int64, name="unknown") + partial = array_ops.placeholder(dtypes.float32, + shape=[None, 1], + name="partial") + spec_1 = tensor_spec.TensorSpec.from_tensor(unknown) + self.assertEqual(spec_1.dtype, dtypes.int64) + self.assertEqual(spec_1.shape, None) + self.assertEqual(spec_1.name, "unknown") + spec_2 = tensor_spec.TensorSpec.from_tensor(partial) + self.assertEqual(spec_2.dtype, dtypes.float32) + self.assertEqual(spec_2.shape.as_list(), [None, 1]) + self.assertEqual(spec_2.name, "partial") + + def testFromBoundedTensorSpec(self): + bounded_spec = tensor_spec.BoundedTensorSpec((1, 2), dtypes.int32, 0, 1) + spec = tensor_spec.TensorSpec.from_spec(bounded_spec) + self.assertEqual(bounded_spec.shape, spec.shape) + self.assertEqual(bounded_spec.dtype, spec.dtype) + self.assertEqual(bounded_spec.name, spec.name) + + +class BoundedTensorSpecTest(test_util.TensorFlowTestCase): + + def testInvalidMinimum(self): + with self.assertRaisesRegexp(ValueError, "not compatible"): + tensor_spec.BoundedTensorSpec((3, 5), dtypes.uint8, (0, 0, 0), (1, 1)) + + def testInvalidMaximum(self): + with self.assertRaisesRegexp(ValueError, "not compatible"): + tensor_spec.BoundedTensorSpec((3, 5), dtypes.uint8, 0, (1, 1, 1)) + + def testMinimumMaximumAttributes(self): + spec = tensor_spec.BoundedTensorSpec( + (1, 2, 3), dtypes.float32, 0, (5, 5, 5)) + self.assertEqual(type(spec.minimum), np.ndarray) + self.assertEqual(type(spec.maximum), np.ndarray) + self.assertAllEqual(spec.minimum, np.array(0, dtype=np.float32)) + self.assertAllEqual(spec.maximum, np.array([5, 5, 5], dtype=np.float32)) + + def testNotWriteableNP(self): + spec = tensor_spec.BoundedTensorSpec( + (1, 2, 3), dtypes.float32, 0, (5, 5, 5)) + with self.assertRaisesRegexp(ValueError, "read-only"): + spec.minimum[0] = -1 + with self.assertRaisesRegexp(ValueError, "read-only"): + spec.maximum[0] = 100 + + def testReuseSpec(self): + spec_1 = tensor_spec.BoundedTensorSpec((1, 2), dtypes.int32, + minimum=0, maximum=1) + spec_2 = tensor_spec.BoundedTensorSpec( + spec_1.shape, spec_1.dtype, spec_1.minimum, spec_1.maximum) + self.assertEqual(spec_1, spec_2) + + def testScalarBounds(self): + spec = tensor_spec.BoundedTensorSpec( + (), dtypes.float32, minimum=0.0, maximum=1.0) + + self.assertIsInstance(spec.minimum, np.ndarray) + self.assertIsInstance(spec.maximum, np.ndarray) + + # Sanity check that numpy compares correctly to a scalar for an empty shape. + self.assertEqual(0.0, spec.minimum) + self.assertEqual(1.0, spec.maximum) + + # Check that the spec doesn't fail its own input validation. + _ = tensor_spec.BoundedTensorSpec( + spec.shape, spec.dtype, spec.minimum, spec.maximum) + + def testFromBoundedTensorSpec(self): + spec_1 = tensor_spec.BoundedTensorSpec((1, 2), dtypes.int32, + minimum=0, maximum=1) + spec_2 = tensor_spec.BoundedTensorSpec.from_spec(spec_1) + self.assertEqual(spec_1, spec_2) + + def testEquality(self): + spec_1_1 = tensor_spec.BoundedTensorSpec((1, 2, 3), dtypes.float32, + 0, (5, 5, 5)) + spec_1_2 = tensor_spec.BoundedTensorSpec((1, 2, 3), dtypes.float32, + 0.00000001, + (5, 5, 5.00000000000000001)) + spec_2_1 = tensor_spec.BoundedTensorSpec((1, 2, 3), dtypes.float32, + 1, (5, 5, 5)) + spec_2_2 = tensor_spec.BoundedTensorSpec((1, 2, 3), dtypes.float32, + (1, 1, 1), (5, 5, 5)) + spec_2_3 = tensor_spec.BoundedTensorSpec((1, 2, 3), dtypes.float32, + (1, 1, 1), 5) + spec_3_1 = tensor_spec.BoundedTensorSpec((1, 2, 3), dtypes.float32, + (2, 1, 1), (5, 5, 5)) + + self.assertEqual(spec_1_1, spec_1_2) + self.assertEqual(spec_1_2, spec_1_1) + + self.assertNotEqual(spec_1_1, spec_2_2) + self.assertNotEqual(spec_1_1, spec_2_1) + self.assertNotEqual(spec_2_2, spec_1_1) + self.assertNotEqual(spec_2_1, spec_1_1) + + self.assertEqual(spec_2_1, spec_2_2) + self.assertEqual(spec_2_2, spec_2_1) + self.assertEqual(spec_2_2, spec_2_3) + + self.assertNotEqual(spec_1_1, spec_3_1) + self.assertNotEqual(spec_2_1, spec_3_1) + self.assertNotEqual(spec_2_2, spec_3_1) + + def testFromTensorSpec(self): + spec = tensor_spec.TensorSpec((1, 2), dtypes.int32) + bounded_spec = tensor_spec.BoundedTensorSpec.from_spec(spec) + self.assertEqual(spec.shape, bounded_spec.shape) + self.assertEqual(spec.dtype, bounded_spec.dtype) + self.assertEqual(spec.dtype.min, bounded_spec.minimum) + self.assertEqual(spec.dtype.max, bounded_spec.maximum) + self.assertEqual(spec.name, bounded_spec.name) + + +if __name__ == "__main__": + googletest.main() diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py index 4a8aa2e258cafe9f406f734937955cb64366c929..15e8f5a38d6a1cbfa46c305f9f5c24e9d2dbc1d7 100644 --- a/tensorflow/python/framework/test_util.py +++ b/tensorflow/python/framework/test_util.py @@ -320,12 +320,17 @@ def _use_c_api_wrapper(fn, use_c_api, *args, **kwargs): prev_value = ops._USE_C_API ops._USE_C_API = use_c_api try: - with ops.Graph().as_default(): - fn(*args, **kwargs) + # Reset the default graph so it has the C API enabled. We call + # reset_default_graph() instead of creating a new default Graph context to + # make this robust to tests that call reset_default_graph(), which requires + # that the current default graph isn't nested. + ops.reset_default_graph() + fn(*args, **kwargs) finally: ops._USE_C_API = prev_value - - + # Make sure default graph reflects prev_value in case next test doesn't call + # reset_default_graph(). + ops.reset_default_graph() # pylint: disable=protected-access @@ -420,66 +425,6 @@ def with_c_api(cls): return cls -class IsolateTest(object): - """A context manager which isolates resources in its block. - - Provides an Eager-agnostic abstraction for preventing the sharing of - variables and other resources. - - In graph mode, resource handle ops are only executed in a particular Session, - isolating them from resources with the same name in other Graphs. In Eager, - separate Sessions do not exist, so resources (particularly ResourceVariables) - would be shared implicitly if a resource of the same name were created - anywhere in a Python process. Multiple handles to the same resource would - cause several issues, and so this type of sharing will raise an exception. - - Using resources with the same name in a single Python process may be useful - (especially for unit tests), so this context manager provides an abstraction - for isolating resources. Using a resource created in one Isolation environment - in another is an error. - - Example usage in Eager mode: - - ```python - import tensorflow as tf - # Import subject to change - from tensorflow.contrib.eager.python import tfe - - tfe.enable_eager_execution() - - for hyperparameter in [1, 2, 3]: - with tfe.IsolateTest(): - v = tfe.Variable(name="v", initial_value=hyperparameter) - # train model, test results ... - ``` - - IsolateTest is currently exposed through contrib.eager, but it creates a new - default Graph and provides equivalent safety in graph mode. - """ - - def __init__(self): - if context.in_eager_mode() and tape.could_possibly_record(): - raise ValueError("Cannot isolate Eager execution with an active tape.") - # In Eager, Graphs set a container which isolates resources, and maintain a - # VariableStore which caches ResourceVariable objects created through - # get_variable. So setting the default Graph has the side effect of - # isolating Eager resources. - with context.eager_mode(): - # Create the graph in Eager mode, as this provides stricter semantics - # (i.e. has a unique container prefix). This prevents implicit sharing - # when a Graph-mode graph is created and then Eager mode is enabled (an - # error through enable_eager_execution, but common with context managers - # in unit tests). - self._graph_as_default_context_manager = ops.Graph().as_default() - - def __enter__(self): - self._graph_as_default_context_manager.__enter__() - - def __exit__(self, type_arg, value_arg, traceback_arg): - return self._graph_as_default_context_manager.__exit__( - type_arg, value_arg, traceback_arg) - - def assert_no_new_tensors(f): """Decorator for asserting that no new Tensors persist after a test. @@ -510,12 +455,11 @@ def assert_no_new_tensors(f): return False tensors_before = set(id(obj) for obj in gc.get_objects() if _is_tensor(obj)) - outside_container_prefix = ops.get_default_graph()._container_prefix - with IsolateTest(): + outside_graph_key = ops.get_default_graph()._graph_key + with ops.Graph().as_default(): # Run the test in a new graph so that collections get cleared when it's - # done, but inherit the container prefix so that we can print the values - # of variables which get leaked when executing eagerly. - ops.get_default_graph()._container_prefix = outside_container_prefix + # done, but inherit the graph key so optimizers behave. + ops.get_default_graph()._graph_key = outside_graph_key f(self, **kwargs) # Make an effort to clear caches, which would otherwise look like leaked # Tensors. @@ -637,7 +581,7 @@ def run_in_graph_and_eager_modes(__unused__=None, assert_no_garbage_created(run_eager_mode)) with context.eager_mode(): - with IsolateTest(): + with ops.Graph().as_default(): run_eager_mode(self, **kwargs) return decorated diff --git a/tensorflow/python/framework/test_util_test.py b/tensorflow/python/framework/test_util_test.py index 3594d125bf616917727bea4958eaabf159d0aee0..a717eb39513ac3369ae133b6090ff82597f12eb7 100644 --- a/tensorflow/python/framework/test_util_test.py +++ b/tensorflow/python/framework/test_util_test.py @@ -29,7 +29,6 @@ from google.protobuf import text_format from tensorflow.core.framework import graph_pb2 from tensorflow.core.protobuf import meta_graph_pb2 -from tensorflow.python.client import session from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import errors @@ -39,7 +38,6 @@ from tensorflow.python.framework import test_util from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import resource_variable_ops -from tensorflow.python.ops import variables from tensorflow.python.platform import googletest @@ -443,71 +441,5 @@ class GarbageCollectionTest(test_util.TensorFlowTestCase): LeakedTensorTest().test_has_no_leak() -@test_util.with_c_api -class IsolationTest(test_util.TensorFlowTestCase): - - @test_util.run_in_graph_and_eager_modes() - def test_variable_reuse_exception(self): - with test_util.IsolateTest(), session.Session(): - first_container_variable = resource_variable_ops.ResourceVariable( - name="first_container_variable", - initial_value=1) - if context.in_graph_mode(): - self.evaluate([variables.global_variables_initializer()]) - with test_util.IsolateTest(): - if context.in_graph_mode(): - with self.assertRaises(RuntimeError): - self.evaluate(first_container_variable.read_value()) - else: - with self.assertRaises(ValueError): - first_container_variable.read_value() - - @test_util.run_in_graph_and_eager_modes() - def test_variable_reuse_exception_nested(self): - with test_util.IsolateTest(), session.Session(): - first_container_variable = resource_variable_ops.ResourceVariable( - name="first_container_variable", - initial_value=1) - if context.in_graph_mode(): - self.evaluate([variables.global_variables_initializer()]) - with test_util.IsolateTest(), session.Session(): - if context.in_graph_mode(): - with self.assertRaises(RuntimeError): - self.evaluate(first_container_variable.read_value()) - else: - with self.assertRaises(ValueError): - first_container_variable.read_value() - - @test_util.run_in_graph_and_eager_modes() - def test_no_sharing(self): - with test_util.IsolateTest(), session.Session(): - first_container_variable = resource_variable_ops.ResourceVariable( - name="same_name", - initial_value=1) - if context.in_graph_mode(): - self.evaluate([variables.global_variables_initializer()]) - with test_util.IsolateTest(), session.Session(): - second_container_variable = resource_variable_ops.ResourceVariable( - name="same_name", - initial_value=2) - if context.in_graph_mode(): - self.evaluate([variables.global_variables_initializer()]) - self.assertEqual( - 2, self.evaluate(second_container_variable.read_value())) - self.assertEqual(1, self.evaluate(first_container_variable.read_value())) - - def test_graph_mode_isolation(self): - with context.graph_mode(): - # Even if we've (accidentally) called IsolateTest in Graph mode, it should - # provide Eager isolation. - with test_util.IsolateTest(): - with context.eager_mode(): - first_container_variable = resource_variable_ops.ResourceVariable( - name="first_container_variable", - initial_value=1) - with context.eager_mode(): - with self.assertRaises(ValueError): - first_container_variable.read_value() - if __name__ == "__main__": googletest.main() diff --git a/tensorflow/python/grappler/cluster_test.py b/tensorflow/python/grappler/cluster_test.py index 2292b2c732b2d5d0d40b44d8ca831f4e72b057c6..10d515a36458d4025060cf4900251cd493f40795 100644 --- a/tensorflow/python/grappler/cluster_test.py +++ b/tensorflow/python/grappler/cluster_test.py @@ -45,7 +45,7 @@ class ClusterTest(test.TestCase): op_perfs, run_time, step_stats = grappler_cluster.MeasureCosts( grappler_item) self.assertTrue(run_time > 0) - self.assertEqual(len(op_perfs), 9) + self.assertEqual(len(op_perfs), 7) self.assertTrue(step_stats.dev_stats) def testNoDetailedStats(self): @@ -125,7 +125,7 @@ class ClusterTest(test.TestCase): disable_detailed_stats=False, disable_timeline=False) as gcluster: op_perfs, run_time, step_stats = gcluster.MeasureCosts(grappler_item) self.assertTrue(run_time > 0) - self.assertEqual(len(op_perfs), 9) + self.assertEqual(len(op_perfs), 7) self.assertTrue(step_stats.dev_stats) def testAvailableOps(self): diff --git a/tensorflow/python/grappler/cost_analyzer_tool.py b/tensorflow/python/grappler/cost_analyzer_tool.py index 61dc4e2afb833414f875d66bb12b0aa010f9d62e..ac251f2bbd9a30d2da9303bcbd20e84e5f5fc772 100644 --- a/tensorflow/python/grappler/cost_analyzer_tool.py +++ b/tensorflow/python/grappler/cost_analyzer_tool.py @@ -22,7 +22,7 @@ import argparse import sys from google.protobuf import text_format - +from tensorflow.contrib.fused_conv.ops import gen_fused_conv2d_bias_activation_op # pylint: disable=unused-import from tensorflow.core.framework import graph_pb2 from tensorflow.core.protobuf import meta_graph_pb2 from tensorflow.core.protobuf import rewriter_config_pb2 @@ -43,7 +43,10 @@ def main(_): else: with gfile.GFile(FLAGS.graphdef) as graph_file: graph_def = graph_pb2.GraphDef() - graph_def.ParseFromString(graph_file.read()) + if FLAGS.graphdef.endswith(".pbtxt"): + text_format.Merge(graph_file.read(), graph_def) + else: + graph_def.ParseFromString(graph_file.read()) importer.import_graph_def(graph_def, name="") graph = ops.get_default_graph() fetch = graph.get_operation_by_name(FLAGS.fetch) diff --git a/tensorflow/python/grappler/layout_optimizer_test.py b/tensorflow/python/grappler/layout_optimizer_test.py index 578f86ca5a0c1f2446dbf26ce412e34f3bdbd23a..5bc9e4b8030d69530a2427172badb7331ad58155 100644 --- a/tensorflow/python/grappler/layout_optimizer_test.py +++ b/tensorflow/python/grappler/layout_optimizer_test.py @@ -157,6 +157,7 @@ def _get_config(layout_optimizer=True): graph_options = config_pb2.GraphOptions( rewrite_options=rewrite_options, build_cost_model=1) config = config_pb2.ConfigProto(graph_options=graph_options) + config.graph_options.optimizer_options.opt_level = -1 return config @@ -179,6 +180,8 @@ def _get_cluster(): named_device = device_properties_pb2.NamedDevice() named_device.name = '/GPU:0' named_device.properties.type = 'GPU' + named_device.properties.num_cores = 24 + named_device.properties.frequency = 1000 named_device.properties.environment['architecture'] = '4' cluster = gcluster.Cluster(devices=[named_device]) return cluster @@ -1169,7 +1172,7 @@ class LayoutOptimizerTest(test.TestCase): num_transposes += 1 nodes.append(node.name) - expected_num_transposes = 2 + expected_num_transposes = 3 self.assertEqual(expected_num_transposes, num_transposes) self._assert_trans_nhwc_to_nchw('map/while/Conv2D-0', nodes) self._assert_trans_nchw_to_nhwc('map/while/Add-0-2', nodes) diff --git a/tensorflow/python/keras/BUILD b/tensorflow/python/keras/BUILD index 61257557751359a45e3ef9f74ee6307b4c6d21dc..fdac22bb53cc7e78d854d4b5ff756a190c9c62b6 100755 --- a/tensorflow/python/keras/BUILD +++ b/tensorflow/python/keras/BUILD @@ -39,6 +39,7 @@ py_library( "_impl/keras/engine/__init__.py", "_impl/keras/engine/topology.py", "_impl/keras/engine/training.py", + "_impl/keras/engine/training_eager.py", "_impl/keras/estimator.py", "_impl/keras/initializers.py", "_impl/keras/layers/__init__.py", @@ -481,6 +482,7 @@ py_test( size = "small", srcs = ["_impl/keras/layers/normalization_test.py"], srcs_version = "PY2AND3", + tags = ["notsan"], deps = [ ":keras", "//tensorflow/python:client_testlib", @@ -719,6 +721,19 @@ py_test( ], ) +py_test( + name = "training_eager_test", + size = "medium", + srcs = ["_impl/keras/engine/training_eager_test.py"], + srcs_version = "PY2AND3", + tags = ["notsan"], + deps = [ + ":keras", + "//tensorflow/python:client_testlib", + "//third_party/py/numpy", + ], +) + py_test( name = "topology_test", size = "small", diff --git a/tensorflow/python/keras/_impl/keras/backend.py b/tensorflow/python/keras/_impl/keras/backend.py index 460c0dc5f39baac7b171568e6014c22eac23ccfc..098ea063f951ef86c2a474aa50d6239b514cc699 100644 --- a/tensorflow/python/keras/_impl/keras/backend.py +++ b/tensorflow/python/keras/_impl/keras/backend.py @@ -29,6 +29,7 @@ import numpy as np from tensorflow.core.protobuf import config_pb2 from tensorflow.python.client import session as session_module +from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes as dtypes_module from tensorflow.python.framework import ops @@ -326,7 +327,15 @@ def learning_phase(): Returns: Learning phase (scalar integer tensor or Python integer). + + Raises: + ValueError: If called when Eager execution is enabled. """ + if context.in_eager_mode(): + if 'eager' not in _GRAPH_LEARNING_PHASES: + raise ValueError('No learning phase set in Eager mode.') + return _GRAPH_LEARNING_PHASES['eager'] + graph = ops.get_default_graph() if graph not in _GRAPH_LEARNING_PHASES: phase = array_ops.placeholder_with_default( @@ -347,7 +356,10 @@ def set_learning_phase(value): global _GRAPH_LEARNING_PHASES # pylint: disable=global-variable-not-assigned if value not in {0, 1}: raise ValueError('Expected learning phase to be ' '0 or 1.') - _GRAPH_LEARNING_PHASES[ops.get_default_graph()] = value + if context.in_eager_mode(): + _GRAPH_LEARNING_PHASES['eager'] = value + else: + _GRAPH_LEARNING_PHASES[ops.get_default_graph()] = value def get_session(): diff --git a/tensorflow/python/keras/_impl/keras/engine/topology.py b/tensorflow/python/keras/_impl/keras/engine/topology.py index 64aa868f3822c4dfcfbe8ae1764d617a00ffff4d..8354a2b8fd7c0182c1daaa7e8fa8390da6038d0b 100644 --- a/tensorflow/python/keras/_impl/keras/engine/topology.py +++ b/tensorflow/python/keras/_impl/keras/engine/topology.py @@ -708,8 +708,10 @@ class Network(tf_network.GraphNetwork, Layer): self.input_names.append(layer.name) if layer.is_placeholder: self._feed_input_names.append(layer.name) - self._feed_inputs.append(layer.input) self._feed_input_shapes.append(K.int_shape(self.inputs[i])) + # layer.input gives an error in eager mode + if context.in_graph_mode(): + self._feed_inputs.append(layer.input) for layer in self._output_layers: self.output_names.append(layer.name) diff --git a/tensorflow/python/keras/_impl/keras/engine/training.py b/tensorflow/python/keras/_impl/keras/engine/training.py index 699ae2edf0db1cdcc73763607d04329c76888565..43d95b1f194e1d30cbdae726fdf5979bd7065d25 100644 --- a/tensorflow/python/keras/_impl/keras/engine/training.py +++ b/tensorflow/python/keras/_impl/keras/engine/training.py @@ -22,17 +22,21 @@ import copy import numpy as np +from tensorflow.python.eager import context +from tensorflow.python.framework import ops from tensorflow.python.keras._impl.keras import backend as K from tensorflow.python.keras._impl.keras import callbacks as cbks from tensorflow.python.keras._impl.keras import losses from tensorflow.python.keras._impl.keras import metrics as metrics_module from tensorflow.python.keras._impl.keras import optimizers +from tensorflow.python.keras._impl.keras.engine import training_eager from tensorflow.python.keras._impl.keras.engine.topology import Network from tensorflow.python.keras._impl.keras.utils.data_utils import GeneratorEnqueuer from tensorflow.python.keras._impl.keras.utils.data_utils import OrderedEnqueuer from tensorflow.python.keras._impl.keras.utils.data_utils import Sequence from tensorflow.python.keras._impl.keras.utils.generic_utils import Progbar from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.training import optimizer as tf_optimizer_module try: from scipy.sparse import issparse # pylint: disable=g-import-not-at-top @@ -82,21 +86,24 @@ def _standardize_input_data(data, if data[x].__class__.__name__ == 'DataFrame' else data[x] for x in names ] - data = [np.expand_dims(x, 1) if x.ndim == 1 else x for x in data] except KeyError as e: raise ValueError('No data provided for "' + e.args[0] + '". Need data ' 'for each key in: ' + str(names)) elif isinstance(data, list): - data = [ - x.values if x.__class__.__name__ == 'DataFrame' else x for x in data - ] - data = [ - np.expand_dims(x, 1) if x is not None and x.ndim == 1 else x - for x in data - ] + if isinstance(data[0], list): + data = [np.asarray(d) for d in data] + elif len(names) == 1 and isinstance(data[0], (float, int)): + data = [np.asarray(data)] + else: + data = [ + x.values if x.__class__.__name__ == 'DataFrame' else x for x in data + ] else: data = data.values if data.__class__.__name__ == 'DataFrame' else data - data = [np.expand_dims(data, 1)] if data.ndim == 1 else [data] + data = [data] + data = [ + np.expand_dims(x, 1) if x is not None and x.ndim == 1 else x for x in data + ] if len(data) != len(names): if data and hasattr(data[0], 'shape'): @@ -618,9 +625,15 @@ class Model(Network): `optimizer`, `loss`, `metrics` or `sample_weight_mode`. """ loss = loss or {} + if context.in_eager_mode() and not isinstance( + optimizer, tf_optimizer_module.Optimizer): + raise ValueError('Only TF native optimizers are supported in Eager mode.') + self.optimizer = optimizers.get(optimizer) self.loss = loss self.loss_weights = loss_weights + if context.in_eager_mode() and sample_weight_mode is not None: + raise ValueError('sample_weight_mode is not supported in Eager mode.') self.sample_weight_mode = sample_weight_mode # Prepare loss functions. @@ -651,6 +664,7 @@ class Model(Network): loss_function = losses.get(loss) loss_functions = [loss_function for _ in range(len(self.outputs))] self.loss_functions = loss_functions + weighted_losses = [_weighted_masked_objective(fn) for fn in loss_functions] skip_target_indices = [] skip_target_weighing_indices = [] @@ -664,11 +678,12 @@ class Model(Network): skip_target_weighing_indices.append(i) # Prepare output masks. - masks = self.compute_mask(self.inputs, mask=None) - if masks is None: - masks = [None for _ in self.outputs] - if not isinstance(masks, list): - masks = [masks] + if context.in_graph_mode(): + masks = self.compute_mask(self.inputs, mask=None) + if masks is None: + masks = [None for _ in self.outputs] + if not isinstance(masks, list): + masks = [masks] # Prepare loss weights. if loss_weights is None: @@ -694,6 +709,32 @@ class Model(Network): else: raise TypeError('Could not interpret loss_weights argument: ' + str(loss_weights) + ' - expected a list of dicts.') + self.loss_weights_list = loss_weights_list + + # initialization for Eager mode execution + if context.in_eager_mode(): + if target_tensors is not None: + raise ValueError('target_tensors are not currently supported in Eager' + 'mode.') + self.total_loss = None + self.metrics = metrics + self.weighted_metrics = weighted_metrics + self.metrics_tensors = [] + self.metrics_names = ['loss'] + for i in range(len(self.outputs)): + if len(self.outputs) > 1: + self.metrics_names.append(self.output_names[i] + '_loss') + self.nested_metrics = _collect_metrics(metrics, self.output_names) + self._feed_sample_weight_modes = [] + for i in range(len(self.outputs)): + self._feed_sample_weight_modes.append(None) + self.sample_weights = [] + self.targets = [] + self._collected_trainable_weights = self.trainable_weights + for i in range(len(self.outputs)): + self._feed_output_names.append(self.output_names[i]) + + return # Prepare targets of model. self.targets = [] @@ -720,6 +761,7 @@ class Model(Network): else: raise TypeError('Expected `target_tensors` to be ' 'a list or dict, but got:', target_tensors) + for i in range(len(self.outputs)): if i in skip_target_indices: self.targets.append(None) @@ -769,7 +811,7 @@ class Model(Network): weight = K.placeholder(ndim=2, name=name + '_sample_weights') sample_weight_modes.append('temporal') else: - weight = K.placeholder(ndim=1, name=name + '_sample_weights') + weight = K.placeholder(ndim=1, name=name + 'sample_weights') sample_weight_modes.append(None) sample_weights.append(weight) elif isinstance(sample_weight_mode, list): @@ -929,7 +971,7 @@ class Model(Network): self._feed_sample_weights = [] for i in range(len(self.sample_weights)): if i not in skip_target_weighing_indices: - self._feed_sample_weights.append(sample_weights[i]) + self._feed_sample_weights.append(self.sample_weights[i]) # Functions for train, test and predict will # be compiled lazily when required. @@ -978,6 +1020,7 @@ class Model(Network): with K.name_scope(self.optimizer.__class__.__name__): training_updates = self.optimizer.get_updates( params=self._collected_trainable_weights, loss=self.total_loss) + updates = self.updates + training_updates # Gets loss and metrics. Updates weights at each call. self.train_function = K.function( @@ -1156,6 +1199,7 @@ class Model(Network): callback_model = self callbacks.set_model(callback_model) + callbacks.set_params({ 'batch_size': batch_size, 'epochs': epochs, @@ -1216,6 +1260,7 @@ class Model(Network): np.random.shuffle(index_array) batches = _make_batches(num_train_samples, batch_size) + for batch_index, (batch_start, batch_end) in enumerate(batches): batch_ids = index_array[batch_start:batch_end] try: @@ -1410,6 +1455,7 @@ class Model(Network): ins_batch[i] = ins_batch[i].toarray() batch_outs = f(ins_batch) + if isinstance(batch_outs, list): if batch_index == 0: for batch_out in enumerate(batch_outs): @@ -1420,7 +1466,6 @@ class Model(Network): if batch_index == 0: outs.append(0.) outs[0] += batch_outs * len(batch_ids) - if verbose == 1: progbar.update(batch_end) for i in range(len(outs)): @@ -1636,6 +1681,7 @@ class Model(Network): batch_size=batch_size) # Prepare validation data. do_validation = False + val_ins = [] if validation_data: do_validation = True if len(validation_data) == 2: @@ -1686,39 +1732,65 @@ class Model(Network): ins = x + y + sample_weights + [1.] else: ins = x + y + sample_weights - self._make_train_function() - f = self.train_function # Prepare display labels. out_labels = self._get_deduped_metrics_names() - if do_validation: - self._make_test_function() - val_f = self.test_function - callback_metrics = copy.copy(out_labels) + [ - 'val_' + n for n in out_labels - ] + if context.in_eager_mode(): + if do_validation: + callback_metrics = copy.copy(out_labels) + [ + 'val_' + n for n in out_labels + ] + else: + callback_metrics = copy.copy(out_labels) + + return training_eager.fit_loop( + self, + ins, + out_labels=out_labels, + batch_size=batch_size, + epochs=epochs, + verbose=verbose, + callbacks=callbacks, + val_ins=val_ins, + shuffle=shuffle, + callback_metrics=callback_metrics, + initial_epoch=initial_epoch, + steps_per_epoch=steps_per_epoch, + validation_steps=validation_steps) else: - callback_metrics = copy.copy(out_labels) - val_f = None - val_ins = [] - - # Delegate logic to `_fit_loop`. - return self._fit_loop( - f, - ins, - out_labels=out_labels, - batch_size=batch_size, - epochs=epochs, - verbose=verbose, - callbacks=callbacks, - val_f=val_f, - val_ins=val_ins, - shuffle=shuffle, - callback_metrics=callback_metrics, - initial_epoch=initial_epoch, - steps_per_epoch=steps_per_epoch, - validation_steps=validation_steps) + self._make_train_function() + f = self.train_function + + if do_validation: + if context.in_graph_mode(): + self._make_test_function() + val_f = self.test_function + else: + val_f = None + callback_metrics = copy.copy(out_labels) + [ + 'val_' + n for n in out_labels + ] + else: + val_f = None + callback_metrics = copy.copy(out_labels) + + # Delegate logic to `_fit_loop`. + return self._fit_loop( + f, + ins, + out_labels=out_labels, + batch_size=batch_size, + epochs=epochs, + verbose=verbose, + callbacks=callbacks, + val_f=val_f, + val_ins=val_ins, + shuffle=shuffle, + callback_metrics=callback_metrics, + initial_epoch=initial_epoch, + steps_per_epoch=steps_per_epoch, + validation_steps=validation_steps) def evaluate(self, x=None, @@ -1794,10 +1866,15 @@ class Model(Network): ins = x + y + sample_weights + [0.] else: ins = x + y + sample_weights - self._make_test_function() - f = self.test_function - return self._test_loop( - f, ins, batch_size=batch_size, verbose=verbose, steps=steps) + + if context.in_eager_mode(): + return training_eager.test_loop( + self, ins, batch_size=batch_size, verbose=verbose, steps=steps) + else: + self._make_test_function() + f = self.test_function + return self._test_loop( + f, ins, batch_size=batch_size, verbose=verbose, steps=steps) def predict(self, x, batch_size=None, verbose=0, steps=None): """Generates output predictions for the input samples. @@ -1849,10 +1926,16 @@ class Model(Network): ins = x + [0.] else: ins = x - self._make_predict_function() - f = self.predict_function - return self._predict_loop( - f, ins, batch_size=batch_size, verbose=verbose, steps=steps) + + if context.in_eager_mode(): + return training_eager.predict_loop( + self, ins, batch_size=batch_size, verbose=verbose, steps=steps) + else: + self._make_predict_function() + f = self.predict_function + + return self._predict_loop( + f, ins, batch_size=batch_size, verbose=verbose, steps=steps) def train_on_batch(self, x, y, sample_weight=None, class_weight=None): """Runs a single gradient update on a single batch of data. @@ -1888,6 +1971,7 @@ class Model(Network): or list of scalars (if the model has multiple outputs and/or metrics). The attribute `model.metrics_names` will give you the display labels for the scalar outputs. + """ x, y, sample_weights = self._standardize_user_data( x, @@ -1899,11 +1983,16 @@ class Model(Network): ins = x + y + sample_weights + [1.] else: ins = x + y + sample_weights - self._make_train_function() - outputs = self.train_function(ins) - if len(outputs) == 1: - return outputs[0] - return outputs + + if context.in_eager_mode(): + return training_eager.train_on_batch(self, ins) + + if context.in_graph_mode(): + self._make_train_function() + outputs = self.train_function(ins) + if len(outputs) == 1: + return outputs[0] + return outputs def test_on_batch(self, x, y, sample_weight=None): """Test the model on a single batch of samples. @@ -1942,11 +2031,16 @@ class Model(Network): ins = x + y + sample_weights + [0.] else: ins = x + y + sample_weights - self._make_test_function() - outputs = self.test_function(ins) - if len(outputs) == 1: - return outputs[0] - return outputs + + if context.in_eager_mode(): + return training_eager.test_on_batch(self, ins) + + if context.in_graph_mode(): + self._make_test_function() + outputs = self.test_function(ins) + if len(outputs) == 1: + return outputs[0] + return outputs def predict_on_batch(self, x): """Returns predictions for a single batch of samples. @@ -1956,6 +2050,7 @@ class Model(Network): Returns: Numpy array(s) of predictions. + """ x = _standardize_input_data(x, self._feed_input_names, self._feed_input_shapes) @@ -1963,11 +2058,25 @@ class Model(Network): ins = x + [0.] else: ins = x - self._make_predict_function() - outputs = self.predict_function(ins) - if len(outputs) == 1: - return outputs[0] - return outputs + + if context.in_eager_mode(): + ins_batch_converted = [] + for ib in ins: + ins_batch_converted.append(ops.convert_to_tensor(ib, dtype=K.floatx())) + + eager_model_inputs = [] + for i in range(len(self.inputs)): + eager_model_inputs.append(ins_batch_converted[i]) + + outs = self(eager_model_inputs) # pylint: disable=not-callable + return outs + + if context.in_graph_mode(): + self._make_predict_function() + outputs = self.predict_function(ins) + if len(outputs) == 1: + return outputs[0] + return outputs def fit_generator(self, generator, @@ -2072,7 +2181,6 @@ class Model(Network): model.fit_generator(generate_arrays_from_file('/my_file.txt'), steps_per_epoch=10000, epochs=10) ``` - Raises: ValueError: In case the generator yields data in an invalid format. diff --git a/tensorflow/python/keras/_impl/keras/engine/training_eager.py b/tensorflow/python/keras/_impl/keras/engine/training_eager.py new file mode 100644 index 0000000000000000000000000000000000000000..0a115969ca614d8d50a60f8980fa49bf404cc66f --- /dev/null +++ b/tensorflow/python/keras/_impl/keras/engine/training_eager.py @@ -0,0 +1,666 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Keras training and evaluation routines. +""" +# pylint: disable=protected-access +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import numpy as np +from tensorflow.python.eager.backprop import GradientTape +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util +from tensorflow.python.keras._impl.keras import backend as K +from tensorflow.python.keras._impl.keras import callbacks as cbks +from tensorflow.python.keras._impl.keras import losses +from tensorflow.python.keras._impl.keras import metrics as metrics_module +from tensorflow.python.keras._impl.keras.utils.generic_utils import Progbar + + +def _make_batches(size, batch_size): + """Returns a list of batch indices (tuples of indices). + + Arguments: + size: Integer, total size of the data to slice into batches. + batch_size: Integer, batch size. + + Returns: + A list of tuples of array indices. + """ + num_batches = int(np.ceil(size / float(batch_size))) + return [(i * batch_size, min(size, (i + 1) * batch_size)) + for i in range(0, num_batches)] + + +def _slice_arrays(arrays, start=None, stop=None): + """Slice an array or list of arrays. + + This takes an array-like, or a list of + array-likes, and outputs: + - arrays[start:stop] if `arrays` is an array-like + - [x[start:stop] for x in arrays] if `arrays` is a list + + Can also work on list/array of indices: `_slice_arrays(x, indices)` + + Arguments: + arrays: Single array or list of arrays. + start: can be an integer index (start index) + or a list/array of indices + stop: integer (stop index); should be None if + `start` was a list. + + Returns: + A slice of the array(s). + + Raises: + ValueError: If the value of start is a list and stop is not None. + """ + if arrays is None: + return [None] + if isinstance(start, list) and stop is not None: + raise ValueError('The stop argument has to be None if the value of start is' + 'a list.') + elif isinstance(arrays, list): + if hasattr(start, '__len__'): + # hdf5 datasets only support list objects as indices + if hasattr(start, 'shape'): + start = start.tolist() + return [None if x is None else x[start] for x in arrays] + else: + return [None if x is None else x[start:stop] for x in arrays] + else: + if hasattr(start, '__len__'): + if hasattr(start, 'shape'): + start = start.tolist() + return arrays[start] + elif hasattr(start, '__getitem__'): + return arrays[start:stop] + else: + return [None] + + +def _get_metrics_info(metric, internal_output_shapes=None, loss_func=None): + if metric == 'accuracy' or metric == 'acc': + # custom handling of accuracy + # (because of class mode duality) + output_shape = internal_output_shapes + if output_shape[-1] == 1 or loss_func == losses.binary_crossentropy: + # case: binary accuracy + acc_fn = metrics_module.binary_accuracy + elif loss_func == losses.sparse_categorical_crossentropy: + # case: categorical accuracy with sparse targets + acc_fn = metrics_module.sparse_categorical_accuracy + else: + acc_fn = metrics_module.categorical_accuracy + + metric_name = 'acc' + return metric_name, acc_fn + else: + metric_fn = metrics_module.get(metric) + metric_name = metric_fn.__name__ + return metric_name, metric_fn + + +def _eager_loss_fn(outputs, targets, loss_fn, output_name): + with K.name_scope(output_name + '_loss'): + loss = loss_fn(targets, outputs) + return loss + + +def _eager_metrics_fn(model, outputs, targets): + """Calculates the metrics for each output of the given model. + + Arguments: + model: The model on which metrics are being calculated. + outputs: The outputs of the given model. + targets: The predictions or targets of the given model. + + Returns: + Returns the metric names and metric results for each output of the model. + """ + metric_names = [] + metric_results = [] + if not isinstance(outputs, list): + outputs = [outputs] + + if not isinstance(targets, list): + targets = [targets] + + for i in range(len(model.outputs)): + output_metrics = model.nested_metrics[i] + for nested_output_metric in output_metrics: + metric_name, metric_fn = _get_metrics_info( + nested_output_metric, model._internal_output_shapes[i], + model.loss_functions[i]) + + if len(model.output_names) > 1: + metric_name = model.output_names[i] + '_' + metric_name + if metric_name not in model.metrics_names: + model.metrics_names.append(metric_name) + + with K.name_scope(metric_name): + metric_result = metric_fn(outputs[i], targets[i]) + metric_names.append(metric_name) + metric_results.append(K.mean(metric_result)) + + return metric_names, metric_results + + +def _model_loss(model, inputs, targets): + """Calculates the loss for a given model. + + Arguments: + model: The model on which metrics are being calculated. + inputs: The inputs of the given model. This is typically the mini batch of + data that is fed to the model. + targets: The predictions or targets of the given model. + + Returns: + Returns the model output, total loss and loss value calculated using the + specified loss function. The total loss includes regularization losses and + applies masking and sample weighting to the loss value. + """ + total_loss = 0 + outs = model(inputs) + if not isinstance(outs, list): + outs = [outs] + + if not isinstance(targets, list): + targets = [targets] + + loss_metrics = [] + with K.name_scope('loss'): + for i, loss_fn in enumerate(model.loss_functions): + # compute the loss + output_loss = _eager_loss_fn(outs[i], targets[i], loss_fn, + model.output_names[i]) + loss_metrics.append(K.mean(output_loss)) + + mask = outs[i]._keras_mask + # adapted from weighted_loss_fn + if mask is not None: + # mask should have the same shape as output_loss + output_loss *= mask + # the loss per batch should be proportional + # to the number of unmasked samples. + output_loss /= K.mean(mask) + + # adapted from weighted_loss_fn + # apply sample weighting + if model.sample_weights: + # reduce score_array to same ndim as weight array + ndim = K.ndim(output_loss) + weight_ndim = K.ndim(model.sample_weights) + output_loss = K.mean(output_loss, axis=list(range(weight_ndim, ndim))) + output_loss *= model.sample_weights + output_loss /= K.mean(K.cast(K.not_equal(model.sample_weights, 0), + K.floatx())) + output_loss = K.mean(output_loss) + + loss_weight = model.loss_weights_list[i] + if total_loss is None: + total_loss = loss_weight * output_loss + else: + total_loss += loss_weight * output_loss + + total_loss = K.mean(total_loss) + # Add regularization losses + custom_losses = [] + for layer in model.layers: + if layer.losses: + custom_losses += layer.losses + + if custom_losses: + total_loss += sum(custom_losses) + + return outs, total_loss, loss_metrics + + +def _process_single_batch(eager_model_inputs, eager_model_outputs, model, + training=True): + """Calculate the loss and gradient for one input batch. + + The model weights are updated if training is set to True. + + Arguments: + eager_model_inputs: Input batch data. + eager_model_outputs: Output batch data. + model: Model whose loss has to be calculated. + training: The boolean represents if the weights of the model are updated. + 'fit' methods will set this to True while 'evaluate' methods will + set this to False. + + Returns: + output of the model, total loss and the loss associated with each output. + + Raises: + ValueError: If the model loss is 0 or if the trainable weights list is + empty when the trainable parameter is set to True. + """ + K.set_learning_phase(training) + with GradientTape() as tape: + outs, loss, loss_metrics = _model_loss(model, eager_model_inputs, + eager_model_outputs) + if loss is None: + raise ValueError('The model cannot be run ' + 'because it has no loss to optimize.') + if training: + if not model._collected_trainable_weights: + raise ValueError('The list of trainable weights is empty. Make sure that ' + 'you are not setting model.trainable to False before ' + 'compiling the model.') + grads = tape.gradient(loss, model._collected_trainable_weights) + model.optimizer.apply_gradients(zip(grads, + model._collected_trainable_weights)) + return outs, loss, loss_metrics + + +def train_on_batch(model, ins): + """Calculates the loss and gradient updates for one input batch. + + Arguments: + model: Given model on which loss and gradients are calculated. + ins: Input and output batch numpy arrays. + + Returns: + total loss and the loss associated with each output. + """ + ins_batch_converted = [] + for ib in ins: + ins_batch_converted.append(ops.convert_to_tensor(ib, dtype=K.floatx())) + eager_model_inputs = [] + eager_model_outputs = [] + for i in range(len(model.inputs)): + eager_model_inputs.append(ins_batch_converted[i]) + for i in range(len(model.inputs), len(ins_batch_converted)): + eager_model_outputs.append(ins_batch_converted[i]) + outs, loss, _ = _process_single_batch( + eager_model_inputs, eager_model_outputs, model) + if not isinstance(outs, list): + outs = [outs] + _, metrics_results = _eager_metrics_fn( + model, outs, eager_model_outputs) + if not isinstance(loss, list): + loss = [loss] + return loss + metrics_results + + +def test_on_batch(model, ins): + """Calculates the loss for one input batch. + + Arguments: + model: Given model on which loss is calculated. + ins: Input and output batch numpy arrays. + + Returns: + total loss, loss and metrics associated with each output. + """ + ins_batch_converted = [] + for ib in ins: + ins_batch_converted.append(ops.convert_to_tensor(ib, dtype=K.floatx())) + eager_model_inputs = [] + eager_model_outputs = [] + for i in range(len(model.inputs)): + eager_model_inputs.append(ins_batch_converted[i]) + for i in range(len(model.inputs), len(ins_batch_converted)): + eager_model_outputs.append(ins_batch_converted[i]) + outs, loss, loss_metrics = _process_single_batch( + eager_model_inputs, eager_model_outputs, model, training=False) + if not isinstance(outs, list): + outs = [outs] + metric_names, metrics_results = _eager_metrics_fn( + model, outs, eager_model_outputs) + model.metrics_names.append(metric_names) + if not isinstance(loss, list): + loss = [loss] + return loss + loss_metrics + metrics_results + + +def fit_loop( + model, + ins, + out_labels=None, + batch_size=None, + epochs=100, + verbose=1, + callbacks=None, + val_ins=None, + shuffle=True, + callback_metrics=None, + initial_epoch=0, + steps_per_epoch=None, + validation_steps=None): + """Abstract fit function for `f(ins)`. + + Assume that f returns a list, labeled by out_labels. + + Arguments: + model: Instance of the model that is being executed in Eager mode. + ins: List of tensors to be fed to `f` + out_labels: List of strings, display names of + the outputs of `f` + batch_size: Integer batch size or None if unknown. + epochs: Number of times to iterate over the data + verbose: Verbosity mode, 0, 1 or 2 + callbacks: List of callbacks to be called during training + val_ins: List of tensors to be fed to `val_f` + shuffle: Whether to shuffle the data at the beginning of each epoch + callback_metrics: List of strings, the display names of the metrics + passed to the callbacks. They should be the + concatenation of list the display names of the outputs of + `f` and the list of display names of the outputs of `f_val`. + initial_epoch: Epoch at which to start training + (useful for resuming a previous training run) + steps_per_epoch: Total number of steps (batches of samples) + before declaring one epoch finished and starting the + next epoch. Ignored with the default value of `None`. + validation_steps: Number of steps to run validation for (only if doing + validation from data tensors). Ignored with default value of `None`. + + Returns: + `History` object. + + Raises: + ValueError: In case of invalid argument values. + """ + # Required for Eager mode + K.set_learning_phase(True) + + do_validation = False + if val_ins: + do_validation = True + if (verbose and ins and hasattr(ins[0], 'shape') and + hasattr(val_ins[0], 'shape')): + print('Train on %d samples, validate on %d samples' % + (ins[0].shape[0], val_ins[0].shape[0])) + if validation_steps: + if steps_per_epoch is None: + raise ValueError('Can only use `validation_steps` when doing step-wise ' + 'training, i.e. `steps_per_epoch` must be set.') + do_validation = True + + num_train_samples = model._check_num_samples( + ins, batch_size, steps_per_epoch, 'steps_per_epoch') + + if num_train_samples is not None: + index_array = np.arange(num_train_samples) + + model.history = cbks.History() + callbacks = [cbks.BaseLogger()] + (callbacks or []) + [model.history] + if verbose: + if steps_per_epoch is not None: + count_mode = 'steps' + else: + count_mode = 'samples' + callbacks += [cbks.ProgbarLogger(count_mode)] + callbacks = cbks.CallbackList(callbacks) + out_labels = out_labels or [] + + # it's possible to callback a different model than self + # (used by Sequential models) + if hasattr(model, 'callback_model') and model.callback_model: + callback_model = model.callback_model + else: + callback_model = model + + callbacks.set_model(callback_model) + + callbacks.set_params({ + 'batch_size': batch_size, + 'epochs': epochs, + 'steps': steps_per_epoch, + 'samples': num_train_samples, + 'verbose': verbose, + 'do_validation': do_validation, + 'metrics': callback_metrics or [], + }) + callbacks.on_train_begin() + callback_model.stop_training = False + for cbk in callbacks: + cbk.validation_data = val_ins + + for epoch in range(initial_epoch, epochs): + callbacks.on_epoch_begin(epoch) + epoch_logs = {} + if shuffle == 'batch': + index_array = model._batch_shuffle(index_array, batch_size) + elif shuffle: + np.random.shuffle(index_array) + + batches = _make_batches(num_train_samples, batch_size) + + for batch_index, (batch_start, batch_end) in enumerate(batches): + batch_ids = index_array[batch_start:batch_end] + try: + if isinstance(ins[-1], float): + # Do not slice the training phase flag. + ins_batch = _slice_arrays(ins[:-1], batch_ids) + [ins[-1]] + else: + ins_batch = _slice_arrays(ins, batch_ids) + except TypeError: + raise TypeError('TypeError while preparing batch. ' + 'If using HDF5 input data, ' + 'pass shuffle="batch".') + batch_logs = {} + batch_logs['batch'] = batch_index + batch_logs['size'] = len(batch_ids) + + callbacks.on_batch_begin(batch_index, batch_logs) + + ins_batch_converted = [] + for ib in ins_batch: + ins_batch_converted.append(ops.convert_to_tensor(ib, dtype=K.floatx())) + eager_model_inputs = [] + eager_model_outputs = [] + for i in range(len(model.inputs)): + eager_model_inputs.append(ins_batch_converted[i]) + + for i in range(len(model.inputs), len(ins_batch_converted)): + eager_model_outputs.append(ins_batch_converted[i]) + + outs, loss, loss_metrics = _process_single_batch(eager_model_inputs, + eager_model_outputs, + model) + + if not isinstance(outs, list): + outs = [outs] + + for l, o in zip(out_labels, outs): + batch_logs[l] = o + # Required for Eager mode + metrics_names, metrics_results = _eager_metrics_fn(model, outs, + eager_model_outputs) + batch_logs['loss'] = tensor_util.constant_value(K.mean(loss)) + + # TODO(anjalisridhar): Move this to compile to avoid duplicate code. + # In graph mode we set the metric names in compile. However in + # Eager mode we calculate the metrics for each batch in fit_loop. + # We could calculate the metric names and functions in compile. + # This would avoid setting the callback parameters separately. + # We need to do this for the first iteration alone + for m in metrics_names: + if m not in callback_metrics: + callback_metrics.append(m) + + callbacks.set_params({ + 'batch_size': batch_size, + 'epochs': epochs, + 'steps': steps_per_epoch, + 'samples': num_train_samples, + 'verbose': verbose, + 'do_validation': do_validation, + 'metrics': callback_metrics or [], + }) + + for k, v in zip(model.metrics_names, + [K.mean(loss)] + loss_metrics + metrics_results): + batch_logs[k] = tensor_util.constant_value(v) + + callbacks.on_batch_end(batch_index, batch_logs) + if callback_model.stop_training: + break + + if batch_index == len(batches) - 1: # Last batch. + if do_validation: + val_outs = test_loop( + model, val_ins, batch_size=batch_size, verbose=0) + if not isinstance(val_outs, list): + val_outs = [val_outs] + # Same labels assumed. + for l, o in zip(out_labels, val_outs): + epoch_logs['val_' + l] = o + callbacks.on_epoch_end(epoch, epoch_logs) + if callback_model.stop_training: + break + callbacks.on_train_end() + return model.history + + +def test_loop(model, ins, batch_size=None, verbose=0, steps=None): + """Abstract method to loop over some data in batches. + + Arguments: + model: Model instance that is being evaluated in Eager mode. + ins: list of tensors to be fed to `f`. + batch_size: integer batch size or `None`. + verbose: verbosity mode. + steps: Total number of steps (batches of samples) + before declaring predictions finished. + Ignored with the default value of `None`. + + Returns: + Scalar loss (if the model has a single output and no metrics) + or list of scalars (if the model has multiple outputs + and/or metrics). The attribute `model.metrics_names` will give you + the display labels for the scalar outputs. + """ + K.set_learning_phase(False) + num_samples = model._check_num_samples(ins, batch_size, steps, 'steps') + outs = [] + if verbose == 1: + progbar = Progbar(target=num_samples) + batches = _make_batches(num_samples, batch_size) + index_array = np.arange(num_samples) + for batch_index, (batch_start, batch_end) in enumerate(batches): + batch_ids = index_array[batch_start:batch_end] + if isinstance(ins[-1], float): + # Do not slice the training phase flag. + ins_batch = _slice_arrays(ins[:-1], batch_ids) + [ins[-1]] + else: + ins_batch = _slice_arrays(ins, batch_ids) + + ins_batch_converted = [] + for ib in ins_batch: + ins_batch_converted.append(ops.convert_to_tensor(ib, dtype=K.floatx())) + + eager_model_inputs = [] + eager_model_outputs = [] + for i in range(len(model.inputs)): + eager_model_inputs.append(ins_batch_converted[i]) + + for i in range(len(model.inputs), len(ins_batch_converted)): + eager_model_outputs.append(ins_batch_converted[i]) + + loss_outs, loss, loss_metrics = _model_loss(model, eager_model_inputs, + eager_model_outputs) + _, metrics_results = _eager_metrics_fn(model, loss_outs, + eager_model_outputs) + batch_outs = [] + for _, v in zip(model.metrics_names, + [K.mean(loss)] + loss_metrics + metrics_results): + batch_outs.append(tensor_util.constant_value(v)) + + if isinstance(batch_outs, list): + if batch_index == 0: + for batch_out in enumerate(batch_outs): + outs.append(0.) + for i, batch_out in enumerate(batch_outs): + outs[i] += batch_out * len(batch_ids) + else: + if batch_index == 0: + outs.append(0.) + outs[0] += batch_outs * len(batch_ids) + + if verbose == 1: + progbar.update(batch_end) + for i in range(len(outs)): + outs[i] /= num_samples + if len(outs) == 1: + return outs[0] + return outs + + +def predict_loop(model, ins, batch_size=32, verbose=0, steps=None): + """Abstract method to loop over some data in batches. + + Arguments: + model: + ins: list of tensors to be fed to `f`. + batch_size: integer batch size. + verbose: verbosity mode. + steps: Total number of steps (batches of samples) + before declaring `_predict_loop` finished. + Ignored with the default value of `None`. + + Returns: + Array of predictions (if the model has a single output) + or list of arrays of predictions + (if the model has multiple outputs). + """ + K.set_learning_phase(False) + num_samples = model._check_num_samples(ins, batch_size, steps, 'steps') + if verbose == 1: + if steps is not None: + progbar = Progbar(target=steps) + else: + progbar = Progbar(target=num_samples) + + outs = [] + batches = _make_batches(num_samples, batch_size) + index_array = np.arange(num_samples) + for batch_index, (batch_start, batch_end) in enumerate(batches): + batch_ids = index_array[batch_start:batch_end] + if ins and isinstance(ins[-1], float): + # Do not slice the training phase flag. + ins_batch = _slice_arrays(ins[:-1], batch_ids) + [ins[-1]] + else: + ins_batch = _slice_arrays(ins, batch_ids) + + ins_batch_converted = [] + for ib in ins_batch: + ins_batch_converted.append(ops.convert_to_tensor(ib, dtype=K.floatx())) + + eager_model_inputs = [] + for i in range(len(model.inputs)): + eager_model_inputs.append(ins_batch_converted[i]) + + batch_outs = model(eager_model_inputs) + + if not isinstance(batch_outs, list): + batch_outs = [batch_outs] + if batch_index == 0: + # Pre-allocate the results arrays. + for batch_out in batch_outs: + dims = batch_out.shape[1:].dims + dims_list = [d.value for d in dims] + shape = (num_samples,) + tuple(dims_list) + outs.append(np.zeros(shape, dtype=batch_out.dtype.as_numpy_dtype)) + for i, batch_out in enumerate(batch_outs): + outs[i][batch_start:batch_end] = batch_out + if verbose == 1: + progbar.update(batch_end) + if len(outs) == 1: + return outs[0] + return outs diff --git a/tensorflow/python/keras/_impl/keras/engine/training_eager_test.py b/tensorflow/python/keras/_impl/keras/engine/training_eager_test.py new file mode 100644 index 0000000000000000000000000000000000000000..81e2f7a5145a586f6a4cc34f54033723fae6a6e9 --- /dev/null +++ b/tensorflow/python/keras/_impl/keras/engine/training_eager_test.py @@ -0,0 +1,755 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for training routines.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import numpy as np + +from tensorflow.python.framework import ops +from tensorflow.python.keras._impl import keras +from tensorflow.python.keras._impl.keras import testing_utils +from tensorflow.python.platform import test +from tensorflow.python.training.rmsprop import RMSPropOptimizer + + +class TrainingTest(test.TestCase): + + def test_fit_on_arrays(self): + a = keras.layers.Input(shape=(3,), name='input_a') + b = keras.layers.Input(shape=(3,), name='input_b') + + dense = keras.layers.Dense(4, name='dense') + c = dense(a) + d = dense(b) + e = keras.layers.Dropout(0.5, name='dropout')(c) + + model = keras.models.Model([a, b], [d, e]) + + optimizer = RMSPropOptimizer(learning_rate=0.001) + loss = 'mse' + loss_weights = [1., 0.5] + metrics = ['mae'] + model.compile(optimizer, loss, metrics=metrics, loss_weights=loss_weights) + + input_a_np = np.random.random((10, 3)) + input_b_np = np.random.random((10, 3)) + + output_d_np = np.random.random((10, 4)) + output_e_np = np.random.random((10, 4)) + + # Test fit at different verbosity + model.fit( + [input_a_np, input_b_np], [output_d_np, output_e_np], + epochs=1, + batch_size=5, + verbose=0) + model.fit( + [input_a_np, input_b_np], [output_d_np, output_e_np], + epochs=1, + batch_size=5, + verbose=1) + model.fit( + [input_a_np, input_b_np], [output_d_np, output_e_np], + epochs=2, + batch_size=5, + verbose=2) + + # Test with validation data + model.fit( + [input_a_np, input_b_np], [output_d_np, output_e_np], + validation_data=([input_a_np, input_b_np], [output_d_np, + output_e_np]), + epochs=1, + batch_size=5, + verbose=0) + model.fit( + [input_a_np, input_b_np], [output_d_np, output_e_np], + validation_data=([input_a_np, input_b_np], [output_d_np, + output_e_np]), + epochs=2, + batch_size=5, + verbose=1) + model.fit( + [input_a_np, input_b_np], [output_d_np, output_e_np], + validation_data=([input_a_np, input_b_np], [output_d_np, + output_e_np]), + epochs=2, + batch_size=5, + verbose=2) + model.train_on_batch([input_a_np, input_b_np], [output_d_np, output_e_np]) + + # Test with validation split + model.fit( + [input_a_np, input_b_np], [output_d_np, output_e_np], + epochs=2, + batch_size=5, + verbose=0, + validation_split=0.2) + + # Test with dictionary inputs + model.fit( + { + 'input_a': input_a_np, + 'input_b': input_b_np + }, {'dense': output_d_np, + 'dropout': output_e_np}, + epochs=1, + batch_size=5, + verbose=0) + model.fit( + { + 'input_a': input_a_np, + 'input_b': input_b_np + }, {'dense': output_d_np, + 'dropout': output_e_np}, + epochs=1, + batch_size=5, + verbose=1) + model.fit( + { + 'input_a': input_a_np, + 'input_b': input_b_np + }, {'dense': output_d_np, + 'dropout': output_e_np}, + validation_data=({'input_a': input_a_np, + 'input_b': input_b_np + }, + { + 'dense': output_d_np, + 'dropout': output_e_np + }), + epochs=1, + batch_size=5, + verbose=0) + model.train_on_batch({ + 'input_a': input_a_np, + 'input_b': input_b_np + }, {'dense': output_d_np, + 'dropout': output_e_np}) + # Test with lists for loss, metrics + loss = ['mae', 'mse'] + metrics = ['acc', 'mae'] + model.compile(optimizer, loss, metrics=metrics) + model.fit( + [input_a_np, input_b_np], [output_d_np, output_e_np], + epochs=1, + batch_size=5, + verbose=0) + + # Test with dictionaries for loss, metrics, loss weights + loss = {'dense': 'mse', 'dropout': 'mae'} + loss_weights = {'dense': 1., 'dropout': 0.5} + metrics = {'dense': 'mse', 'dropout': 'mae'} + model.compile(optimizer, loss, metrics=metrics, loss_weights=loss_weights) + model.fit( + [input_a_np, input_b_np], [output_d_np, output_e_np], + epochs=1, + batch_size=5, + verbose=0) + + # Invalid use cases + with self.assertRaises(AttributeError): + model.fit( + [input_a_np, input_b_np], [output_d_np, output_e_np], + epochs=1, + validation_data=([input_a_np, input_b_np], 0, 0), + verbose=0) + with self.assertRaises(ValueError): + model.train_on_batch({'input_a': input_a_np}, + [output_d_np, output_e_np]) + with self.assertRaises(ValueError): + model.train_on_batch([input_a_np], [output_d_np, output_e_np]) + with self.assertRaises(AttributeError): + model.train_on_batch(1, [output_d_np, output_e_np]) + with self.assertRaises(ValueError): + model.train_on_batch(input_a_np, [output_d_np, output_e_np]) + with self.assertRaises(ValueError): + bad_input = np.random.random((11, 3)) + model.train_on_batch([bad_input, input_b_np], + [output_d_np, output_e_np]) + with self.assertRaises(ValueError): + bad_target = np.random.random((11, 4)) + model.train_on_batch([input_a_np, input_b_np], + [bad_target, output_e_np]) + + # Build single-input model + x = keras.layers.Input(shape=(3,), name='input_a') + y = keras.layers.Dense(4)(x) + model = keras.models.Model(x, y) + model.compile(optimizer=RMSPropOptimizer(learning_rate=0.001), loss='mse') + # This will work + model.fit([input_a_np], output_d_np, epochs=1) + with self.assertRaises(ValueError): + model.fit([input_a_np, input_a_np], output_d_np, epochs=1) + + def test_evaluate_predict_on_arrays(self): + a = keras.layers.Input(shape=(3,), name='input_a') + b = keras.layers.Input(shape=(3,), name='input_b') + + dense = keras.layers.Dense(4, name='dense') + c = dense(a) + d = dense(b) + e = keras.layers.Dropout(0.5, name='dropout')(c) + + model = keras.models.Model([a, b], [d, e]) + + optimizer = RMSPropOptimizer(learning_rate=0.001) + loss = 'mse' + loss_weights = [1., 0.5] + metrics = ['mae'] + model.compile( + optimizer, + loss, + metrics=metrics, + loss_weights=loss_weights, + sample_weight_mode=None) + + input_a_np = np.random.random((10, 3)) + input_b_np = np.random.random((10, 3)) + + output_d_np = np.random.random((10, 4)) + output_e_np = np.random.random((10, 4)) + + # Test evaluate at different verbosity + out = model.evaluate( + [input_a_np, input_b_np], [output_d_np, output_e_np], + batch_size=5, + verbose=0) + self.assertEqual(len(out), 5) + out = model.evaluate( + [input_a_np, input_b_np], [output_d_np, output_e_np], + batch_size=5, + verbose=1) + self.assertEqual(len(out), 5) + out = model.evaluate( + [input_a_np, input_b_np], [output_d_np, output_e_np], + batch_size=5, + verbose=2) + self.assertEqual(len(out), 5) + out = model.test_on_batch([input_a_np, input_b_np], + [output_d_np, output_e_np]) + self.assertEqual(len(out), 5) + + # Test evaluate with dictionary inputs + model.evaluate( + { + 'input_a': input_a_np, + 'input_b': input_b_np + }, {'dense': output_d_np, + 'dropout': output_e_np}, + batch_size=5, + verbose=0) + model.evaluate( + { + 'input_a': input_a_np, + 'input_b': input_b_np + }, {'dense': output_d_np, + 'dropout': output_e_np}, + batch_size=5, + verbose=1) + + # Test predict + out = model.predict([input_a_np, input_b_np], batch_size=5) + self.assertEqual(len(out), 2) + out = model.predict({'input_a': input_a_np, 'input_b': input_b_np}) + self.assertEqual(len(out), 2) + out = model.predict_on_batch({ + 'input_a': input_a_np, + 'input_b': input_b_np + }) + self.assertEqual(len(out), 2) + + def test_invalid_loss_or_metrics(self): + num_classes = 5 + train_samples = 1000 + test_samples = 1000 + input_dim = 5 + + model = keras.models.Sequential() + model.add(keras.layers.Dense(10, input_shape=(input_dim,))) + model.add(keras.layers.Activation('relu')) + model.add(keras.layers.Dense(num_classes)) + model.add(keras.layers.Activation('softmax')) + model.compile(loss='categorical_crossentropy', + optimizer=RMSPropOptimizer(learning_rate=0.001)) + np.random.seed(1337) + + (x_train, y_train), (_, _) = testing_utils.get_test_data( + train_samples=train_samples, + test_samples=test_samples, + input_shape=(input_dim,), + num_classes=num_classes) + + with self.assertRaises(ValueError): + model.fit(x_train, np.concatenate([y_train, y_train], axis=-1)) + + with self.assertRaises(TypeError): + model.compile(loss='categorical_crossentropy', + optimizer=RMSPropOptimizer(learning_rate=0.001), + metrics=set(0)) + + with self.assertRaises(ValueError): + model.compile(loss=None, + optimizer='rms') + + +class LossWeightingTest(test.TestCase): + + def test_class_weights(self): + num_classes = 5 + batch_size = 5 + epochs = 5 + weighted_class = 3 + train_samples = 3000 + test_samples = 3000 + input_dim = 5 + + model = keras.models.Sequential() + model.add(keras.layers.Dense(10, input_shape=(input_dim,))) + model.add(keras.layers.Activation('relu')) + model.add(keras.layers.Dense(num_classes)) + model.add(keras.layers.Activation('softmax')) + model.compile(loss='categorical_crossentropy', + optimizer=RMSPropOptimizer(learning_rate=0.001)) + + np.random.seed(1337) + (x_train, y_train), (x_test, y_test) = testing_utils.get_test_data( + train_samples=train_samples, + test_samples=test_samples, + input_shape=(input_dim,), + num_classes=num_classes) + int_y_test = y_test.copy() + int_y_train = y_train.copy() + # convert class vectors to binary class matrices + y_train = keras.utils.to_categorical(y_train, num_classes) + y_test = keras.utils.to_categorical(y_test, num_classes) + test_ids = np.where(int_y_test == np.array(weighted_class))[0] + + class_weight = dict([(i, 1.) for i in range(num_classes)]) + class_weight[weighted_class] = 2. + + sample_weight = np.ones((y_train.shape[0])) + sample_weight[int_y_train == weighted_class] = 2. + + model.fit( + x_train, + y_train, + batch_size=batch_size, + epochs=epochs // 3, + verbose=0, + class_weight=class_weight, + validation_data=(x_train, y_train, sample_weight)) + model.fit( + x_train, + y_train, + batch_size=batch_size, + epochs=epochs // 2, + verbose=0, + class_weight=class_weight) + model.fit( + x_train, + y_train, + batch_size=batch_size, + epochs=epochs // 2, + verbose=0, + class_weight=class_weight, + validation_split=0.1) + + model.train_on_batch( + x_train[:batch_size], y_train[:batch_size], class_weight=class_weight) + ref_score = model.evaluate(x_test, y_test, verbose=0) + score = model.evaluate( + x_test[test_ids, :], y_test[test_ids, :], verbose=0) + self.assertLess(score, ref_score) + + def test_sample_weights(self): + num_classes = 5 + batch_size = 5 + epochs = 5 + weighted_class = 3 + train_samples = 3000 + test_samples = 3000 + input_dim = 5 + + model = keras.models.Sequential() + model.add(keras.layers.Dense(10, input_shape=(input_dim,))) + model.add(keras.layers.Activation('relu')) + model.add(keras.layers.Dense(num_classes)) + model.add(keras.layers.Activation('softmax')) + model.compile(loss='categorical_crossentropy', + optimizer=RMSPropOptimizer(learning_rate=0.001)) + + np.random.seed(43) + (x_train, y_train), (x_test, y_test) = testing_utils.get_test_data( + train_samples=train_samples, + test_samples=test_samples, + input_shape=(input_dim,), + num_classes=num_classes) + int_y_test = y_test.copy() + int_y_train = y_train.copy() + # convert class vectors to binary class matrices + y_train = keras.utils.to_categorical(y_train, num_classes) + y_test = keras.utils.to_categorical(y_test, num_classes) + test_ids = np.where(int_y_test == np.array(weighted_class))[0] + + class_weight = dict([(i, 1.) for i in range(num_classes)]) + class_weight[weighted_class] = 2. + + sample_weight = np.ones((y_train.shape[0])) + sample_weight[int_y_train == weighted_class] = 2. + + model.fit( + x_train, + y_train, + batch_size=batch_size, + epochs=epochs // 3, + verbose=0, + sample_weight=sample_weight) + model.fit( + x_train, + y_train, + batch_size=batch_size, + epochs=epochs // 3, + verbose=0, + sample_weight=sample_weight, + validation_split=0.1) + model.train_on_batch( + x_train[:batch_size], + y_train[:batch_size], + sample_weight=sample_weight[:batch_size]) + model.test_on_batch( + x_train[:batch_size], + y_train[:batch_size], + sample_weight=sample_weight[:batch_size]) + + def test_temporal_sample_weights(self): + num_classes = 5 + weighted_class = 3 + train_samples = 1000 + test_samples = 1000 + input_dim = 5 + timesteps = 3 + + model = keras.models.Sequential() + model.add( + keras.layers.TimeDistributed( + keras.layers.Dense(num_classes), + input_shape=(timesteps, input_dim))) + model.add(keras.layers.Activation('softmax')) + + np.random.seed(1337) + (_, y_train), _ = testing_utils.get_test_data( + train_samples=train_samples, + test_samples=test_samples, + input_shape=(input_dim,), + num_classes=num_classes) + int_y_train = y_train.copy() + # convert class vectors to binary class matrices + y_train = keras.utils.to_categorical(y_train, num_classes) + + class_weight = dict([(i, 1.) for i in range(num_classes)]) + class_weight[weighted_class] = 2. + + sample_weight = np.ones((y_train.shape[0])) + sample_weight[int_y_train == weighted_class] = 2. + with self.assertRaises(ValueError): + model.compile( + loss='binary_crossentropy', + optimizer=RMSPropOptimizer(learning_rate=0.001), + sample_weight_mode='temporal') + + def test_class_weight_invalid_use_case(self): + num_classes = 5 + train_samples = 1000 + test_samples = 1000 + input_dim = 5 + timesteps = 3 + + model = keras.models.Sequential() + model.add( + keras.layers.TimeDistributed( + keras.layers.Dense(num_classes), + input_shape=(timesteps, input_dim))) + model.add(keras.layers.Activation('softmax')) + model.compile( + loss='binary_crossentropy', + optimizer=RMSPropOptimizer(learning_rate=0.001)) + + (x_train, y_train), _ = testing_utils.get_test_data( + train_samples=train_samples, + test_samples=test_samples, + input_shape=(input_dim,), + num_classes=num_classes) + # convert class vectors to binary class matrices + y_train = keras.utils.to_categorical(y_train, num_classes) + class_weight = dict([(i, 1.) for i in range(num_classes)]) + + del class_weight[1] + with self.assertRaises(ValueError): + model.fit(x_train, y_train, + epochs=0, verbose=0, class_weight=class_weight) + + with self.assertRaises(ValueError): + model.compile( + loss='binary_crossentropy', + optimizer=RMSPropOptimizer(learning_rate=0.001), + sample_weight_mode=[]) + + # Build multi-output model + x = keras.Input((3,)) + y1 = keras.layers.Dense(4, name='1')(x) + y2 = keras.layers.Dense(4, name='2')(x) + model = keras.models.Model(x, [y1, y2]) + model.compile(optimizer=RMSPropOptimizer(learning_rate=0.001), loss='mse') + x_np = np.random.random((10, 3)) + y_np = np.random.random((10, 4)) + w_np = np.random.random((10,)) + # This will work + model.fit(x_np, [y_np, y_np], epochs=1, sample_weight={'1': w_np}) + # These will not + with self.assertRaises(ValueError): + model.fit(x_np, [y_np, y_np], epochs=1, sample_weight=[w_np]) + with self.assertRaises(TypeError): + model.fit(x_np, [y_np, y_np], epochs=1, sample_weight=w_np) + with self.assertRaises(ValueError): + bad_w_np = np.random.random((11,)) + model.fit(x_np, [y_np, y_np], epochs=1, sample_weight={'1': bad_w_np}) + with self.assertRaises(ValueError): + bad_w_np = np.random.random((10, 2)) + model.fit(x_np, [y_np, y_np], epochs=1, sample_weight={'1': bad_w_np}) + with self.assertRaises(ValueError): + bad_w_np = np.random.random((10, 2, 2)) + model.fit(x_np, [y_np, y_np], epochs=1, sample_weight={'1': bad_w_np}) + + +class TestDynamicTrainability(test.TestCase): + + def test_trainable_warning(self): + x = np.random.random((5, 3)) + y = np.random.random((5, 2)) + model = keras.models.Sequential() + model.add(keras.layers.Dense(2, input_dim=3)) + model.trainable = False + model.compile(RMSPropOptimizer(learning_rate=0.001), 'mse') + model.trainable = True + with self.assertRaises(ValueError): + model.train_on_batch(x, y) + + def test_trainable_argument(self): + x = np.random.random((5, 3)) + y = np.random.random((5, 2)) + + model = keras.models.Sequential() + model.add(keras.layers.Dense(2, input_dim=3, trainable=False)) + model.compile(RMSPropOptimizer(learning_rate=0.001), 'mse') + out = model.predict(x) + with self.assertRaises(ValueError): + model.train_on_batch(x, y) + out_2 = model.predict(x) + self.assertAllClose(out, out_2) + + # test with nesting + inputs = keras.layers.Input(shape=(3,)) + output = model(inputs) + model = keras.models.Model(inputs, output) + model.compile(RMSPropOptimizer(learning_rate=0.001), 'mse') + out = model.predict(x) + with self.assertRaises(ValueError): + model.train_on_batch(x, y) + out_2 = model.predict(x) + self.assertAllClose(out, out_2) + + def test_layer_trainability_switch(self): + # with constructor argument, in Sequential + model = keras.models.Sequential() + model.add(keras.layers.Dense(2, trainable=False, input_dim=1)) + self.assertListEqual(model.trainable_weights, []) + + # by setting the `trainable` argument, in Sequential + model = keras.models.Sequential() + layer = keras.layers.Dense(2, input_dim=1) + model.add(layer) + self.assertListEqual(model.trainable_weights, layer.trainable_weights) + layer.trainable = False + self.assertListEqual(model.trainable_weights, []) + + # with constructor argument, in Model + x = keras.layers.Input(shape=(1,)) + y = keras.layers.Dense(2, trainable=False)(x) + model = keras.models.Model(x, y) + self.assertListEqual(model.trainable_weights, []) + + # by setting the `trainable` argument, in Model + x = keras.layers.Input(shape=(1,)) + layer = keras.layers.Dense(2) + y = layer(x) + model = keras.models.Model(x, y) + self.assertListEqual(model.trainable_weights, layer.trainable_weights) + layer.trainable = False + self.assertListEqual(model.trainable_weights, []) + + def test_model_trainability_switch(self): + # a non-trainable model has no trainable weights + x = keras.layers.Input(shape=(1,)) + y = keras.layers.Dense(2)(x) + model = keras.models.Model(x, y) + model.trainable = False + self.assertListEqual(model.trainable_weights, []) + + # same for Sequential + model = keras.models.Sequential() + model.add(keras.layers.Dense(2, input_dim=1)) + model.trainable = False + self.assertListEqual(model.trainable_weights, []) + + def test_nested_model_trainability(self): + + # a Sequential inside a Model + inner_model = keras.models.Sequential() + inner_model.add(keras.layers.Dense(2, input_dim=1)) + + x = keras.layers.Input(shape=(1,)) + y = inner_model(x) + outer_model = keras.models.Model(x, y) + self.assertListEqual(outer_model.trainable_weights, + inner_model.trainable_weights) + inner_model.trainable = False + self.assertListEqual(outer_model.trainable_weights, []) + inner_model.trainable = True + inner_model.layers[-1].trainable = False + self.assertListEqual(outer_model.trainable_weights, []) + + # a Sequential inside a Sequential + inner_model = keras.models.Sequential() + inner_model.add(keras.layers.Dense(2, input_dim=1)) + outer_model = keras.models.Sequential() + outer_model.add(inner_model) + self.assertListEqual(outer_model.trainable_weights, + inner_model.trainable_weights) + inner_model.trainable = False + self.assertListEqual(outer_model.trainable_weights, []) + inner_model.trainable = True + inner_model.layers[-1].trainable = False + self.assertListEqual(outer_model.trainable_weights, []) + + # a Model inside a Model + x = keras.layers.Input(shape=(1,)) + y = keras.layers.Dense(2)(x) + inner_model = keras.models.Model(x, y) + x = keras.layers.Input(shape=(1,)) + y = inner_model(x) + outer_model = keras.models.Model(x, y) + self.assertListEqual(outer_model.trainable_weights, + inner_model.trainable_weights) + inner_model.trainable = False + self.assertListEqual(outer_model.trainable_weights, []) + inner_model.trainable = True + inner_model.layers[-1].trainable = False + self.assertListEqual(outer_model.trainable_weights, []) + + # a Model inside a Sequential + x = keras.layers.Input(shape=(1,)) + y = keras.layers.Dense(2)(x) + inner_model = keras.models.Model(x, y) + outer_model = keras.models.Sequential() + outer_model.add(inner_model) + self.assertListEqual(outer_model.trainable_weights, + inner_model.trainable_weights) + inner_model.trainable = False + self.assertListEqual(outer_model.trainable_weights, []) + inner_model.trainable = True + inner_model.layers[-1].trainable = False + self.assertListEqual(outer_model.trainable_weights, []) + + +class TestTrainingUtils(test.TestCase): + + def test_check_array_lengths(self): + keras.engine.training._check_array_lengths(None, None, None) + a_np = np.random.random((4, 3, 3)) + keras.engine.training._check_array_lengths(a_np, a_np, a_np) + keras.engine.training._check_array_lengths( + [a_np, a_np], [a_np, a_np], [a_np, a_np]) + keras.engine.training._check_array_lengths([None], [None], [None]) + + b_np = np.random.random((3, 4)) + with self.assertRaises(ValueError): + keras.engine.training._check_array_lengths(a_np, None, None) + with self.assertRaises(ValueError): + keras.engine.training._check_array_lengths(a_np, a_np, None) + with self.assertRaises(ValueError): + keras.engine.training._check_array_lengths([a_np], [None], None) + with self.assertRaises(ValueError): + keras.engine.training._check_array_lengths([a_np], [b_np], None) + with self.assertRaises(ValueError): + keras.engine.training._check_array_lengths([a_np], None, [b_np]) + + def test_slice_arrays(self): + input_a = np.random.random((10, 3)) + keras.engine.training._slice_arrays(None) + keras.engine.training._slice_arrays(input_a, 0) + keras.engine.training._slice_arrays(input_a, 0, 1) + keras.engine.training._slice_arrays(input_a, stop=2) + input_a = [None, [1, 1], None, [1, 1]] + keras.engine.training._slice_arrays(input_a, 0) + keras.engine.training._slice_arrays(input_a, 0, 1) + keras.engine.training._slice_arrays(input_a, stop=2) + input_a = [None] + keras.engine.training._slice_arrays(input_a, 0) + keras.engine.training._slice_arrays(input_a, 0, 1) + keras.engine.training._slice_arrays(input_a, stop=2) + input_a = None + keras.engine.training._slice_arrays(input_a, 0) + keras.engine.training._slice_arrays(input_a, 0, 1) + keras.engine.training._slice_arrays(input_a, stop=2) + + def test_fit_with_BatchNorm(self): + model = keras.models.Sequential() + model.add(keras.layers.Dense(10, input_dim=4)) + model.add(keras.layers.BatchNormalization()) + model.add(keras.layers.Activation('tanh')) + model.add(keras.layers.Dropout(0.2)) + + input_a_np = np.random.random((10, 4)) + output_b_np = np.random.random((10, 10)) + + model.compile(loss='binary_crossentropy', optimizer=RMSPropOptimizer(0.001)) + model.fit(input_a_np, output_b_np, epochs=1, batch_size=5, verbose=0) + + def test_fit_with_regularization(self): + model = keras.models.Sequential() + with self.assertRaises(ValueError): + model.add( + keras.layers.Dense(4, input_dim=3, + kernel_regularizer=keras.regularizers.l2(0.01), + activity_regularizer=keras.regularizers.l1(0.01))) + + +if __name__ == '__main__': + # Bazel sets these environment variables to very long paths. + # Tempfile uses them to create long paths, and in turn multiprocessing + # library tries to create sockets named after paths. Delete whatever bazel + # writes to these to avoid tests failing due to socket addresses being too + # long. + for var in ('TMPDIR', 'TMP', 'TEMP'): + if var in os.environ: + del os.environ[var] + + ops.enable_eager_execution() + test.main() diff --git a/tensorflow/python/keras/_impl/keras/engine/training_test.py b/tensorflow/python/keras/_impl/keras/engine/training_test.py index 5a033a04ade6c9b93ab32fb45f31d3efec85cd3f..b380238e4e2bb3bccbfc5efdc0db213d86910fe5 100644 --- a/tensorflow/python/keras/_impl/keras/engine/training_test.py +++ b/tensorflow/python/keras/_impl/keras/engine/training_test.py @@ -78,6 +78,14 @@ class TrainingTest(test.TestCase): verbose=2) model.train_on_batch([input_a_np, input_b_np], [output_d_np, output_e_np]) + # Test model with input data as a list of lists + model.fit( + [np.ndarray.tolist(input_a_np), np.ndarray.tolist(input_b_np)], + [output_d_np, output_e_np], + epochs=2, + batch_size=5, + verbose=2) + # Test with validation data model.fit( [input_a_np, input_b_np], [output_d_np, output_e_np], @@ -205,6 +213,16 @@ class TrainingTest(test.TestCase): with self.assertRaises(ValueError): model.fit([input_a_np, input_a_np], output_d_np, epochs=1) + # Test model on a list of floats + input_a_np = np.random.random((10, 3)) + input_b_np = np.random.random((10, 4)) + + model.fit([np.ndarray.tolist(input_a_np)], + [np.ndarray.tolist(input_b_np)], + epochs=2, + batch_size=5, + verbose=2) + def test_evaluate_predict_on_arrays(self): with self.test_session(): a = keras.layers.Input(shape=(3,), name='input_a') diff --git a/tensorflow/python/keras/_impl/keras/layers/core.py b/tensorflow/python/keras/_impl/keras/layers/core.py index 6ee3fb48b2f1426b87c5c1947e90d0797e9b9ff7..ea2d3f2f04a591ab97f09dd0a43829fe9f75fc9e 100644 --- a/tensorflow/python/keras/_impl/keras/layers/core.py +++ b/tensorflow/python/keras/_impl/keras/layers/core.py @@ -23,6 +23,7 @@ import types as python_types import numpy as np +from tensorflow.python.eager import context from tensorflow.python.framework import tensor_shape from tensorflow.python.keras._impl.keras import activations from tensorflow.python.keras._impl.keras import backend as K @@ -119,7 +120,8 @@ class Dropout(tf_core_layers.Dropout, Layer): if training is None: training = K.learning_phase() output = super(Dropout, self).call(inputs, training=training) - if training is K.learning_phase(): + # EagerTensor object has no attribute _uses_learning_phase + if not context.in_eager_mode() and training is K.learning_phase(): output._uses_learning_phase = True # pylint: disable=protected-access return output diff --git a/tensorflow/python/keras/_impl/keras/layers/normalization.py b/tensorflow/python/keras/_impl/keras/layers/normalization.py index 965ef70e6e6cb488aa4832462da4a2cb43e964a6..eecb14ceaa38968d54ea6702e534ee29b6e180d5 100644 --- a/tensorflow/python/keras/_impl/keras/layers/normalization.py +++ b/tensorflow/python/keras/_impl/keras/layers/normalization.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.eager import context from tensorflow.python.keras._impl.keras import backend as K from tensorflow.python.keras._impl.keras import constraints from tensorflow.python.keras._impl.keras import initializers @@ -108,7 +109,7 @@ class BatchNormalization(tf_normalization_layers.BatchNormalization, Layer): if training is None: training = K.learning_phase() output = super(BatchNormalization, self).call(inputs, training=training) - if training is K.learning_phase(): + if context.in_graph_mode() and training is K.learning_phase(): output._uses_learning_phase = True # pylint: disable=protected-access return output diff --git a/tensorflow/python/keras/_impl/keras/optimizers.py b/tensorflow/python/keras/_impl/keras/optimizers.py index e47987aadc48e1f722558e32929ab81ad82bea0f..a55a5e39a69c4286f0a002474b5ad543c04bf256 100644 --- a/tensorflow/python/keras/_impl/keras/optimizers.py +++ b/tensorflow/python/keras/_impl/keras/optimizers.py @@ -24,6 +24,7 @@ import copy import six from six.moves import zip # pylint: disable=redefined-builtin +from tensorflow.python.eager import context from tensorflow.python.framework import dtypes as dtypes_module from tensorflow.python.framework import ops from tensorflow.python.keras._impl.keras import backend as K @@ -680,7 +681,14 @@ class TFOptimizer(Optimizer): def __init__(self, optimizer): # pylint: disable=super-init-not-called self.optimizer = optimizer with K.name_scope(self.__class__.__name__): - self.iterations = K.variable(0, dtype='int64', name='iterations') + if context.in_graph_mode(): + self.iterations = K.variable(0, dtype='int64', name='iterations') + + def apply_gradients(self, grads): + self.optimizer.apply_gradients(grads) + + def get_grads(self, loss, params): + return self.optimizer.compute_gradients(loss, params) def get_updates(self, loss, params): grads = self.optimizer.compute_gradients(loss, params) diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD index 3a6058054be4fad05daf20703384d57624d95deb..d4ceb2e489c8a20d26eaf9d89b12992d2b8673d7 100644 --- a/tensorflow/python/kernel_tests/BUILD +++ b/tensorflow/python/kernel_tests/BUILD @@ -1294,7 +1294,7 @@ cuda_py_test( cuda_py_test( name = "control_flow_ops_py_test", - # TOOD(b/70473603): change this back to "small" once the C API is + # TODO(b/70473603): change this back to "small" once the C API is # permanently enabled size = "medium", srcs = ["control_flow_ops_py_test.py"], diff --git a/tensorflow/python/kernel_tests/array_ops_test.py b/tensorflow/python/kernel_tests/array_ops_test.py index 68b7c3a98a160017c2d5d84d3d9a8f92a6ab66b0..ee7a5621e0c860cf710dd619ba85465c717a5196 100644 --- a/tensorflow/python/kernel_tests/array_ops_test.py +++ b/tensorflow/python/kernel_tests/array_ops_test.py @@ -414,7 +414,7 @@ class MeshgridTest(test_util.TensorFlowTestCase): def _compareDiffType(self, n, np_dtype, use_gpu): inputs = [] for index in ("ij", "xy"): - for i in range(n): + for _ in range(n): x = np.linspace(-10, 10, 5).astype(np_dtype) if np_dtype in (np.complex64, np.complex128): x += 1j @@ -422,8 +422,8 @@ class MeshgridTest(test_util.TensorFlowTestCase): numpy_out = np.meshgrid(*inputs, indexing=index) with self.test_session(use_gpu=use_gpu): tf_out = array_ops.meshgrid(*inputs, indexing=index) - for X, _X in zip(numpy_out, tf_out): - self.assertAllEqual(X, _X.eval()) + for x_np, x_tf in zip(numpy_out, tf_out): + self.assertAllEqual(x_np, x_tf.eval()) def testCompare(self): for t in (np.float16, np.float32, np.float64, np.int32, np.int64, @@ -952,6 +952,32 @@ class SliceAssignTest(test_util.TensorFlowTestCase): v = variables.Variable([1, 2]) sess.run(v[:].assign([1, 2])) + def testTypeError(self): + init_val = constant_op.constant([1, 2], dtype=dtypes.int32) + too_small_val = constant_op.constant([3, 4], dtype=dtypes.int8) + too_large_val = constant_op.constant([3, 4], dtype=dtypes.int64) + v = variables.Variable(init_val) + with self.assertRaises(TypeError): + v[:].assign(too_small_val) + with self.assertRaises(TypeError): + v[:].assign(too_large_val) + + def testTypeErrorResource(self): + init_val = constant_op.constant([1, 2], dtype=dtypes.int32) + too_small_val = constant_op.constant([3, 4], dtype=dtypes.int8) + too_large_val = constant_op.constant([3, 4], dtype=dtypes.int64) + v = resource_variable_ops.ResourceVariable(init_val) + with self.test_session() as sess: + sess.run(v.initializer) + with self.assertRaisesRegexp( + errors.InvalidArgumentError, + "l-value dtype int32 does not match r-value dtype int64"): + sess.run(v[:].assign(too_large_val)) + with self.assertRaisesRegexp( + errors.InvalidArgumentError, + "l-value dtype int32 does not match r-value dtype int8"): + sess.run(v[:].assign(too_small_val)) + class ShapeSizeRankTest(test_util.TensorFlowTestCase): @@ -989,7 +1015,7 @@ class SequenceMaskTest(test_util.TensorFlowTestCase): with self.assertRaisesRegexp(ValueError, "maxlen must be scalar"): array_ops.sequence_mask([10, 20], [10, 20]) - def testOneDimensional(self): + def testOneDimensionalWithMaxlen(self): with self.test_session(): res = array_ops.sequence_mask(constant_op.constant([1, 3, 2]), 5) self.assertAllEqual(res.get_shape(), [3, 5]) @@ -998,9 +1024,11 @@ class SequenceMaskTest(test_util.TensorFlowTestCase): [[True, False, False, False, False], [True, True, True, False, False], [True, True, False, False, False]]) + def testOneDimensionalDtypeWithoutMaxlen(self): + with self.test_session(): # test dtype and default maxlen: - res = array_ops.sequence_mask( - constant_op.constant([0, 1, 4]), dtype=dtypes.float32) + res = array_ops.sequence_mask(constant_op.constant([0, 1, 4]), + dtype=dtypes.float32) if ops._USE_C_API: self.assertAllEqual(res.get_shape().as_list(), [3, 4]) else: @@ -1009,6 +1037,20 @@ class SequenceMaskTest(test_util.TensorFlowTestCase): res.eval(), [[0.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 1.0, 1.0, 1.0]]) + def testOneDimensionalWithoutMaxlen(self): + with self.test_session(): + res = array_ops.sequence_mask( + constant_op.constant([0, 1, 4])) + if ops._USE_C_API: + self.assertAllEqual(res.get_shape().as_list(), [3, 4]) + else: + self.assertAllEqual(res.get_shape().as_list(), [3, None]) + self.assertAllEqual( + res.eval(), + [[False, False, False, False], + [True, False, False, False], + [True, True, True, True]]) + def testTwoDimensional(self): with self.test_session(): res = array_ops.sequence_mask(constant_op.constant([[1, 3, 2]]), 5) @@ -1029,6 +1071,11 @@ class SequenceMaskTest(test_util.TensorFlowTestCase): [[[0.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 1.0, 1.0, 1.0]], [[1.0, 0.0, 0.0, 0.0], [1.0, 1.0, 0.0, 0.0], [1.0, 1.0, 1.0, 0.0]]]) + def testUnknownShape(self): + lengths = array_ops.placeholder(dtype=dtypes.int32) + res = array_ops.sequence_mask(lengths) + self.assertEqual(res.shape, None) + def testDtypes(self): def check_dtypes(lengths_dtype, maxlen_dtype): diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py index 5d648bb235c8a8a0ae435e0c249bcb64ba787b08..4fafc36014e65318c72610949bfdfb092293c95a 100644 --- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py +++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py @@ -44,6 +44,7 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import functional_ops from tensorflow.python.ops import gen_array_ops +from tensorflow.python.ops import gen_control_flow_ops from tensorflow.python.ops import gen_data_flow_ops from tensorflow.python.ops import gen_logging_ops from tensorflow.python.ops import gen_state_ops @@ -143,7 +144,7 @@ class ControlFlowTest(test.TestCase): enter_v = control_flow_ops._Enter(v, "foo_1", is_constant=True) nine = constant_op.constant(9) - enter_nine = control_flow_ops.enter(nine, "foo_1") + enter_nine = gen_control_flow_ops._enter(nine, "foo_1") op = state_ops.assign(enter_v, enter_nine) v2 = control_flow_ops.with_dependencies([op], enter_v) v3 = control_flow_ops.exit(v2) @@ -163,9 +164,9 @@ class ControlFlowTest(test.TestCase): def testEnterMulExit(self): with self.test_session(): data = constant_op.constant([1, 2, 3, 4, 5, 6], name="data") - enter_data = control_flow_ops.enter(data, "foo_1", False) + enter_data = gen_control_flow_ops._enter(data, "foo_1", False) five = constant_op.constant(5) - enter_five = control_flow_ops.enter(five, "foo_1", False) + enter_five = gen_control_flow_ops._enter(five, "foo_1", False) mul_op = math_ops.multiply(enter_data, enter_five) exit_op = control_flow_ops.exit(mul_op) @@ -177,11 +178,12 @@ class ControlFlowTest(test.TestCase): v = variables.Variable([0.0, 0.0], dtype=dtypes.float32) # If is_constant=True, the shape information should be propagated. - enter_v_constant = control_flow_ops.enter(v, "frame1", is_constant=True) + enter_v_constant = gen_control_flow_ops._enter( + v, "frame1", is_constant=True) self.assertEqual(enter_v_constant.shape, [2]) # Otherwise, the shape should be unknown. - enter_v_non_constant = control_flow_ops.enter( + enter_v_non_constant = gen_control_flow_ops._enter( v, "frame2", is_constant=False) self.assertEqual(enter_v_non_constant.shape, None) @@ -255,8 +257,8 @@ class ControlFlowTest(test.TestCase): false = ops.convert_to_tensor(False) n = constant_op.constant(10) - enter_false = control_flow_ops.enter(false, "foo_1", False) - enter_n = control_flow_ops.enter(n, "foo_1", False) + enter_false = gen_control_flow_ops._enter(false, "foo_1", False) + enter_n = gen_control_flow_ops._enter(n, "foo_1", False) merge_n = control_flow_ops.merge([enter_n, enter_n], name="merge_n")[0] switch_n = control_flow_ops.switch(merge_n, enter_false) @@ -273,9 +275,9 @@ class ControlFlowTest(test.TestCase): one = constant_op.constant(1) n = constant_op.constant(10) - enter_i = control_flow_ops.enter(zero, "foo", False) - enter_one = control_flow_ops.enter(one, "foo", True) - enter_n = control_flow_ops.enter(n, "foo", True) + enter_i = gen_control_flow_ops._enter(zero, "foo", False) + enter_one = gen_control_flow_ops._enter(one, "foo", True) + enter_n = gen_control_flow_ops._enter(n, "foo", True) with ops.device(test.gpu_device_name()): merge_i = control_flow_ops.merge([enter_i, enter_i])[0] @@ -299,9 +301,9 @@ class ControlFlowTest(test.TestCase): one = constant_op.constant(1) n = constant_op.constant(10) - enter_i = control_flow_ops.enter(zero, "foo", False) - enter_one = control_flow_ops.enter(one, "foo", True) - enter_n = control_flow_ops.enter(n, "foo", True) + enter_i = gen_control_flow_ops._enter(zero, "foo", False) + enter_one = gen_control_flow_ops._enter(one, "foo", True) + enter_n = gen_control_flow_ops._enter(n, "foo", True) merge_i = control_flow_ops.merge([enter_i, enter_i])[0] @@ -322,8 +324,8 @@ class ControlFlowTest(test.TestCase): def testDifferentFrame(self): with self.test_session(): data = array_ops.placeholder(dtypes.float32, shape=[]) - enter_1 = control_flow_ops.enter(data, "foo_1", False) - enter_2 = control_flow_ops.enter(data, "foo_2", False) + enter_1 = gen_control_flow_ops._enter(data, "foo_1", False) + enter_2 = gen_control_flow_ops._enter(data, "foo_2", False) res = math_ops.add(enter_1, enter_2) with self.assertRaisesOpError("has inputs from different frames"): res.eval(feed_dict={data: 1.0}) diff --git a/tensorflow/python/kernel_tests/control_flow_util_test.py b/tensorflow/python/kernel_tests/control_flow_util_test.py index 39e96f74b0461da0cf499e303b30a4a41aae4899..23185eaeece0d56fd83ecdf9e02c778712420465 100644 --- a/tensorflow/python/kernel_tests/control_flow_util_test.py +++ b/tensorflow/python/kernel_tests/control_flow_util_test.py @@ -41,17 +41,17 @@ class ControlFlowUtilTest(test.TestCase): self.assertFalse(control_flow_util.IsSwitch(test_ops.int_output().op)) def testIsLoopEnter(self): - enter = gen_control_flow_ops.enter(1, frame_name="name").op + enter = gen_control_flow_ops._enter(1, frame_name="name").op self.assertTrue(control_flow_util.IsLoopEnter(enter)) self.assertFalse(control_flow_util.IsLoopConstantEnter(enter)) - ref_enter = gen_control_flow_ops.ref_enter(test_ops.ref_output(), - frame_name="name").op + ref_enter = gen_control_flow_ops._ref_enter(test_ops.ref_output(), + frame_name="name").op self.assertTrue(control_flow_util.IsLoopEnter(ref_enter)) self.assertFalse(control_flow_util.IsLoopConstantEnter(ref_enter)) - const_enter = gen_control_flow_ops.enter(1, frame_name="name", - is_constant=True).op + const_enter = gen_control_flow_ops._enter(1, frame_name="name", + is_constant=True).op self.assertTrue(control_flow_util.IsLoopEnter(const_enter)) self.assertTrue(control_flow_util.IsLoopConstantEnter(const_enter)) diff --git a/tensorflow/python/kernel_tests/extract_image_patches_op_test.py b/tensorflow/python/kernel_tests/extract_image_patches_op_test.py index 5c7624f1f6be4da91ca74d4ef2ed81a21890b35c..6ea9f1badc3b8fac06fe6328f95714b93de97c0e 100644 --- a/tensorflow/python/kernel_tests/extract_image_patches_op_test.py +++ b/tensorflow/python/kernel_tests/extract_image_patches_op_test.py @@ -84,7 +84,7 @@ class ExtractImagePatches(test.TestCase): patches=patches) def testKsize2x2Stride1x1Rate1x1Valid(self): - """Test for 1x1 kernel .""" + """Test for 2x2 kernel with VALID padding.""" # [1, 2, 2, 1] image = [[[[1], [2]], [[3], [4]]]] # [1, 1, 1, 4] @@ -98,7 +98,7 @@ class ExtractImagePatches(test.TestCase): patches=patches) def testKsize2x2Stride1x1Rate1x1Same(self): - """Test for 1x1 kernel .""" + """Test for 2x2 kernel with SAME padding.""" # [1, 2, 2, 1] image = [[[[1], [2]], [[3], [4]]]] # [1, 2, 2, 4] @@ -111,6 +111,20 @@ class ExtractImagePatches(test.TestCase): padding="SAME", patches=patches) + def testKsize2x2Stride1x1Rate2x2Valid(self): + """Test for 2x2 kernel with 2x2 dilation.""" + # [1, 2, 2, 1] + image = np.arange(16).reshape(1, 4, 4, 1).astype(np.float32) + # [1, 2, 2, 4] + patches = [[[[0, 2, 8, 10], [1, 3, 9, 11]], + [[4, 6, 12, 14], [5, 7, 13, 15]]]] + self._VerifyValues( + image, + ksizes=[2, 2], + strides=[1, 1], + rates=[2, 2], + padding="VALID", + patches=patches) if __name__ == "__main__": test.main() diff --git a/tensorflow/python/kernel_tests/io_ops_test.py b/tensorflow/python/kernel_tests/io_ops_test.py index f91875c6f0c1a7bfa388ec1b1a58f06b65889c3e..61944f7e3197844d00cbc001459e48b50c9003b4 100644 --- a/tensorflow/python/kernel_tests/io_ops_test.py +++ b/tensorflow/python/kernel_tests/io_ops_test.py @@ -1,4 +1,4 @@ -# -*- coding: utf-8 -*- +# -*- coding: utf-8 -*- # Copyright 2015 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/tensorflow/python/kernel_tests/linalg/BUILD b/tensorflow/python/kernel_tests/linalg/BUILD index 4e18eaa4e8281c799e4669b2d6083c00bc1e2863..fd1b5bab6f5aa072c8821eb053bd8d39391be4d4 100644 --- a/tensorflow/python/kernel_tests/linalg/BUILD +++ b/tensorflow/python/kernel_tests/linalg/BUILD @@ -39,6 +39,7 @@ cuda_py_test( "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", ], + shard_count = 5, tags = ["noasan"], # times out b/63678675 ) @@ -57,6 +58,7 @@ cuda_py_test( "//tensorflow/python:platform_test", "//tensorflow/python:random_ops", ], + shard_count = 5, ) cuda_py_test( @@ -73,6 +75,7 @@ cuda_py_test( "//tensorflow/python:platform_test", "//tensorflow/python:random_ops", ], + shard_count = 5, ) cuda_py_test( @@ -88,6 +91,7 @@ cuda_py_test( "//tensorflow/python:framework_test_lib", "//tensorflow/python:platform_test", ], + shard_count = 5, ) cuda_py_test( @@ -134,6 +138,7 @@ cuda_py_test( "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", ], + shard_count = 5, ) filegroup( diff --git a/tensorflow/python/kernel_tests/losses_test.py b/tensorflow/python/kernel_tests/losses_test.py index 81af3a0887d09a7736a145a5b3c99c9391691724..f1fbe1a745bcc851154121e6e2123b92bba6fec1 100644 --- a/tensorflow/python/kernel_tests/losses_test.py +++ b/tensorflow/python/kernel_tests/losses_test.py @@ -953,14 +953,14 @@ class MeanPairwiseSquaredErrorTest(test.TestCase): # Compute the expected loss 'manually'. total = np.zeros((batch_size,)) for b in range(batch_size): - for i in range(dims): - for j in range(dims): + for i in range(dims-1): + for j in range(i+1, dims): x = self._predictions[b, i].item() - self._predictions[b, j].item() y = self._labels[b, i].item() - self._labels[b, j].item() diff = (x - y) total[b] += (diff * diff) - self._expected_losses = np.divide(total, 9.0) + self._expected_losses = np.divide(total, 3.0) def testValueErrorThrownWhenWeightIsNone(self): with self.test_session(): @@ -1060,7 +1060,7 @@ class MeanPairwiseSquaredErrorTest(test.TestCase): [[8, 1, 3], [7, 8, 9], [10, 11, 12]], ]) self._test_valid_weights( - labels, predictions, expected_loss=122.22222) + labels, predictions, expected_loss=137.5) def test3dWeightedScalar(self): labels = np.array([ @@ -1073,7 +1073,7 @@ class MeanPairwiseSquaredErrorTest(test.TestCase): ]) weight = 3.0 self._test_valid_weights( - labels, predictions, expected_loss=weight * 122.22222, + labels, predictions, expected_loss=weight * 137.5, weights=weight) def _test_invalid_weights( @@ -1124,7 +1124,7 @@ class MeanPairwiseSquaredErrorTest(test.TestCase): ]) self._test_valid_weights( # TODO(ptucker): This doesn't look right. - labels, predictions, expected_loss=9 * 122.22222, + labels, predictions, expected_loss=9 * 137.5, weights=np.ones((2, 3, 3))) def testLossWithAllZeroBatchSpecificWeights(self): @@ -1345,6 +1345,34 @@ class ComputeWeightedLossTest(test.TestCase): self.assertAllClose( np.mean(self._raw_losses), unweighted_loss.eval()) + def testUnweightedFromPlaceholder(self): + for reduction in losses.Reduction.all(): + with ops.Graph().as_default() as g: + self.assertEqual(0, len(util.get_losses())) + raw_losses = array_ops.placeholder(dtype=dtypes.float32) + feed_dict = {raw_losses: self._raw_losses} + unweighted_losses = ( + losses.compute_weighted_loss(raw_losses, reduction=reduction), + losses.compute_weighted_loss( + raw_losses, weights=np.ones((1, 1, 1)), reduction=reduction), + losses.compute_weighted_loss( + raw_losses, weights=np.ones((1, 1, 4)), reduction=reduction), + ) + self.assertEqual(3, len(util.get_losses())) + with self.test_session(g): + for unweighted_loss in unweighted_losses: + if reduction == losses.Reduction.NONE: + self.assertAllClose( + self._raw_losses, unweighted_loss.eval(feed_dict)) + elif reduction == losses.Reduction.SUM: + self.assertAllClose( + np.sum(self._raw_losses), unweighted_loss.eval(feed_dict)) + else: + # reduction one of MEAN, SUM_OVER_NONZERO_WEIGHTS, + # SUM_BY_NONZERO_WEIGHTS or SUM_OVER_BATCH_SIZE. + self.assertAllClose( + np.mean(self._raw_losses), unweighted_loss.eval(feed_dict)) + def testScalarWeight(self): with ops.Graph().as_default(): self.assertEqual(0, len(util.get_losses())) diff --git a/tensorflow/python/kernel_tests/matrix_band_part_op_test.py b/tensorflow/python/kernel_tests/matrix_band_part_op_test.py index 317b8dc05beac7642c384bf89e6d154be50f6992..68d626de2c5cdd91ee332247c05ddce2a558a35e 100644 --- a/tensorflow/python/kernel_tests/matrix_band_part_op_test.py +++ b/tensorflow/python/kernel_tests/matrix_band_part_op_test.py @@ -21,6 +21,7 @@ import numpy as np from tensorflow.python.client import session from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes as dtypes_lib from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops @@ -54,9 +55,13 @@ def _GetMatrixBandPartTest(dtype_, batch_shape_, shape_): band_np = np.tril(band_np, upper) if batch_shape_ is not (): band_np = np.tile(band_np, batch_shape_ + (1, 1)) - with self.test_session(use_gpu=False): - band = array_ops.matrix_band_part(batch_mat, lower, upper) - self.assertAllEqual(band_np, band.eval()) + for index_dtype in [dtypes_lib.int32, dtypes_lib.int64]: + with self.test_session(use_gpu=False): + band = array_ops.matrix_band_part( + batch_mat, + constant_op.constant(lower, index_dtype), + constant_op.constant(upper, index_dtype)) + self.assertAllEqual(band_np, band.eval()) return Test diff --git a/tensorflow/python/kernel_tests/py_func_test.py b/tensorflow/python/kernel_tests/py_func_test.py index 92fb68820e04c3db1385296d91d956134b8ff2d4..c7181497d891f6d35a788c90bf59a0ce5a536328 100644 --- a/tensorflow/python/kernel_tests/py_func_test.py +++ b/tensorflow/python/kernel_tests/py_func_test.py @@ -396,66 +396,66 @@ class PyFuncTest(test.TestCase): @test_util.run_in_graph_and_eager_modes() def testEagerSingleOutputFloat32(self): - a = array_ops.ones((3, 3), dtype=dtypes.float32) - x = array_ops.ones((3, 1), dtype=dtypes.float32) - output = script_ops.eager_py_func(matmul, inp=[a, x], Tout=dtypes.float32) - with self.test_session(): + with test_util.device(use_gpu=True): + a = array_ops.ones((3, 3), dtype=dtypes.float32) + x = array_ops.ones((3, 1), dtype=dtypes.float32) + output = script_ops.eager_py_func(matmul, inp=[a, x], Tout=dtypes.float32) ret = self.evaluate(output) self.assertAllClose(ret, [[3.0], [3.0], [3.0]]) @test_util.run_in_graph_and_eager_modes() def testEagerArrayOutput(self): - a = array_ops.ones((3, 3), dtype=dtypes.int32) - x = array_ops.ones((3, 1), dtype=dtypes.int32) - output = script_ops.eager_py_func( - lambda a, x: [matmul(a, x)], inp=[a, x], Tout=[dtypes.int32]) - - with self.test_session(): + with test_util.device(use_gpu=True): + a = array_ops.ones((3, 3), dtype=dtypes.float32) + x = array_ops.ones((3, 1), dtype=dtypes.float32) + output = script_ops.eager_py_func( + lambda a, x: [matmul(a, x)], inp=[a, x], Tout=[dtypes.float32]) ret = self.evaluate(output) - self.assertAllEqual(ret, [[[3], [3], [3]]]) + self.assertAllEqual(ret, [[[3.0], [3.0], [3.0]]]) @test_util.run_in_graph_and_eager_modes() def testEagerReturnNone(self): + with test_util.device(use_gpu=True): + def no_return_value(): + return - def no_return_value(): - return - - output = script_ops.eager_py_func(no_return_value, inp=[], Tout=[]) - ret = self.evaluate(output) - if context.in_eager_mode(): - self.assertEquals(len(ret), 0) - else: - self.assertIsNone(ret) + output = script_ops.eager_py_func(no_return_value, inp=[], Tout=[]) + ret = self.evaluate(output) + if context.in_eager_mode(): + self.assertEquals(len(ret), 0) + else: + self.assertIsNone(ret) @test_util.run_in_graph_and_eager_modes() def testEagerPyFuncInDefun(self): + with test_util.device(use_gpu=True): + def wrapper(): + a = array_ops.ones((3, 3), dtype=dtypes.float32) + x = array_ops.ones((3, 1), dtype=dtypes.float32) + return script_ops.eager_py_func(matmul, inp=[a, x], Tout=dtypes.float32) - def wrapper(): - a = array_ops.ones((3, 3), dtype=dtypes.int32) - x = array_ops.ones((3, 1), dtype=dtypes.int32) - return script_ops.eager_py_func(matmul, inp=[a, x], Tout=dtypes.int32) - - wrapped = function.defun(wrapper) - ret = self.evaluate(wrapped()) - self.assertAllEqual(ret, [[3], [3], [3]]) + wrapped = function.defun(wrapper) + ret = self.evaluate(wrapped()) + self.assertAllEqual(ret, [[3.0], [3.0], [3.0]]) @test_util.run_in_graph_and_eager_modes() def testEagerExceptionHandling(self): - self._testExceptionHandling( - ValueError, errors.InvalidArgumentError, eager=True) - self._testExceptionHandling( - TypeError, errors.InvalidArgumentError, eager=True) - self._testExceptionHandling( - StopIteration, errors.OutOfRangeError, eager=True) - self._testExceptionHandling( - MemoryError, errors.ResourceExhaustedError, eager=True) - self._testExceptionHandling( - NotImplementedError, errors.UnimplementedError, eager=True) - - class WeirdError(Exception): - pass - - self._testExceptionHandling(WeirdError, errors.UnknownError, eager=True) + with test_util.device(use_gpu=True): + self._testExceptionHandling( + ValueError, errors.InvalidArgumentError, eager=True) + self._testExceptionHandling( + TypeError, errors.InvalidArgumentError, eager=True) + self._testExceptionHandling( + StopIteration, errors.OutOfRangeError, eager=True) + self._testExceptionHandling( + MemoryError, errors.ResourceExhaustedError, eager=True) + self._testExceptionHandling( + NotImplementedError, errors.UnimplementedError, eager=True) + + class WeirdError(Exception): + pass + + self._testExceptionHandling(WeirdError, errors.UnknownError, eager=True) if __name__ == "__main__": diff --git a/tensorflow/python/kernel_tests/resource_variable_ops_test.py b/tensorflow/python/kernel_tests/resource_variable_ops_test.py index b4b555591d054226210eb6af20036967b240928f..cd94579688130974f640d1dea2afadeadadd551b 100644 --- a/tensorflow/python/kernel_tests/resource_variable_ops_test.py +++ b/tensorflow/python/kernel_tests/resource_variable_ops_test.py @@ -36,6 +36,7 @@ from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.platform import test +from tensorflow.python.util import compat @test_util.with_c_api @@ -170,6 +171,17 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) self.assertEqual(self.evaluate(read), [[3]]) + def testScatterUpdateString(self): + handle = resource_variable_ops.var_handle_op( + dtype=dtypes.string, shape=[1, 1]) + self.evaluate(resource_variable_ops.assign_variable_op( + handle, constant_op.constant([["a"]], dtype=dtypes.string))) + self.evaluate(resource_variable_ops.resource_scatter_update( + handle, [0], constant_op.constant([["b"]], dtype=dtypes.string))) + read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.string) + self.assertEqual(compat.as_bytes(self.evaluate(read)[0][0]), + compat.as_bytes("b")) + # TODO(alive): get this to work in Eager mode. def testGPU(self): with self.test_session(use_gpu=True): diff --git a/tensorflow/python/kernel_tests/segment_reduction_ops_test.py b/tensorflow/python/kernel_tests/segment_reduction_ops_test.py index 5a54f448d092093db668570d055801f9f9cd0f9f..bbce6b7d47325b8209815230426672ec6894147f 100644 --- a/tensorflow/python/kernel_tests/segment_reduction_ops_test.py +++ b/tensorflow/python/kernel_tests/segment_reduction_ops_test.py @@ -46,7 +46,8 @@ class SegmentReductionHelper(test.TestCase): return constant_op.constant( np_values, shape=input_shape, dtype=dtype), np_values - def _segmentReduce(self, indices, x, op1, op2=None, num_segments=None): + def _segmentReduce(self, indices, x, op1, op2=None, num_segments=None, + initial_value=0): if not x.size: return np.array([]) indices = np.asarray(indices) @@ -64,13 +65,8 @@ class SegmentReductionHelper(test.TestCase): else: output[index] = x_flat[i] # zero initialize values that are still uncalcuated. - # output = [o if o is not None else np.zeros(slice_shape) for o in output] - if not op1 == np.max: - output = [o if o is not None else np.zeros(slice_shape) for o in output] - else: - zeroslice = np.zeros(slice_shape) - zeroslice.fill(dtype.min) - output = [o if o is not None else zeroslice for o in output] + initial_value_slice = np.ones(slice_shape) * initial_value + output = [o if o is not None else initial_value_slice for o in output] if op2 is not None: output = [op2(o) for o in output] output = [o.reshape(slice_shape) for o in output] @@ -82,6 +78,9 @@ class SegmentReductionHelper(test.TestCase): def _mean_reduce_op(self, x): return x[0] / x[1] if isinstance(x, tuple) else x + def _sqrt_n_reduce_op(self, x): + return x[0] / np.sqrt(x[1]) if isinstance(x, tuple) else x + class SegmentReductionOpTest(SegmentReductionHelper): @@ -244,27 +243,61 @@ class SegmentReductionOpTest(SegmentReductionHelper): self.assertAllClose(jacob_t, jacob_n) -class UnsortedSegmentSumTest(SegmentReductionHelper): +class UnsortedSegmentTest(SegmentReductionHelper): + + def __init__(self, methodName='runTest'): + # Each item is np_op1, np_op2, tf_op, initial_value functor + self.ops_list = [(np.add, None, + math_ops.unsorted_segment_sum, lambda t: 0), + (self._mean_cum_op, self._mean_reduce_op, + math_ops.unsorted_segment_mean, lambda t: 0), + (self._mean_cum_op, self._sqrt_n_reduce_op, + math_ops.unsorted_segment_sqrt_n, lambda t: 0), + (np.ndarray.__mul__, None, + math_ops.unsorted_segment_prod, lambda t: 1), + (np.minimum, None, + math_ops.unsorted_segment_min, lambda t: t.max), + (np.maximum, None, + math_ops.unsorted_segment_max, lambda t: t.min)] + + # A subset of ops has been enabled for complex numbers + self.complex_ops_list = [(np.add, None, + math_ops.unsorted_segment_sum, lambda t: 0)] + self.differentiable_dtypes = [dtypes_lib.float16, dtypes_lib.float32, + dtypes_lib.float64] + self.all_dtypes = (self.differentiable_dtypes + + [dtypes_lib.bfloat16, + dtypes_lib.int64, dtypes_lib.int32, + dtypes_lib.complex64, dtypes_lib.complex128]) + super(UnsortedSegmentTest, self).__init__(methodName=methodName) def testValues(self): - dtypes = [ - dtypes_lib.float32, dtypes_lib.float64, dtypes_lib.int64, - dtypes_lib.int32, dtypes_lib.complex64, dtypes_lib.complex128 - ] indices_flat = np.array([0, 4, 0, 8, 3, 8, 4, 7, 7, 3]) num_segments = 12 for indices in indices_flat, indices_flat.reshape(5, 2): shape = indices.shape + (2,) - for dtype in dtypes: - with self.test_session(use_gpu=True): - tf_x, np_x = self._input(shape, dtype=dtype) - np_ans = self._segmentReduce( - indices, np_x, np.add, op2=None, num_segments=num_segments) - s = math_ops.unsorted_segment_sum( - data=tf_x, segment_ids=indices, num_segments=num_segments) - tf_ans = s.eval() - self.assertAllClose(np_ans, tf_ans) - self.assertShapeEqual(np_ans, s) + for dtype in self.all_dtypes: + ops_list = self.complex_ops_list if dtype.is_complex else self.ops_list + tf_x, np_x = self._input(shape, dtype=dtype) + for use_gpu in [True, False]: + with self.test_session(use_gpu=True): + for np_op1, np_op2, tf_op, init_op in ops_list: + # sqrt_n doesn't support integers + if (np_op2 == self._sqrt_n_reduce_op and dtype.is_integer): + continue + # todo(philjd): enable this test once real_div supports bfloat16 + if (np_op2 in [self._sqrt_n_reduce_op, self._mean_reduce_op] and + dtype == dtypes_lib.bfloat16): + continue + np_ans = self._segmentReduce( + indices, np_x, np_op1, np_op2, num_segments=num_segments, + initial_value=init_op(dtype)) + s = tf_op(tf_x, segment_ids=indices, num_segments=num_segments) + tf_ans = s.eval() + if dtype is dtypes_lib.bfloat16: + tf_ans = tf_ans.astype(np.float32) + self.assertAllClose(np_ans, tf_ans) + self.assertShapeEqual(np_ans, s) def testNumSegmentsTypes(self): dtypes = [dtypes_lib.int32, dtypes_lib.int64] @@ -287,25 +320,51 @@ class UnsortedSegmentSumTest(SegmentReductionHelper): self.assertAllClose(np_ans, tf_ans) self.assertShapeEqual(np_ans, s) - def testGradientSegmentSum(self): + def testGradients(self): num_cols = 2 - indices_flat = np.array([0, 4, 0, 8, 3, 8, 4, 7, 7, 3]) + indices_flat = np.array([0, 4, 0, -1, 3, -1, 4, 7, 7, 3]) num_segments = max(indices_flat) + 3 - for dtype in [dtypes_lib.float32, dtypes_lib.float64, dtypes_lib.complex64, - dtypes_lib.complex128]: + for dtype in self.differentiable_dtypes: + ops_list = self.complex_ops_list if dtype.is_complex else self.ops_list for indices in indices_flat, indices_flat.reshape(5, 2): shape = indices.shape + (num_cols,) - with self.test_session(use_gpu=True): - tf_x, np_x = self._input(shape, dtype=dtype) - s = math_ops.unsorted_segment_sum( - data=tf_x, segment_ids=indices, num_segments=num_segments) + # test CPU and GPU as tf.gather behaves differently on each device + for use_gpu in [False, True]: + with self.test_session(use_gpu=use_gpu): + for _, _, tf_op, _ in ops_list: + tf_x, np_x = self._input(shape, dtype=dtype) + s = tf_op(tf_x, indices, num_segments) + jacob_t, jacob_n = gradient_checker.compute_gradient( + tf_x, + shape, + s, [num_segments, num_cols], + x_init_value=np_x, + delta=1) + self.assertAllClose(jacob_t, jacob_n) + + def testProdGrad(self): + # additional test for the prod gradient to ensure correct handling of zeros + values = np.array([0, 0, 1, 0, 2, 2, 3, 3, 3], dtype=np.float32) + indices = np.array([0, 0, 0, 1, 1, 1, 2, 2, 2], dtype=np.int32) + indices_neg = np.array([-1, 0, 0, -1, 1, 1, -1, 2, 2], dtype=np.int32) + values_tf = constant_op.constant(values) + # ground truth partial derivatives + gradients_indices = np.zeros((9, 3), dtype=np.float32) + gradients_indices_neg = np.zeros((9, 3), dtype=np.float32) + # the derivative w.r.t. to the other segments is zero, so here we only + # explicitly set the grad values for the corresponding segment + gradients_indices[range(9), indices] = [0, 0, 0, 4, 0, 0, 9, 9, 9] + gradients_indices_neg[range(9), indices_neg] = [0, 1, 0, 0, 2, 2, 0, 3, 3] + for use_gpu in [False, True]: + with self.test_session(use_gpu=use_gpu): + for ind, grad_gt in [(indices, gradients_indices), + (indices_neg, gradients_indices_neg)]: + s = math_ops.unsorted_segment_prod(values_tf, + constant_op.constant(ind), 3) jacob_t, jacob_n = gradient_checker.compute_gradient( - tf_x, - shape, - s, [num_segments, num_cols], - x_init_value=np_x, - delta=1) - self.assertAllClose(jacob_t, jacob_n) + values_tf, (9,), s, (3,), x_init_value=values, delta=1) + self.assertAllClose(jacob_t, jacob_n) + self.assertAllClose(jacob_t, grad_gt) def testGradientMatchesSegmentSum(self): # Strategy: compute the gradient for UnsortedSegmentSum and SegmentSum @@ -318,8 +377,7 @@ class UnsortedSegmentSumTest(SegmentReductionHelper): num_cols = 2 shape = [n, num_cols] num_segments = max(indices) + 1 - for dtype in [dtypes_lib.float32, dtypes_lib.float64, dtypes_lib.complex64, - dtypes_lib.complex128]: + for dtype in self.differentiable_dtypes: with self.test_session(use_gpu=True): tf_x, np_x = self._input(shape, dtype=dtype) # Results from UnsortedSegmentSum @@ -353,9 +411,8 @@ class UnsortedSegmentSumTest(SegmentReductionHelper): unsorted.eval() def testEmptySecondDimension(self): - dtypes = [ - np.float32, np.float64, np.int64, np.int32, np.complex64, np.complex128 - ] + dtypes = [np.float16, np.float32, np.float64, np.int64, np.int32, + np.complex64, np.complex128] with self.test_session(use_gpu=True): for dtype in dtypes: for itype in (np.int32, np.int64): @@ -364,36 +421,14 @@ class UnsortedSegmentSumTest(SegmentReductionHelper): unsorted = math_ops.unsorted_segment_sum(data, segment_ids, 2) self.assertAllEqual(unsorted.eval(), np.zeros((2, 0), dtype=dtype)) - def testGradientSegmentMax(self): - num_cols = 2 - indices_flat = np.array([0, 4, 0, 8, 3, 8, 4, 7, 7, 3]) - num_segments = max(indices_flat) + 3 - for indices in indices_flat, indices_flat.reshape(5, 2): - shape = indices.shape + (num_cols,) - with self.test_session(use_gpu=True): - tf_x, np_x = self._input(shape, dtype=dtypes_lib.float64) - s = math_ops.unsorted_segment_max( - data=tf_x, segment_ids=indices, num_segments=num_segments) - jacob_t, jacob_n = gradient_checker.compute_gradient( - tf_x, - shape, - s, - [num_segments, num_cols], - x_init_value=np_x.astype(np.double), delta=1) - self.assertAllClose(jacob_t, jacob_n) - def testDropNegatives(self): # Note: the test is done by replacing segment_ids with 8 to -1 # for index and replace values generated by numpy with 0. - dtypes = [ - dtypes_lib.float32, dtypes_lib.float64, dtypes_lib.int64, - dtypes_lib.int32, dtypes_lib.complex64, dtypes_lib.complex128 - ] indices_flat = np.array([0, 4, 0, 8, 3, 8, 4, 7, 7, 3]) num_segments = 12 for indices in indices_flat, indices_flat.reshape(5, 2): shape = indices.shape + (2,) - for dtype in dtypes: + for dtype in self.all_dtypes: with self.test_session(use_gpu=True): tf_x, np_x = self._input(shape, dtype=dtype) np_ans = self._segmentReduce( diff --git a/tensorflow/python/kernel_tests/tensordot_op_test.py b/tensorflow/python/kernel_tests/tensordot_op_test.py index 38205518b528b44313b1de83d06707b4498f061d..8ad29afd0a0f2e7fbaaf2bde956326e578466b1d 100644 --- a/tensorflow/python/kernel_tests/tensordot_op_test.py +++ b/tensorflow/python/kernel_tests/tensordot_op_test.py @@ -56,9 +56,11 @@ class TensordotTest(test_lib.TestCase): axes_ph = array_ops.placeholder(dtypes.int32) output = math_ops.tensordot(a_ph, b_ph, axes_ph) _ = sess.run( - [output], feed_dict={a_ph: a, - b_ph: b, - axes_ph: (a_axes, b_axes)}) + [output], feed_dict={ + a_ph: a, + b_ph: b, + axes_ph: (a_axes, b_axes) + }) def test_invalid_axes(self): a = [[1, 2], [3, 4]] @@ -81,28 +83,29 @@ class TensordotTest(test_lib.TestCase): with self.test_session() as sess: with self.assertRaises(errors_impl.InvalidArgumentError): _ = sess.run( - [output], feed_dict={a_ph: a, - b_ph: b, - axes_ph: axes_value}) + [output], feed_dict={ + a_ph: a, + b_ph: b, + axes_ph: axes_value + }) # Test case for 11950 def test_valid_axis(self): for axes_value in [1, 2], [[1], [2]], [[], []], 0: with self.test_session() as sess: - np_a = np.ones((3,3)) + np_a = np.ones((3, 3)) np_b = np.array([2, 3, 1])[None, None] np_ans = np.tensordot(np_a, np_b, axes_value) - tf_a = array_ops.ones((3,3), dtype=dtypes.float32) + tf_a = array_ops.ones((3, 3), dtype=dtypes.float32) tf_b = constant_op.constant([2, 3, 1], dtype=dtypes.float32)[None, None] tf_ans = math_ops.tensordot(tf_a, tf_b, axes_value).eval() self.assertAllEqual(tf_ans.shape, np_ans.shape) self.assertAllEqual(tf_ans, np_ans) - def test_partial_shape_inference(self): - for axes in ([1],[0]), 1: + for axes in ([1], [0]), 1: a = array_ops.placeholder(dtypes.float32) b = array_ops.placeholder(dtypes.float32) output = math_ops.tensordot(a, b, axes) @@ -169,9 +172,11 @@ def _get_tensordot_tests(dtype_, rank_a_, rank_b_, num_dims_, dynamic_shape_): axes = array_ops.placeholder(dtypes.int32) c = math_ops.tensordot(a, b, axes) tf_ans = sess.run( - c, feed_dict={a: a_np, - b: b_np, - axes: (a_dims_np, b_dims_np)}) + c, feed_dict={ + a: a_np, + b: b_np, + axes: (a_dims_np, b_dims_np) + }) else: tf_ans = math_ops.tensordot(a_np, b_np, (a_dims_np, b_dims_np)).eval() self.assertAllClose(tf_ans, np_ans, rtol=tol, atol=tol) diff --git a/tensorflow/python/kernel_tests/topk_op_test.py b/tensorflow/python/kernel_tests/topk_op_test.py index efb5b9f3641ceaebf1fd5285486b4a9bb93615cf..6ab931fdb97a8945ab610fda27a036693f0291e5 100644 --- a/tensorflow/python/kernel_tests/topk_op_test.py +++ b/tensorflow/python/kernel_tests/topk_op_test.py @@ -58,7 +58,7 @@ class TopKTest(test.TestCase): # Do some special casing of equality of indices: if indices # are not the same, but values are floating type, ensure that # the values are within epsilon of each other. - if not np.issubdtype(np_expected_values.dtype, np.float): + if not np.issubdtype(np_expected_values.dtype, np.floating): # Values are not floating point type; check indices exactly self.assertAllEqual(np_expected_indices, indices) else: diff --git a/tensorflow/python/layers/base.py b/tensorflow/python/layers/base.py index 5d9feb07b445ca86c17a7da2bcd8c1171f68d1a3..5dea732cbaa43a40f6a1bc4beef729f3b84dad5c 100644 --- a/tensorflow/python/layers/base.py +++ b/tensorflow/python/layers/base.py @@ -31,6 +31,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.layers import utils as layers_util +from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import variable_scope as vs from tensorflow.python.ops import variables as tf_variables @@ -139,9 +140,6 @@ class Layer(object): self._init_set_name(name) - # Holds functions for creating regularizer ops. - self._regularizer_factories = [] - # Determine variable scope. scope = kwargs.get('_scope') if scope: @@ -306,22 +304,6 @@ class Layer(object): inputs_hash = None return self._per_input_updates.get(inputs_hash, []) - def _get_regularizer_factories(self): - try: - # Some subclasses of Layer do not use its constructor. - return self._regularizer_factories - except AttributeError: - self._regularizer_factories = [] - return self._regularizer_factories - - def _maybe_create_variable_regularizers(self): - """Creates added but uninstantiated regularizers.""" - factories = self._get_regularizer_factories() - if factories: - for factory in factories: - factory() - factories[:] = [] - @property def losses(self): """Losses which are associated with this `Layer`. @@ -333,7 +315,6 @@ class Layer(object): Returns: A list of tensors. """ - self._maybe_create_variable_regularizers() if context.in_eager_mode(): # _losses may only contain variable regularization losses when executing # eagerly, and they have been saved as lambdas to be executed when @@ -417,7 +398,6 @@ class Layer(object): inputs_hash = layers_util.object_list_uid(inputs) else: inputs_hash = None - self._maybe_create_variable_regularizers() return self._per_input_losses.get(inputs_hash, []) def build(self, _): @@ -670,6 +650,7 @@ class Layer(object): else: scope_context_manager = vs.variable_scope( self._scope, reuse=self._reuse, auxiliary_name_scope=False) + input_shapes = None with scope_context_manager as scope: with ops.name_scope(self._name_scope_name(scope)): if not self.built: @@ -719,6 +700,9 @@ class Layer(object): else: # Deferred mode behavior: use `compute_output_shape` to # infer the number of outputs of the layer and their shapes. + if input_shapes is None: + input_shapes = nest.map_structure(lambda x: x.get_shape(), inputs) + output_shapes = self.compute_output_shape(input_shapes) output_shapes = nest.flatten(output_shapes) outputs = [ @@ -1414,7 +1398,10 @@ class _DeferredTensor(object): def __init__(self, shape, dtype, name=None): self.shape = tensor_shape.TensorShape(shape) - self.dtype = dtypes.as_dtype(dtype) + if dtype is None: + self.dtype = dtypes.as_dtype(np.float32) + else: + self.dtype = dtypes.as_dtype(dtype) self.name = name def get_shape(self): diff --git a/tensorflow/python/layers/network.py b/tensorflow/python/layers/network.py index 0a5dd57621b7dc06d9bc2d69c04cd8d6936fb7c8..745843975c487c34ab854ad2bb52b93e617fcaec 100644 --- a/tensorflow/python/layers/network.py +++ b/tensorflow/python/layers/network.py @@ -621,6 +621,11 @@ class GraphNetwork(base.Layer): A list of loss tensors. """ losses = [] + if context.in_eager_mode(): + for layer in self.layers: + losses += layer.losses + return losses + # Retrieve losses for all internal layers. for layer in self.layers: if hasattr(layer, 'losses'): @@ -853,7 +858,6 @@ class GraphNetwork(base.Layer): for node in nodes: # This is always a single layer, never a list. layer = node.outbound_layer - reference_input_tensors = node.input_tensors reference_output_tensors = node.output_tensors @@ -901,12 +905,13 @@ class GraphNetwork(base.Layer): else: output_masks = [None for _ in range(len(output_tensors))] - # Apply activity regularizer if any: - if layer.activity_regularizer is not None: - regularization_losses = [ - layer.activity_regularizer(x) for x in computed_tensors - ] - layer.add_loss(regularization_losses, computed_tensors) + if context.in_graph_mode(): + if layer.activity_regularizer is not None: + regularization_losses = [ + layer.activity_regularizer(x) for x in computed_tensors + ] + # Apply activity regularizer if any: + layer.add_loss(regularization_losses, computed_tensors) if context.in_graph_mode(): # Update model updates and losses: diff --git a/tensorflow/python/lib/core/py_func.cc b/tensorflow/python/lib/core/py_func.cc index d3bfa0ee337d1f606e5e994406969685a2986ab4..e0422ef80add42307268be2743e668eb8c8acb68 100644 --- a/tensorflow/python/lib/core/py_func.cc +++ b/tensorflow/python/lib/core/py_func.cc @@ -19,6 +19,7 @@ limitations under the License. #include "numpy/arrayobject.h" #include "tensorflow/c/eager/c_api.h" +#include "tensorflow/c/eager/c_api_internal.h" #include "tensorflow/c/tf_status_helper.h" #include "tensorflow/core/framework/allocation_description.pb.h" #include "tensorflow/core/framework/op_kernel.h" @@ -53,6 +54,12 @@ struct PyCall { // with this "token". string token; + // The device on which Tensors are stored; only used for EagerPyFunc. + Device* device; + + // True if and only if the op has been placed on a GPU. + bool gpu; + // True if the call is associated with an EagerPyFunc. bool eager; @@ -71,7 +78,12 @@ Status MakeArgTuple(const PyCall* call, PyObject** tuple) { PyObject* arg = nullptr; const Tensor& t = call->ins[i]; if (call->eager) { - arg = EagerTensorFromHandle(TFE_NewTensorHandle(t)); + if (call->gpu) { + arg = EagerTensorFromHandle(new TFE_TensorHandle(t, call->device)); + } else { + // TFE_TensorHandle assumes that CPU is identified by `nullptr`. + arg = EagerTensorFromHandle(new TFE_TensorHandle(t, nullptr)); + } if (arg == nullptr) { return errors::Internal("Unable to procure EagerTensor from Tensor."); } @@ -84,7 +96,8 @@ Status MakeArgTuple(const PyCall* call, PyObject** tuple) { } PyList_SetItem(lst, i, arg); } - *tuple = Py_BuildValue("(sN)", call->token.c_str(), lst); + *tuple = Py_BuildValue("(sON)", call->token.c_str(), + call->gpu ? Py_True : Py_False, lst); CHECK(*tuple); return Status::OK(); } @@ -150,15 +163,9 @@ bool IsSingleNone(PyObject* obj) { } // Retrieves a Tensor from `eager_tensor` and stores it in `output_tensor`. -Status ExtractTensorFromEagerTensor(const PyObject* eager_tensor, - Tensor* output_tensor, - TF_Status* tf_status) { - // TODO(akshayka): Lift the restriction requiring output tensors to - // lie in host memory; EagerPyFunc should be able to dispatch ops on GPU - // tensors, so we should eventually implement a GPU kernel for EagerPyFunc. - *output_tensor = *TFE_TensorHandleUnderlyingTensorInHostMemory( - EagerTensor_Handle(eager_tensor), tf_status); - return StatusFromTF_Status(tf_status); +void ExtractTensorFromEagerTensor(const PyObject* eager_tensor, + Tensor* output_tensor) { + *output_tensor = EagerTensor_Handle(eager_tensor)->t; } // Calls the registered py function through the trampoline. @@ -201,15 +208,23 @@ Status DoCallPyFunc(PyCall* call, bool* out_log_on_error) { } // Process the return values and convert them to TF Tensors. - Status s; + Status s = Status::OK(); if (PyList_Check(result)) { + // `result` is a Python list; if this operation is an `EagerPyFunc`, then + // every item in the list must be an `EagerTensor`; otherwise, every element + // must be a NumPy array. call->out.clear(); for (int i = 0; i < PyList_Size(result); ++i) { Tensor t; if (call->eager) { - auto tf_status = tensorflow::make_safe(TF_NewStatus()); - s = ExtractTensorFromEagerTensor(PyList_GetItem(result, i), &t, - tf_status.get()); + const PyObject* item = PyList_GetItem(result, i); + if (EagerTensor_CheckExact(item)) { + ExtractTensorFromEagerTensor(item, &t); + } else { + s = errors::FailedPrecondition( + "Expected EagerTensor, found PyObject of type: ", + Py_TYPE(item)->tp_name); + } } else { s = ConvertNdarrayToTensor(PyList_GetItem(result, i), &t); } @@ -220,16 +235,15 @@ Status DoCallPyFunc(PyCall* call, bool* out_log_on_error) { call->out.push_back(t); } } else if (EagerTensor_CheckExact(result) || result == Py_None) { + // result is an `EagerTensor` or `None`. DCHECK(call->eager); Tensor t; if (result != Py_None) { - auto tf_status = tensorflow::make_safe(TF_NewStatus()); - s = ExtractTensorFromEagerTensor(result, &t, tf_status.get()); - if (s.ok()) { - call->out.push_back(t); - } + ExtractTensorFromEagerTensor(result, &t); + call->out.push_back(t); } } else if (PyArray_Check(result)) { + // `result` is a NumPy array. DCHECK(!call->eager); if (!IsSingleNone(result)) { Tensor t; @@ -239,7 +253,7 @@ Status DoCallPyFunc(PyCall* call, bool* out_log_on_error) { } } } else { - s = errors::Internal("Unexpected pyobject is returned: ", + s = errors::Internal("Unexpected PyObject was returned: ", Py_TYPE(result)->tp_name); } Py_DECREF(result); @@ -429,12 +443,24 @@ class PyFuncOp : public OpKernel { explicit PyFuncOp(OpKernelConstruction* ctx) : OpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("token", &token_)); eager_ = type_string() == "EagerPyFunc"; + gpu_ = ctx->device_type().type_string() == DEVICE_GPU; } void Compute(OpKernelContext* ctx) override { PyCall call; call.token = token_; + call.gpu = gpu_; call.eager = eager_; + if (call.eager) { + // Eager's C API uses `Device`, whereas `OpKernelContext` stores a + // `DeviceBase`; attempt to downcast. + call.device = dynamic_cast(ctx->device()); + if (call.device == nullptr) { + ctx->CtxFailureWithWarning( + errors::Internal("Unrecognized device class")); + } + } + for (int i = 0; i < ctx->num_inputs(); ++i) { call.ins.push_back(ctx->input(i)); } @@ -476,6 +502,9 @@ class PyFuncOp : public OpKernel { private: string token_; + // True if and only if this op has been placed on a GPU. + bool gpu_; + // True if and only if this op should execute the python function eagerly, // i.e., if and only if the eager attribute is set. bool eager_; @@ -486,5 +515,6 @@ class PyFuncOp : public OpKernel { REGISTER_KERNEL_BUILDER(Name("PyFunc").Device(DEVICE_CPU), PyFuncOp); REGISTER_KERNEL_BUILDER(Name("PyFuncStateless").Device(DEVICE_CPU), PyFuncOp); REGISTER_KERNEL_BUILDER(Name("EagerPyFunc").Device(DEVICE_CPU), PyFuncOp); +REGISTER_KERNEL_BUILDER(Name("EagerPyFunc").Device(DEVICE_GPU), PyFuncOp); } // end namespace tensorflow diff --git a/tensorflow/python/lib/io/file_io.py b/tensorflow/python/lib/io/file_io.py index 4e3071d8513a28b02b70b290c4987bec92b3c32e..59f5075f177ef5335115cb4f24182d28a9b547c8 100644 --- a/tensorflow/python/lib/io/file_io.py +++ b/tensorflow/python/lib/io/file_io.py @@ -31,6 +31,7 @@ from tensorflow.python.framework import c_api_util from tensorflow.python.framework import errors from tensorflow.python.util import compat from tensorflow.python.util import deprecation +from tensorflow.python.util.tf_export import tf_export class FileIO(object): @@ -235,6 +236,7 @@ class FileIO(object): self._writable_file = None +@tf_export("gfile.Exists") def file_exists(filename): """Determines whether a path exists or not. @@ -256,6 +258,7 @@ def file_exists(filename): return True +@tf_export("gfile.Remove") def delete_file(filename): """Deletes the file located at 'filename'. @@ -306,6 +309,7 @@ def write_string_to_file(filename, file_content): f.write(file_content) +@tf_export("gfile.Glob") def get_matching_files(filename): """Returns a list of files that match the given pattern(s). @@ -336,6 +340,7 @@ def get_matching_files(filename): ] +@tf_export("gfile.MkDir") def create_dir(dirname): """Creates a directory with the name 'dirname'. @@ -353,6 +358,7 @@ def create_dir(dirname): pywrap_tensorflow.CreateDir(compat.as_bytes(dirname), status) +@tf_export("gfile.MakeDirs") def recursive_create_dir(dirname): """Creates a directory and all parent/intermediate directories. @@ -368,6 +374,7 @@ def recursive_create_dir(dirname): pywrap_tensorflow.RecursivelyCreateDir(compat.as_bytes(dirname), status) +@tf_export("gfile.Copy") def copy(oldpath, newpath, overwrite=False): """Copies data from oldpath to newpath. @@ -385,6 +392,7 @@ def copy(oldpath, newpath, overwrite=False): compat.as_bytes(oldpath), compat.as_bytes(newpath), overwrite, status) +@tf_export("gfile.Rename") def rename(oldname, newname, overwrite=False): """Rename or move a file / directory. @@ -426,6 +434,7 @@ def atomic_write_string_to_file(filename, contents, overwrite=True): raise +@tf_export("gfile.DeleteRecursively") def delete_recursively(dirname): """Deletes everything under dirname recursively. @@ -439,6 +448,7 @@ def delete_recursively(dirname): pywrap_tensorflow.DeleteRecursively(compat.as_bytes(dirname), status) +@tf_export("gfile.IsDirectory") def is_directory(dirname): """Returns whether the path is a directory or not. @@ -452,6 +462,7 @@ def is_directory(dirname): return pywrap_tensorflow.IsDirectory(compat.as_bytes(dirname), status) +@tf_export("gfile.ListDirectory") def list_directory(dirname): """Returns a list of entries contained within a directory. @@ -479,6 +490,7 @@ def list_directory(dirname): ] +@tf_export("gfile.Walk") def walk(top, in_order=True): """Recursive directory tree generator for directories. @@ -522,6 +534,7 @@ def walk(top, in_order=True): yield here +@tf_export("gfile.Stat") def stat(filename): """Returns file statistics for a given path. diff --git a/tensorflow/python/lib/io/tf_record.py b/tensorflow/python/lib/io/tf_record.py index df190100689bd864de78f5a2cf52b1ade081a789..48ea107a146c2714f7b59f53abbcd8b60dbf2fd4 100644 --- a/tensorflow/python/lib/io/tf_record.py +++ b/tensorflow/python/lib/io/tf_record.py @@ -22,8 +22,10 @@ from __future__ import print_function from tensorflow.python import pywrap_tensorflow from tensorflow.python.framework import errors from tensorflow.python.util import compat +from tensorflow.python.util.tf_export import tf_export +@tf_export("python_io.TFRecordCompressionType") class TFRecordCompressionType(object): """The type of compression for the record.""" NONE = 0 @@ -33,6 +35,7 @@ class TFRecordCompressionType(object): # NOTE(vrv): This will eventually be converted into a proto. to match # the interface used by the C++ RecordWriter. +@tf_export("python_io.TFRecordOptions") class TFRecordOptions(object): """Options used for manipulating TFRecord files.""" compression_type_map = { @@ -51,6 +54,7 @@ class TFRecordOptions(object): return cls.compression_type_map[options.compression_type] +@tf_export("python_io.tf_record_iterator") def tf_record_iterator(path, options=None): """An iterator that read the records from a TFRecords file. @@ -81,6 +85,7 @@ def tf_record_iterator(path, options=None): reader.Close() +@tf_export("python_io.TFRecordWriter") class TFRecordWriter(object): """A class to write records to a TFRecords file. diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py index 9541b097a94466861a83cb48ed3111563490cfba..ad409ad7e5a152bbc4312e1d16f324bb8be71c33 100644 --- a/tensorflow/python/ops/array_ops.py +++ b/tensorflow/python/ops/array_ops.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +# Tests for this file live in python/kernel_tests/array_ops_test.py """Support for manipulating tensors. See the @{$python/array_ops} guide. @@ -2451,8 +2452,8 @@ def _all_dimensions(x): r = x.dense_shape.get_shape()[0].value # sparse.dense_shape is 1-D. return constant_op.constant(np.arange(r), dtype=dtypes.int32) - # Otherwise, we rely on Range and Rank to do the right thing at run-time. - return range(0, rank(x)) + # Otherwise, we rely on `range` and `rank` to do the right thing at runtime. + return gen_math_ops._range(0, rank(x), 1) @tf_export("sequence_mask") @@ -2497,7 +2498,7 @@ def sequence_mask(lengths, maxlen=None, dtype=dtypes.bool, name=None): maxlen = gen_math_ops._max(lengths, _all_dimensions(lengths)) else: maxlen = ops.convert_to_tensor(maxlen) - if maxlen.get_shape().ndims != 0: + if maxlen.get_shape().ndims is not None and maxlen.get_shape().ndims != 0: raise ValueError("maxlen must be scalar for sequence_mask") # The basic idea is to compare a range row vector of size maxlen: diff --git a/tensorflow/python/ops/confusion_matrix.py b/tensorflow/python/ops/confusion_matrix.py index 50690cd891f73df1e345817b834ce6c361bff9e8..e4ce2ab28a15f82e80194ab17ef939411982076a 100644 --- a/tensorflow/python/ops/confusion_matrix.py +++ b/tensorflow/python/ops/confusion_matrix.py @@ -119,7 +119,7 @@ def confusion_matrix(labels, predictions, num_classes=None, dtype=dtypes.int32, For example: ```python - tf.contrib.metrics.confusion_matrix([1, 2, 4], [2, 2, 4]) ==> + tf.confusion_matrix([1, 2, 4], [2, 2, 4]) ==> [[0 0 0 0 0] [0 0 1 0 0] [0 0 1 0 0] diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py index 49191c647d59691a59aa5d7dd9cc9dac285b9fea..e75eb0843f5fea5d5d512845df1677485757e32a 100644 --- a/tensorflow/python/ops/control_flow_ops.py +++ b/tensorflow/python/ops/control_flow_ops.py @@ -50,12 +50,14 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import abc import collections import functools import six from six.moves import xrange # pylint: disable=redefined-builtin +from tensorflow.core.framework import attr_value_pb2 from tensorflow.core.protobuf import control_flow_pb2 from tensorflow.python.eager import context from tensorflow.python.framework import constant_op @@ -78,6 +80,7 @@ from tensorflow.python.ops import tensor_array_ops from tensorflow.python.ops.gen_control_flow_ops import * # pylint: enable=wildcard-import from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.util import compat from tensorflow.python.util import deprecation from tensorflow.python.util import nest from tensorflow.python.util import tf_should_use @@ -261,10 +264,10 @@ def _Enter(data, data = ops.internal_convert_to_tensor_or_indexed_slices(data, as_ref=True) if isinstance(data, ops.Tensor): if data.dtype._is_ref_dtype and use_ref: # pylint: disable=protected-access - result = ref_enter( + result = gen_control_flow_ops._ref_enter( data, frame_name, is_constant, parallel_iterations, name=name) else: - result = enter( + result = gen_control_flow_ops._enter( data, frame_name, is_constant, parallel_iterations, name=name) if use_input_shape: result.set_shape(data.get_shape()) @@ -279,7 +282,7 @@ def _Enter(data, parallel_iterations=parallel_iterations, use_input_shape=use_input_shape, name=name) - indices = enter( + indices = gen_control_flow_ops._enter( data.indices, frame_name, is_constant, @@ -290,7 +293,7 @@ def _Enter(data, if isinstance(data, ops.IndexedSlices): dense_shape = data.dense_shape if dense_shape is not None: - dense_shape = enter( + dense_shape = gen_control_flow_ops._enter( dense_shape, frame_name, is_constant, @@ -300,7 +303,7 @@ def _Enter(data, dense_shape.set_shape(data.dense_shape.get_shape()) return ops.IndexedSlices(values, indices, dense_shape) else: - dense_shape = enter( + dense_shape = gen_control_flow_ops._enter( data.dense_shape, frame_name, is_constant, @@ -1498,7 +1501,10 @@ class ControlFlowContext(object): """ def __init__(self, values_def=None, import_scope=None): + self._nested_contexts = [] self._outer_context = ops.get_default_graph()._get_control_flow_context() + if self._outer_context: + self._outer_context._nested_contexts.append(self) # pylint: disable=protected-access self._context_stack = [] if values_def: self._init_values_from_proto(values_def, import_scope=import_scope) @@ -1551,7 +1557,17 @@ class ControlFlowContext(object): def back_prop(self): raise NotImplementedError("Abstract method") - def _to_proto(self, export_scope=None): + @abc.abstractmethod + def to_control_flow_context_def(self, context_def, export_scope=None): + """Serializes this into `context_def`. + + Args: + context_def: a `ControlFlowContextDef` protocol buffer. + export_scope: Optional `string`. Name scope to remove. + """ + raise NotImplementedError("Abstract method") + + def _to_values_def(self, export_scope=None): """Converts the values to a `ValuesDef` protocol buffer. Args: @@ -1568,11 +1584,6 @@ class ControlFlowContext(object): values_def.external_values[k] = ops.strip_name_scope(v.name, export_scope) return values_def - @staticmethod - def _from_proto(values_def, import_scope=None): - """Returns a `ControlFlowContext` created from `values_def`.""" - return ControlFlowContext(values_def=values_def, import_scope=import_scope) - def AddName(self, name): self._values.add(name) @@ -1751,8 +1762,15 @@ class CondContext(ControlFlowContext): context_def.pivot_name = ops.strip_name_scope(self._pivot.name, export_scope) context_def.branch = self._branch - context_def.values_def.MergeFrom( - super(CondContext, self)._to_proto(export_scope)) + context_def.values_def.MergeFrom(super(CondContext, self)._to_values_def( + export_scope)) + # TODO(b/72868227): enable this once the corresponding control_flow.proto + # changes have been checked in (they aren't checked in and this is + # disabled for now to ensure forwards compatibility). + if False: # pylint: disable=using-constant-test + for nested in self._nested_contexts: + nested_def = context_def.nested_contexts.add() + nested.to_control_flow_context_def(nested_def) return context_def else: @@ -1761,7 +1779,21 @@ class CondContext(ControlFlowContext): @staticmethod def from_proto(context_def, import_scope=None): """Returns a `CondContext` object created from `context_def`.""" - return CondContext(context_def=context_def, import_scope=import_scope) + ret = CondContext(context_def=context_def, + import_scope=import_scope) + + # TODO(b/72868227): remove "if hasattr(...)" once the corresponding + # control_flow.proto changes have been checked in (they aren't checked in + # and this is here for now to ensure forwards compatibility). + if hasattr(context_def, "nested_contexts"): + ret.Enter() + for nested_def in context_def.nested_contexts: + from_control_flow_context_def(nested_def) + ret.Exit() + return ret + + def to_control_flow_context_def(self, context_def, export_scope=None): + context_def.cond_ctxt.CopyFrom(self.to_proto(export_scope=export_scope)) def AddValue(self, val): """Add `val` to the current context and its outer context recursively.""" @@ -2067,9 +2099,15 @@ def cond(pred, merges = [merge(pair)[0] for pair in zip(res_f_flat, res_t_flat)] merges = _convert_flows_to_tensorarrays(nest.flatten(orig_res_t), merges) - # Add to collections - ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_t) - ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_f) + # Only add non-nested conds to the collection. Any nested control flow will + # be encapsulated in the root context. + assert context_t.outer_context == context_f.outer_context + # TODO(b/72868227): remove "if True..." once the corresponding + # control_flow.proto changes have been checked in (they aren't checked in + # and this is disabled for now to ensure forwards compatibility). + if True or context_t.outer_context is None: + ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_t) + ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_f) merges = nest.pack_sequence_as(structure=orig_res_t, flat_sequence=merges) @@ -2206,6 +2244,17 @@ class WhileContext(ControlFlowContext): super(WhileContext, self).__init__( values_def=context_def.values_def, import_scope=import_scope) + # import_scope causes self.name to be different from the original serialized + # context's name. Rewrite "frame_name" attrs with the new name. + if import_scope: + for tensor_name in self._values: + op = g.as_graph_element(tensor_name).op + if util.IsLoopEnter(op): + # pylint: disable=protected-access + op._set_attr("frame_name", + attr_value_pb2.AttrValue(s=compat.as_bytes(self.name))) + # pylint: enable=protected-access + @property def maximum_iterations(self): """The maximum number of iterations that will be executed.""" @@ -2277,12 +2326,23 @@ class WhileContext(ControlFlowContext): ops.strip_name_scope(l.name, export_scope) for l in self._loop_enters ]) context_def.values_def.MergeFrom( - super(WhileContext, self)._to_proto(export_scope=export_scope)) + super(WhileContext, self)._to_values_def( + export_scope=export_scope)) + # TODO(b/72868227): remove "if True..." once the corresponding + # control_flow.proto changes have been checked in (they aren't checked in + # and this is disabled for now to ensure forwards compatibility). + if False: # pylint: disable=using-constant-test + for nested in self._nested_contexts: + nested_def = context_def.nested_contexts.add() + nested.to_control_flow_context_def(nested_def) return context_def else: return None + def to_control_flow_context_def(self, context_def, export_scope=None): + context_def.while_ctxt.CopyFrom(self.to_proto(export_scope=export_scope)) + @staticmethod def from_proto(context_def, import_scope=None): """Returns a `WhileContext` object created from `context_def`. @@ -2294,7 +2354,17 @@ class WhileContext(ControlFlowContext): Returns: A `WhileContext` Python object. """ - return WhileContext(context_def=context_def, import_scope=import_scope) + ret = WhileContext(context_def=context_def, + import_scope=import_scope) + # TODO(b/72868227): remove "if hasattr(...)" once the corresponding + # control_flow.proto changes have been checked in (they aren't checked in + # and this is disabled for now to ensure forwards compatibility). + if hasattr(context_def, "nested_contexts"): + ret.Enter() + for nested_def in context_def.nested_contexts: + from_control_flow_context_def(nested_def, import_scope=import_scope) + ret.Exit() + return ret def GetWhileContext(self): return self @@ -3092,7 +3162,13 @@ def while_loop(cond, parallel_iterations=parallel_iterations, back_prop=back_prop, swap_memory=swap_memory) - ops.add_to_collection(ops.GraphKeys.WHILE_CONTEXT, loop_context) + # Only add non-nested loops to the collection. Any nested control flow will + # be encapsulated in the root context. + # TODO(b/72868227): enable condition once the corresponding + # control_flow.proto changes have been checked in (they aren't checked in + # and this is disabled for now to ensure forwards compatibility). + if True or loop_context.outer_context is None: + ops.add_to_collection(ops.GraphKeys.WHILE_CONTEXT, loop_context) result = loop_context.BuildLoop(cond, body, loop_vars, shape_invariants) if maximum_iterations is not None: return result[1] @@ -3540,6 +3616,26 @@ class XLAControlFlowContext(ControlFlowContext): return x +def from_control_flow_context_def(context_def, import_scope=None): + """Deserializes `context_def` into the appropriate ControlFlowContext. + + Args: + context_def: ControlFlowContextDef proto + import_scope: Optional `string`. Name scope to add. + + Returns: + A ControlFlowContext subclass + """ + if context_def.HasField("cond_ctxt"): + return CondContext.from_proto(context_def.cond_ctxt, + import_scope=import_scope) + if context_def.HasField("while_ctxt"): + return WhileContext.from_proto(context_def.while_ctxt, + import_scope=import_scope) + raise NotImplementedError("Unknown ControlFlowContextDef field: %s" + % context_def.WhichOneof("ctxt")) + + ops.register_proto_function( ops.GraphKeys.COND_CONTEXT, proto_type=control_flow_pb2.CondContextDef, diff --git a/tensorflow/python/ops/control_flow_ops_test.py b/tensorflow/python/ops/control_flow_ops_test.py index cc5a42bf3ddd4b37d037f8d28a2fe6af79f79ba1..f942f478f25929699766b7ecbfb46a354ccc8fc5 100644 --- a/tensorflow/python/ops/control_flow_ops_test.py +++ b/tensorflow/python/ops/control_flow_ops_test.py @@ -483,8 +483,8 @@ class ContextTest(test_util.TensorFlowTestCase): c._values = ["a", "b"] c._external_values = {"a": b1} - c_with_scope = control_flow_ops.ControlFlowContext._from_proto( - c._to_proto(), import_scope="test_scope") + c_with_scope = control_flow_ops.ControlFlowContext( + values_def=c._to_values_def(), import_scope="test_scope") # _values and _external_values should be have scope prepended. self.assertEquals( @@ -494,8 +494,8 @@ class ContextTest(test_util.TensorFlowTestCase): # Calling _to_proto() with export_scope should remove "test_scope". self.assertProtoEquals( - c._to_proto(), - c_with_scope._to_proto(export_scope="test_scope")) + c._to_values_def(), + c_with_scope._to_values_def(export_scope="test_scope")) def _GetNestedShape(nested): diff --git a/tensorflow/python/ops/distributions/bijector_impl.py b/tensorflow/python/ops/distributions/bijector_impl.py index 44d64070ce48c0c115ea7edb1237124bc6698e90..ed435557fde7a2e8a0a4f7eef4e240daef0565e7 100644 --- a/tensorflow/python/ops/distributions/bijector_impl.py +++ b/tensorflow/python/ops/distributions/bijector_impl.py @@ -114,7 +114,7 @@ class _Mapping(collections.namedtuple( @six.add_metaclass(abc.ABCMeta) @tf_export("distributions.bijectors.Bijector") class Bijector(object): - """Interface for transformations of a `Distribution` sample. + r"""Interface for transformations of a `Distribution` sample. Bijectors can be used to represent any differentiable and injective (one to one) function defined on an open subset of `R^n`. Some non-injective @@ -122,27 +122,24 @@ class Bijector(object): #### Mathematical Details - A `Bijector` implements a - [diffeomorphism](https://en.wikipedia.org/wiki/Diffeomorphism), i.e., a - bijective, differentiable function. A `Bijector` is used by - `TransformedDistribution` but can be generally used for transforming a - `Distribution` generated `Tensor`. A `Bijector` is characterized by three - operations: - - 1. Forward Evaluation + A `Bijector` implements a [smooth covering map]( + https://en.wikipedia.org/wiki/Local_diffeomorphism), i.e., a local + diffeomorphism such that every point in the target has a neighborhood evenly + covered by a map ([see also]( + https://en.wikipedia.org/wiki/Covering_space#Covering_of_a_manifold)). + A `Bijector` is used by `TransformedDistribution` but can be generally used + for transforming a `Distribution` generated `Tensor`. A `Bijector` is + characterized by three operations: + 1. Forward\ Useful for turning one random outcome into another random outcome from a different distribution. - - 2. Inverse Evaluation - + 2. Inverse\ Useful for "reversing" a transformation to compute one probability in terms of another. - - 3. (log o det o Jacobian o inverse)(x) - + 3. `(log o det o Jacobian o inverse)(x)`\ "The log of the determinant of the matrix of all first-order partial - derivatives of the inverse function." + derivatives of the inverse function."\ Useful for inverting a transformation to compute one probability in terms of another. Geometrically, the det(Jacobian) is the volume of the transformation and is used to scale the probability. diff --git a/tensorflow/python/ops/image_ops_impl.py b/tensorflow/python/ops/image_ops_impl.py index cab1025df11d26064c3d2939598f1c58ab104736..22636fdbb3f5c512f25ad0c1d7eb4e18056da211 100644 --- a/tensorflow/python/ops/image_ops_impl.py +++ b/tensorflow/python/ops/image_ops_impl.py @@ -1691,7 +1691,8 @@ def rgb_to_yiq(images): images: tensor with the same shape as `images`. """ images = ops.convert_to_tensor(images, name='images') - kernel = ops.convert_to_tensor(_rgb_to_yiq_kernel, dtype=images.dtype, name='kernel') + kernel = ops.convert_to_tensor( + _rgb_to_yiq_kernel, dtype=images.dtype, name='kernel') ndims = images.get_shape().ndims return math_ops.tensordot(images, kernel, axes=[[ndims-1], [0]]) @@ -1717,7 +1718,8 @@ def yiq_to_rgb(images): images: tensor with the same shape as `images`. """ images = ops.convert_to_tensor(images, name='images') - kernel = ops.convert_to_tensor(_yiq_to_rgb_kernel, dtype=images.dtype, name='kernel') + kernel = ops.convert_to_tensor( + _yiq_to_rgb_kernel, dtype=images.dtype, name='kernel') ndims = images.get_shape().ndims return math_ops.tensordot(images, kernel, axes=[[ndims-1], [0]]) @@ -1742,7 +1744,8 @@ def rgb_to_yuv(images): images: tensor with the same shape as `images`. """ images = ops.convert_to_tensor(images, name='images') - kernel = ops.convert_to_tensor(_rgb_to_yuv_kernel, dtype=images.dtype, name='kernel') + kernel = ops.convert_to_tensor( + _rgb_to_yuv_kernel, dtype=images.dtype, name='kernel') ndims = images.get_shape().ndims return math_ops.tensordot(images, kernel, axes=[[ndims-1], [0]]) @@ -1768,7 +1771,8 @@ def yuv_to_rgb(images): images: tensor with the same shape as `images`. """ images = ops.convert_to_tensor(images, name='images') - kernel = ops.convert_to_tensor(_yuv_to_rgb_kernel, dtype=images.dtype, name='kernel') + kernel = ops.convert_to_tensor( + _yuv_to_rgb_kernel, dtype=images.dtype, name='kernel') ndims = images.get_shape().ndims return math_ops.tensordot(images, kernel, axes=[[ndims-1], [0]]) diff --git a/tensorflow/python/ops/image_ops_test.py b/tensorflow/python/ops/image_ops_test.py index 6a516a99110d2578c905b71020d9242e5636a72b..82b77ee8e3792596ec4c50ac24da1a1c38cc634b 100644 --- a/tensorflow/python/ops/image_ops_test.py +++ b/tensorflow/python/ops/image_ops_test.py @@ -3169,6 +3169,46 @@ class NonMaxSuppressionTest(test_util.TensorFlowTestCase): boxes, scores, max_output_size, iou_threshold).eval() self.assertAllClose(selected_indices, [3, 0, 5]) + def testInvalidShape(self): + # The boxes should be 2D of shape [num_boxes, 4]. + with self.assertRaisesRegexp( + ValueError, 'Shape must be rank 2 but is rank 1'): + boxes = constant_op.constant([0.0, 0.0, 1.0, 1.0]) + scores = constant_op.constant([0.9]) + selected_indices = image_ops.non_max_suppression( + boxes, scores, 3, 0.5) + + with self.assertRaisesRegexp( + ValueError, 'Dimension must be 4 but is 3'): + boxes = constant_op.constant([[0.0, 0.0, 1.0]]) + scores = constant_op.constant([0.9]) + selected_indices = image_ops.non_max_suppression( + boxes, scores, 3, 0.5) + + # The scores should be 1D of shape [num_boxes]. + with self.assertRaisesRegexp( + ValueError, 'Shape must be rank 1 but is rank 2'): + boxes = constant_op.constant([[0.0, 0.0, 1.0, 1.0]]) + scores = constant_op.constant([[0.9]]) + selected_indices = image_ops.non_max_suppression( + boxes, scores, 3, 0.5) + + # The max_output_size should be a scaler (0-D). + with self.assertRaisesRegexp( + ValueError, 'Shape must be rank 0 but is rank 1'): + boxes = constant_op.constant([[0.0, 0.0, 1.0, 1.0]]) + scores = constant_op.constant([0.9]) + selected_indices = image_ops.non_max_suppression( + boxes, scores, [3], 0.5) + + # The iou_threshold should be a scaler (0-D). + with self.assertRaisesRegexp( + ValueError, 'Shape must be rank 0 but is rank 2'): + boxes = constant_op.constant([[0.0, 0.0, 1.0, 1.0]]) + scores = constant_op.constant([0.9]) + selected_indices = image_ops.non_max_suppression( + boxes, scores, 3, [[0.5]]) + if __name__ == "__main__": googletest.main() diff --git a/tensorflow/python/ops/linalg/linalg_impl.py b/tensorflow/python/ops/linalg/linalg_impl.py index db33a08137e1d2508314c2d28bdbbb001198e6c1..a5096ffdd9ca1a50c26857ee40f02918fa90eda8 100644 --- a/tensorflow/python/ops/linalg/linalg_impl.py +++ b/tensorflow/python/ops/linalg/linalg_impl.py @@ -65,8 +65,8 @@ def logdet(matrix, name=None): ``` Args: - matrix: A `Tensor`. Must be `float32`, `float64`, `complex64`, or - `complex128` with shape `[..., M, M]`. + matrix: A `Tensor`. Must be `float16`, `float32`, `float64`, `complex64`, + or `complex128` with shape `[..., M, M]`. name: A name to give this `Op`. Defaults to `logdet`. Returns: @@ -99,8 +99,8 @@ def adjoint(matrix, name=None): # [3 - 3j, 6 - 6j]] Args: - matrix: A `Tensor`. Must be `float32`, `float64`, `complex64`, or - `complex128` with shape `[..., M, M]`. + matrix: A `Tensor`. Must be `float16`, `float32`, `float64`, `complex64`, + or `complex128` with shape `[..., M, M]`. name: A name to give this `Op` (optional). Returns: diff --git a/tensorflow/python/ops/linalg/linear_operator.py b/tensorflow/python/ops/linalg/linear_operator.py index 27e0f17020afa0fd44ec11c49b7a77d4426933dd..8339c940af8ae46ae3cf6ff59931c013863e66be 100644 --- a/tensorflow/python/ops/linalg/linear_operator.py +++ b/tensorflow/python/ops/linalg/linear_operator.py @@ -478,7 +478,6 @@ class LinearOperator(object): cond, self._max_condition_number_to_be_non_singular(), message="Singular matrix up to precision epsilon.") - raise NotImplementedError("assert_non_singular is not implemented.") def _max_condition_number_to_be_non_singular(self): """Return the maximum condition number that we consider nonsingular.""" diff --git a/tensorflow/python/ops/linalg/linear_operator_diag.py b/tensorflow/python/ops/linalg/linear_operator_diag.py index a4724d030f388230cf85cc68bf60b6553b409c17..2217bfd54593129712bc1e5a60bf2e1d5f88939b 100644 --- a/tensorflow/python/ops/linalg/linear_operator_diag.py +++ b/tensorflow/python/ops/linalg/linear_operator_diag.py @@ -121,8 +121,8 @@ class LinearOperatorDiag(linear_operator.LinearOperator): Args: diag: Shape `[B1,...,Bb, N]` `Tensor` with `b >= 0` `N >= 0`. - The diagonal of the operator. Allowed dtypes: `float32`, `float64`, - `complex64`, `complex128`. + The diagonal of the operator. Allowed dtypes: `float16`, `float32`, + `float64`, `complex64`, `complex128`. is_non_singular: Expect that this operator is non-singular. is_self_adjoint: Expect that this operator is equal to its hermitian transpose. If `diag.dtype` is real, this is auto-set to `True`. @@ -167,7 +167,12 @@ class LinearOperatorDiag(linear_operator.LinearOperator): def _check_diag(self, diag): """Static check of diag.""" allowed_dtypes = [ - dtypes.float32, dtypes.float64, dtypes.complex64, dtypes.complex128] + dtypes.float16, + dtypes.float32, + dtypes.float64, + dtypes.complex64, + dtypes.complex128, + ] dtype = diag.dtype if dtype not in allowed_dtypes: diff --git a/tensorflow/python/ops/linalg/linear_operator_full_matrix.py b/tensorflow/python/ops/linalg/linear_operator_full_matrix.py index dd4c7cb0413013f3f54f6085a7adcb523755a603..8fb59ca1a7e68450b71728c87e8a3c1f314e2ec4 100644 --- a/tensorflow/python/ops/linalg/linear_operator_full_matrix.py +++ b/tensorflow/python/ops/linalg/linear_operator_full_matrix.py @@ -114,7 +114,8 @@ class LinearOperatorFullMatrix(linear_operator.LinearOperator): Args: matrix: Shape `[B1,...,Bb, M, N]` with `b >= 0`, `M, N >= 0`. - Allowed dtypes: `float32`, `float64`, `complex64`, `complex128`. + Allowed dtypes: `float16`, `float32`, `float64`, `complex64`, + `complex128`. is_non_singular: Expect that this operator is non-singular. is_self_adjoint: Expect that this operator is equal to its hermitian transpose. @@ -147,7 +148,12 @@ class LinearOperatorFullMatrix(linear_operator.LinearOperator): def _check_matrix(self, matrix): """Static check of the `matrix` argument.""" allowed_dtypes = [ - dtypes.float32, dtypes.float64, dtypes.complex64, dtypes.complex128] + dtypes.float16, + dtypes.float32, + dtypes.float64, + dtypes.complex64, + dtypes.complex128, + ] matrix = ops.convert_to_tensor(matrix, name="matrix") diff --git a/tensorflow/python/ops/linalg/linear_operator_low_rank_update.py b/tensorflow/python/ops/linalg/linear_operator_low_rank_update.py index ad3bb2efa94bfa9751c31ff0c704aad8faa58ba7..36eed89db60a16adaf045e9fa332c38844c964d8 100644 --- a/tensorflow/python/ops/linalg/linear_operator_low_rank_update.py +++ b/tensorflow/python/ops/linalg/linear_operator_low_rank_update.py @@ -150,8 +150,8 @@ class LinearOperatorLowRankUpdate(linear_operator.LinearOperator): `is_X` matrix property hints, which will trigger the appropriate code path. Args: - base_operator: Shape `[B1,...,Bb, M, N]` real `float32` or `float64` - `LinearOperator`. This is `L` above. + base_operator: Shape `[B1,...,Bb, M, N]` real `float16`, `float32` or + `float64` `LinearOperator`. This is `L` above. u: Shape `[B1,...,Bb, M, K]` `Tensor` of same `dtype` as `base_operator`. This is `U` above. diag_update: Optional shape `[B1,...,Bb, K]` `Tensor` with same `dtype` @@ -188,7 +188,11 @@ class LinearOperatorLowRankUpdate(linear_operator.LinearOperator): # because if diag has non-zero imaginary part, it will not be # self-adjoint positive definite. dtype = base_operator.dtype - allowed_dtypes = [dtypes.float32, dtypes.float64] + allowed_dtypes = [ + dtypes.float16, + dtypes.float32, + dtypes.float64, + ] if dtype not in allowed_dtypes: raise TypeError( "Argument matrix must have dtype in %s. Found: %s" diff --git a/tensorflow/python/ops/linalg/linear_operator_lower_triangular.py b/tensorflow/python/ops/linalg/linear_operator_lower_triangular.py index 6ea55f0367bd55379b280f81f22df2c3a0dcfb1e..6419030755f59438081c15b8f1e04e3af59061bb 100644 --- a/tensorflow/python/ops/linalg/linear_operator_lower_triangular.py +++ b/tensorflow/python/ops/linalg/linear_operator_lower_triangular.py @@ -118,7 +118,8 @@ class LinearOperatorLowerTriangular(linear_operator.LinearOperator): Args: tril: Shape `[B1,...,Bb, N, N]` with `b >= 0`, `N >= 0`. The lower triangular part of `tril` defines this operator. The strictly - upper triangle is ignored. Allowed dtypes: `float32`, `float64`. + upper triangle is ignored. Allowed dtypes: `float16`, `float32`, + `float64`. is_non_singular: Expect that this operator is non-singular. This operator is non-singular if and only if its diagonal elements are all non-zero. @@ -164,7 +165,11 @@ class LinearOperatorLowerTriangular(linear_operator.LinearOperator): """Static check of the `tril` argument.""" # TODO(langmore) Add complex types once matrix_triangular_solve works for # them. - allowed_dtypes = [dtypes.float32, dtypes.float64] + allowed_dtypes = [ + dtypes.float16, + dtypes.float32, + dtypes.float64, + ] dtype = tril.dtype if dtype not in allowed_dtypes: raise TypeError( diff --git a/tensorflow/python/ops/losses/losses_impl.py b/tensorflow/python/ops/losses/losses_impl.py index 72508eb4350f57bb06b3829890f92554677c98d5..8b3c61b9339734d6a596d92e93f7a69d32dddd12 100644 --- a/tensorflow/python/ops/losses/losses_impl.py +++ b/tensorflow/python/ops/losses/losses_impl.py @@ -28,8 +28,10 @@ from tensorflow.python.ops import nn_ops from tensorflow.python.ops import weights_broadcast_ops from tensorflow.python.ops.losses import util from tensorflow.python.util.deprecation import deprecated_args +from tensorflow.python.util.tf_export import tf_export +@tf_export("losses.Reduction") class Reduction(object): """Types of loss reduction. @@ -149,9 +151,10 @@ def _num_present(losses, weights, per_batch=False): def _num_elements(losses): """Computes the number of elements in `losses` tensor.""" with ops.name_scope(None, "num_elements", values=[losses]) as scope: - return array_ops.size(losses, name=scope, out_type=losses.dtype) + return math_ops.cast(array_ops.size(losses, name=scope), dtype=losses.dtype) +@tf_export("losses.compute_weighted_loss") def compute_weighted_loss( losses, weights=1.0, scope=None, loss_collection=ops.GraphKeys.LOSSES, reduction=Reduction.SUM_BY_NONZERO_WEIGHTS): @@ -211,6 +214,7 @@ def compute_weighted_loss( return loss +@tf_export("losses.absolute_difference") def absolute_difference( labels, predictions, weights=1.0, scope=None, loss_collection=ops.GraphKeys.LOSSES, @@ -258,6 +262,7 @@ def absolute_difference( losses, weights, scope, loss_collection, reduction=reduction) +@tf_export("losses.cosine_distance") @deprecated_args(None, "dim is deprecated, use axis instead", "dim") def cosine_distance( labels, predictions, axis=None, weights=1.0, scope=None, @@ -311,6 +316,7 @@ def cosine_distance( losses, weights, scope, loss_collection, reduction=reduction) +@tf_export("losses.hinge_loss") def hinge_loss(labels, logits, weights=1.0, scope=None, loss_collection=ops.GraphKeys.LOSSES, reduction=Reduction.SUM_BY_NONZERO_WEIGHTS): @@ -352,6 +358,7 @@ def hinge_loss(labels, logits, weights=1.0, scope=None, losses, weights, scope, loss_collection, reduction=reduction) +@tf_export("losses.huber_loss") def huber_loss(labels, predictions, weights=1.0, delta=1.0, scope=None, loss_collection=ops.GraphKeys.LOSSES, reduction=Reduction.SUM_BY_NONZERO_WEIGHTS): @@ -420,6 +427,7 @@ def huber_loss(labels, predictions, weights=1.0, delta=1.0, scope=None, losses, weights, scope, loss_collection, reduction=reduction) +@tf_export("losses.log_loss") def log_loss(labels, predictions, weights=1.0, epsilon=1e-7, scope=None, loss_collection=ops.GraphKeys.LOSSES, reduction=Reduction.SUM_BY_NONZERO_WEIGHTS): @@ -471,6 +479,7 @@ def log_loss(labels, predictions, weights=1.0, epsilon=1e-7, scope=None, # TODO(b/37208492): Add reduction arg. +@tf_export("losses.mean_pairwise_squared_error") def mean_pairwise_squared_error( labels, predictions, weights=1.0, scope=None, loss_collection=ops.GraphKeys.LOSSES): @@ -538,12 +547,13 @@ def mean_pairwise_squared_error( num_present_per_batch = _num_present(diffs, weights, per_batch=True) term1 = 2.0 * _safe_div(sum_squares_diff_per_batch, - num_present_per_batch) + num_present_per_batch-1) sum_diff = math_ops.reduce_sum( diffs, reduction_indices=reduction_indices, keep_dims=True) - term2 = 2.0 * _safe_div(math_ops.square(sum_diff), - math_ops.square(num_present_per_batch)) + term2 = 2.0 * _safe_div( + math_ops.square(sum_diff), + math_ops.multiply(num_present_per_batch, num_present_per_batch-1)) weighted_losses = math_ops.multiply(term1 - term2, weights) loss = math_ops.reduce_sum(weighted_losses) @@ -557,6 +567,7 @@ def mean_pairwise_squared_error( return mean_loss +@tf_export("losses.mean_squared_error") def mean_squared_error( labels, predictions, weights=1.0, scope=None, loss_collection=ops.GraphKeys.LOSSES, @@ -604,6 +615,7 @@ def mean_squared_error( losses, weights, scope, loss_collection, reduction=reduction) +@tf_export("losses.sigmoid_cross_entropy") def sigmoid_cross_entropy( multi_class_labels, logits, weights=1.0, label_smoothing=0, scope=None, loss_collection=ops.GraphKeys.LOSSES, @@ -662,6 +674,7 @@ def sigmoid_cross_entropy( losses, weights, scope, loss_collection, reduction=reduction) +@tf_export("losses.softmax_cross_entropy") def softmax_cross_entropy( onehot_labels, logits, weights=1.0, label_smoothing=0, scope=None, loss_collection=ops.GraphKeys.LOSSES, @@ -771,6 +784,7 @@ def _remove_squeezable_dimensions( return labels, predictions, weights +@tf_export("losses.sparse_softmax_cross_entropy") def sparse_softmax_cross_entropy( labels, logits, weights=1.0, scope=None, loss_collection=ops.GraphKeys.LOSSES, diff --git a/tensorflow/python/ops/losses/util.py b/tensorflow/python/ops/losses/util.py index 3718c481c26afdd9f007ffc22a9e6ec44a1eb10e..b835d963869704f053de6c2f8a75ae1fa72e6a5d 100644 --- a/tensorflow/python/ops/losses/util.py +++ b/tensorflow/python/ops/losses/util.py @@ -30,8 +30,10 @@ from __future__ import print_function from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.ops import math_ops +from tensorflow.python.util.tf_export import tf_export +@tf_export("losses.add_loss") def add_loss(loss, loss_collection=ops.GraphKeys.LOSSES): """Adds a externally defined loss to the collection of losses. @@ -43,6 +45,7 @@ def add_loss(loss, loss_collection=ops.GraphKeys.LOSSES): ops.add_to_collection(loss_collection, loss) +@tf_export("losses.get_losses") def get_losses(scope=None, loss_collection=ops.GraphKeys.LOSSES): """Gets the list of losses from the loss_collection. @@ -56,6 +59,7 @@ def get_losses(scope=None, loss_collection=ops.GraphKeys.LOSSES): return ops.get_collection(loss_collection, scope) +@tf_export("losses.get_regularization_losses") def get_regularization_losses(scope=None): """Gets the list of regularization losses. @@ -68,6 +72,7 @@ def get_regularization_losses(scope=None): return ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES, scope) +@tf_export("losses.get_regularization_loss") def get_regularization_loss(scope=None, name="total_regularization_loss"): """Gets the total regularization loss. @@ -85,6 +90,7 @@ def get_regularization_loss(scope=None, name="total_regularization_loss"): return constant_op.constant(0.0) +@tf_export("losses.get_total_loss") def get_total_loss(add_regularization_losses=True, name="total_loss"): """Returns a tensor whose value represents the total loss. diff --git a/tensorflow/python/ops/math_grad.py b/tensorflow/python/ops/math_grad.py index 53308484c427e715f649a09f0dbe3f1448f18f5b..c6cc4e186074e71b8742e4aa5b69a699f77f250e 100644 --- a/tensorflow/python/ops/math_grad.py +++ b/tensorflow/python/ops/math_grad.py @@ -228,56 +228,142 @@ def _SparseSegmentSqrtNWithNumSegmentsGrad(op, grad): dim0), None, None, None) -def _SegmentMinOrMaxGrad(op, grad, is_sorted): - """Gradient for SegmentMin and (unsorted) SegmentMax. - - They share similar code. - """ - zeros = array_ops.zeros( - array_ops.shape(op.inputs[0]), dtype=op.inputs[0].dtype) - +def _SegmentMinOrMaxGrad(op, grad): + """ Gradient for SegmentMin and SegmentMax. """ + zeros = array_ops.zeros_like(op.inputs[0], dtype=op.inputs[0].dtype) # Get the number of selected (minimum or maximum) elements in each segment. gathered_outputs = array_ops.gather(op.outputs[0], op.inputs[1]) is_selected = math_ops.equal(op.inputs[0], gathered_outputs) - if is_sorted: - num_selected = math_ops.segment_sum( - math_ops.cast(is_selected, grad.dtype), op.inputs[1]) - else: - num_selected = math_ops.unsorted_segment_sum( - math_ops.cast(is_selected, grad.dtype), op.inputs[1], op.inputs[2]) - + num_selected = math_ops.segment_sum(math_ops.cast(is_selected, grad.dtype), + op.inputs[1]) # Compute the gradient for each segment. The gradient for the ith segment is # divided evenly among the selected elements in that segment. weighted_grads = math_ops.div(grad, num_selected) gathered_grads = array_ops.gather(weighted_grads, op.inputs[1]) - - if is_sorted: - return array_ops.where(is_selected, gathered_grads, zeros), None - else: - return array_ops.where(is_selected, gathered_grads, zeros), None, None + return array_ops.where(is_selected, gathered_grads, zeros), None @ops.RegisterGradient("SegmentMin") def _SegmentMinGrad(op, grad): """Gradient for SegmentMin.""" - return _SegmentMinOrMaxGrad(op, grad, True) + return _SegmentMinOrMaxGrad(op, grad) @ops.RegisterGradient("SegmentMax") def _SegmentMaxGrad(op, grad): """Gradient for SegmentMax.""" - return _SegmentMinOrMaxGrad(op, grad, True) + return _SegmentMinOrMaxGrad(op, grad) + + +def _GatherDropNegatives(params, ids, zero_clipped_indices=None, + is_positive=None): + """ Helper function for unsorted segment ops. Gathers params for + positive segment ids and gathers 0 for inputs with negative segment id. + Also returns the clipped indices and a boolean mask with the same shape + as ids where a positive id is masked as true. With this, the latter two + can be passed as arguments to this function to reuse them. + """ + if zero_clipped_indices is None: + zero_clipped_indices = math_ops.maximum(ids, array_ops.zeros_like(ids)) + gathered = array_ops.gather(params, zero_clipped_indices) + if is_positive is None: + is_positive = math_ops.greater_equal(ids, 0) + # tf.where(condition, x, y) requires condition to have the same shape as x + # and y. + # todo(philjd): remove this if tf.where supports broadcasting (#9284) + for _ in range(gathered.shape.ndims - is_positive.shape.ndims): + is_positive = array_ops.expand_dims(is_positive, -1) + is_positive = (is_positive & + array_ops.ones_like(gathered, dtype=dtypes.bool)) + # replace gathered params of negative indices with 0 + zero_slice = array_ops.zeros_like(gathered) + return (array_ops.where(is_positive, gathered, zero_slice), + zero_clipped_indices, is_positive) + + +def _UnsortedSegmentMinOrMaxGrad(op, grad): + """ Gradient for UnsortedSegmentMin and UnsortedSegmentMax. """ + # Get the number of selected (minimum or maximum) elements in each segment. + gathered_outputs, zero_clipped_indices, is_positive = \ + _GatherDropNegatives(op.outputs[0], op.inputs[1]) + is_selected = math_ops.equal(op.inputs[0], gathered_outputs) + is_selected = math_ops.logical_and(is_selected, is_positive) + num_selected = math_ops.unsorted_segment_sum( + math_ops.cast(is_selected, grad.dtype), op.inputs[1], op.inputs[2]) + # Compute the gradient for each segment. The gradient for the ith segment is + # divided evenly among the selected elements in that segment. + weighted_grads = math_ops.div(grad, num_selected) + gathered_grads, _, _ = _GatherDropNegatives(weighted_grads, None, + zero_clipped_indices, + is_positive) + zeros = array_ops.zeros_like(gathered_grads) + return array_ops.where(is_selected, gathered_grads, zeros), None, None @ops.RegisterGradient("UnsortedSegmentSum") def _UnsortedSegmentSumGrad(op, grad): - """Gradient for SegmentSum.""" - return array_ops.gather(grad, op.inputs[1]), None, None + """Gradient for UnsortedSegmentSum.""" + return _GatherDropNegatives(grad, op.inputs[1])[0], None, None @ops.RegisterGradient("UnsortedSegmentMax") def _UnsortedSegmentMaxGrad(op, grad): - return _SegmentMinOrMaxGrad(op, grad, False) + """ Gradient for UnsortedSegmentMax. """ + return _UnsortedSegmentMinOrMaxGrad(op, grad) + + +@ops.RegisterGradient("UnsortedSegmentMin") +def _UnsortedSegmentMinGrad(op, grad): + """ Gradient for UnsortedSegmentMin. """ + return _UnsortedSegmentMinOrMaxGrad(op, grad) + + +@ops.RegisterGradient("UnsortedSegmentProd") +def _UnsortedSegmentProdGrad(op, grad): + """ Gradient for UnsortedSegmentProd. + The gradient can be expressed for each segment by dividing the segment's + product by each element of the segment input tensor, but this approach can't + deal with zeros in the input. + Unlike reduce_prod we can't use cumsum here as individual segments may have + a different number of elements. Therefore we consider three cases: + 1) A segment input contains no zeros and we can safely divide by the input + tensor. + 2) A segment contains exactly one zero. Then the gradient of each input of + the segment is zero except for the 0-input, there the gradient is + the product of the remaining segment entries. + 3) A segment contains at least two zeros. The gradient is zero for all + segment inputs. + """ + # Note that unsorted_segment_sum will filter out the negative indices, + # so we don't need to do a logical_and with is_positive here + is_zero = math_ops.equal(op.inputs[0], 0) + num_zeros = gen_math_ops.unsorted_segment_sum( + math_ops.cast(is_zero, dtype=dtypes.int32), op.inputs[1], op.inputs[2]) + # handle case 3 and set the gradient to 0 for segments with more than one + # 0 as input + grad = array_ops.where(math_ops.greater(num_zeros, 1), + array_ops.zeros_like(grad), grad) + # replace all zeros with ones and compute the unsorted_segment_prod + non_zero_data = array_ops.where(is_zero, array_ops.ones_like(op.inputs[0]), + op.inputs[0]) + non_zero_prod = gen_math_ops.unsorted_segment_prod( + non_zero_data, op.inputs[1], op.inputs[2]) + # clip the indices for gather to be positive + zero_clipped_indices = math_ops.maximum(op.inputs[1], + array_ops.zeros_like(op.inputs[1])) + gathered_prod = array_ops.gather(op.outputs[0], zero_clipped_indices) + gathered_non_zero_prod = array_ops.gather(non_zero_prod, + zero_clipped_indices) + prod_divided_by_el = gathered_prod / op.inputs[0] # May contain nan/inf. + # Now fetch the individual results for segments containing 0 and those that + # don't. is_zero will also fetch results for entries with negative index + # but the following gather_drop_negatives sets the corresponding entry in + # grad to 0 for these + partial_derivative = array_ops.where(is_zero, gathered_non_zero_prod, + prod_divided_by_el) + gathered_grad = _GatherDropNegatives(grad, op.inputs[1], + zero_clipped_indices)[0] + return gathered_grad * partial_derivative, None, None @ops.RegisterGradient("Abs") diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py index 9a8ac93de9dcc12c513b5ddd07cca9d863d19b8a..aac72b331eaa00f65df1934fd1025c10a928ced8 100644 --- a/tensorflow/python/ops/math_ops.py +++ b/tensorflow/python/ops/math_ops.py @@ -131,6 +131,9 @@ See the @{$python/math_ops} guide. @@segment_mean @@unsorted_segment_sum @@unsorted_segment_max +@@unsorted_segment_min +@@unsorted_segment_prod +@@unsorted_segment_sqrt_n @@sparse_segment_sum @@sparse_segment_mean @@sparse_segment_sqrt_n @@ -2552,6 +2555,87 @@ def reduced_shape(input_shape, axes): ]) # [1, 1] +def _unsorted_segment_N(data, segment_ids, num_segments): + """ Helper function for unsorted_segment_mean/_sqrtN. Computes the number + of segment entries with 0-entries set to 1 to allow division by N. + """ + # bincount doesn't support negative indices so we use unsorted_segment_sum + ones_tensor = array_ops.ones(segment_ids.shape, dtype=data.dtype) + N = gen_math_ops.unsorted_segment_sum(ones_tensor, segment_ids, num_segments) + # add dimensions for all non-reduced axes + ndims_output = data.shape.ndims - segment_ids.shape.ndims + broadcast_shape = [num_segments] + [1] * ndims_output + N = array_ops.reshape(N, broadcast_shape) + return gen_math_ops.maximum(N, 1) + + +@tf_export("unsorted_segment_mean") +def unsorted_segment_mean(data, segment_ids, num_segments, name=None): + r""" Computes the mean along segments of a tensor. + + Read @{$math_ops#segmentation$the section on segmentation} for an explanation + of segments. + + This operator is similar to the unsorted segment sum operator found + [here](../../../api_docs/python/math_ops.md#UnsortedSegmentSum). + Instead of computing the sum over segments, it computes the mean of all + entries belonging to a segment such that: + + \\(output_i = 1/N_i \sum data_j\\) where the sum is over `j` such + that `segment_ids[j] == i` with \\N_i\\ being the number of occurrences + of id \\i\\. + + If there is no entry for a given segment ID `i`, it outputs 0. + + segment_ids: A 1-D tensor whose rank is equal to the rank of `data`'s + first dimension. + + output: Has same shape as data, except for dimension 0 which + has size `num_segments`. + """ + with ops.name_scope(name, "UnsortedSegmentMean"): + data = ops.convert_to_tensor(data) + segment_ids = ops.convert_to_tensor(segment_ids) + N = _unsorted_segment_N(data, segment_ids, num_segments) + summed = gen_math_ops.unsorted_segment_sum(data, segment_ids, num_segments) + return summed / N + + +@tf_export("unsorted_segment_sqrt_n") +def unsorted_segment_sqrt_n(data, segment_ids, num_segments, name=None): + r"""Computes the sum along segments of a tensor divided by the sqrt(N). + + Read @{$math_ops#segmentation$the section on segmentation} for an explanation + of segments. + + This operator is similar to the unsorted segment sum operator found + [here](../../../api_docs/python/math_ops.md#UnsortedSegmentSum). + Additionally to computing the sum over segments, it divides the results by + sqrt(N). + + \\(output_i = 1/sqrt(N_i) \sum data_j\\) where the sum is over `j` such + that `segment_ids[j] == i` with \\N_i\\ being the number of occurrences + of id \\i\\. + + If there is no entry for a given segment ID `i`, it outputs 0. + + Note that this op only supports floating point and complex dtypes, + due to tf.sqrt only supporting these types. + + segment_ids: A 1-D tensor whose rank is equal to the rank of `data`'s + first dimension. + + output: Has same shape as data, except for dimension 0 which + has size `num_segments`. + """ + with ops.name_scope(name, "UnsortedSegmentSqrtN"): + data = ops.convert_to_tensor(data) + segment_ids = ops.convert_to_tensor(segment_ids) + N = _unsorted_segment_N(data, segment_ids, num_segments) + summed = gen_math_ops.unsorted_segment_sum(data, segment_ids, num_segments) + return summed / gen_math_ops.sqrt(N) + + @tf_export("sparse_segment_sum") def sparse_segment_sum(data, indices, segment_ids, name=None, num_segments=None): diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py index 9c875b4bcb11f0c1c0ccd724f118ec2d864ef7ed..a691e281ee7f0ce4ee3253069b095a88df5723c5 100644 --- a/tensorflow/python/ops/nn_ops.py +++ b/tensorflow/python/ops/nn_ops.py @@ -2116,13 +2116,12 @@ def max_pool(value, ksize, strides, padding, data_format="NHWC", name=None): """ with ops.name_scope(name, "MaxPool", [value]) as name: value = ops.convert_to_tensor(value, name="input") - return gen_nn_ops._max_pool( - value, - ksize=ksize, - strides=strides, - padding=padding, - data_format=data_format, - name=name) + return gen_nn_ops._max_pool(value, + ksize=ksize, + strides=strides, + padding=padding, + data_format=data_format, + name=name) @ops.RegisterStatistics("Conv2D", "flops") diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py index bdf41cd75d6432750b7b23391c28892e2d6b9ffc..cc9f7981e4148cb117693fb3f83153dfba9c5895 100644 --- a/tensorflow/python/ops/resource_variable_ops.py +++ b/tensorflow/python/ops/resource_variable_ops.py @@ -348,9 +348,9 @@ class ResourceVariable(variables.Variable): if trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections: collections = list(collections) + [ops.GraphKeys.TRAINABLE_VARIABLES] self._save_slice_info = None - # Save the graph's container prefix for error checking. Reading the value of - # the ResourceVariable from another Graph in Eager mode is an error. - self._container_prefix = ops.get_default_graph()._container_prefix # pylint: disable=protected-access + # Store the graph key so optimizers know how to only retrieve variables from + # this graph. + self._graph_key = ops.get_default_graph()._graph_key # pylint: disable=protected-access with ops.init_scope(): self._in_graph_mode = context.in_graph_mode() with ops.name_scope(name, "Variable", [] @@ -662,15 +662,7 @@ class ResourceVariable(variables.Variable): Returns: the read operation. - Raises: - ValueError: if the ResourceVariable was created in another isolation - environment or graph. """ - if (not self._in_graph_mode and - self._container_prefix != ops.get_default_graph()._container_prefix): # pylint: disable=protected-access - raise ValueError( - "Attempted to read a variable from another isolation environment" - " or Graph") with ops.name_scope("Read"): # Ensure we read the variable in the same device as the handle. with ops.device(self._handle_device): diff --git a/tensorflow/python/ops/rnn.py b/tensorflow/python/ops/rnn.py index e0052b8869dd2cf331c14e2355d4b40dd217c561..da80e72071c095207dcfaf18681ecf5e5998e0d6 100644 --- a/tensorflow/python/ops/rnn.py +++ b/tensorflow/python/ops/rnn.py @@ -171,11 +171,11 @@ def _rnn_step( return (final_output, final_state) Args: - time: Python int, the current time step - sequence_length: int32 `Tensor` vector of size [batch_size] - min_sequence_length: int32 `Tensor` scalar, min of sequence_length - max_sequence_length: int32 `Tensor` scalar, max of sequence_length - zero_output: `Tensor` vector of shape [output_size] + time: int32 `Tensor` scalar. + sequence_length: int32 `Tensor` vector of size [batch_size]. + min_sequence_length: int32 `Tensor` scalar, min of sequence_length. + max_sequence_length: int32 `Tensor` scalar, max of sequence_length. + zero_output: `Tensor` vector of shape [output_size]. state: Either a single `Tensor` matrix of shape `[batch_size, state_size]`, or a list/tuple of such tensors. call_cell: lambda returning tuple of (new_output, new_state) where @@ -202,6 +202,9 @@ def _rnn_step( flat_state = nest.flatten(state) flat_zero_output = nest.flatten(zero_output) + # Vector describing which batch entries are finished. + copy_cond = time >= sequence_length + def _copy_one_through(output, new_output): # TensorArray and scalar get passed through. if isinstance(output, tensor_array_ops.TensorArray): @@ -209,7 +212,6 @@ def _rnn_step( if output.shape.ndims == 0: return new_output # Otherwise propagate the old or the new value. - copy_cond = (time >= sequence_length) with ops.colocate_with(new_output): return array_ops.where(copy_cond, output, new_output) diff --git a/tensorflow/python/ops/script_ops.py b/tensorflow/python/ops/script_ops.py index 4b5072fd6799ae289d3c1a1b2a40878e36604bf4..1b9071ee93c21f8d6bdc9ace11dbf57f3eb3e218 100644 --- a/tensorflow/python/ops/script_ops.py +++ b/tensorflow/python/ops/script_ops.py @@ -50,19 +50,21 @@ class EagerFunc(object): self._func = func self._out_dtypes = Tout - def __call__(self, *args, **kwargs): - """Passes args, kwargs to `self._func`, which is executed eagerly.""" + def __call__(self, on_gpu, args): + """Passes `args` to `self._func`, which is executed eagerly.""" with context.eager_mode(): - ret = self._func(*args, **kwargs) + ret = self._func(*args) + maybe_copy_to_gpu = lambda x: x if not on_gpu else x.gpu() if isinstance(ret, (tuple, list)): return [ - ops.convert_to_tensor(x, dtype=dtype) + maybe_copy_to_gpu(ops.convert_to_tensor(x, dtype=dtype)) for (x, dtype) in zip(ret, self._out_dtypes) ] elif ret is None: return ret else: - return ops.convert_to_tensor(ret, dtype=self._out_dtypes[0]) + return maybe_copy_to_gpu( + ops.convert_to_tensor(ret, dtype=self._out_dtypes[0])) class FuncRegistry(object): @@ -116,16 +118,29 @@ class FuncRegistry(object): else: return result - def __call__(self, token, args): - """Calls the registered function for `token` with args.""" + def __call__(self, token, on_gpu, args): + """Calls the registered function for `token` with args. + + Args: + token: A key into this `FuncRegistry` identifying which function to call. + on_gpu: A boolean indicating whether or not `token`'s corresponding + operation was placed on GPU; only used if the function registered for + `token` is an `EagerPyFunc`. + args: The arguments to pass to the function registered for `token`. + + Returns: + The output of the function registered for `token`. + + Raises: + ValueError: if no function is registered for `token`. + """ func = self._funcs[token] if func is None: raise ValueError("callback %s is not found" % token) - ret = func(*args) - if isinstance(func, EagerFunc): - return ret + return func(on_gpu, args) else: + ret = func(*args) # Strings seem to lead to a memory leak here if they're not wrapped in a # list. if isinstance(ret, six.binary_type): @@ -302,8 +317,5 @@ def py_func(func, inp, Tout, stateful=True, name=None): func=func, inp=inp, Tout=Tout, stateful=stateful, eager=False, name=name) -# TODO(akshayka): PyFuncs where the 'eager' attribute is set to True should be -# differentiable, i.e., the gradient of PyFunc should propagate Nones if the -# eager attribute is not set, and otherwise, it should return the gradient. ops.NotDifferentiable("PyFunc") ops.NotDifferentiable("PyFuncStateless") diff --git a/tensorflow/python/ops/state_ops.py b/tensorflow/python/ops/state_ops.py index 3cc76fdbf34ff6de47d98400cd826d671c9178eb..f00213eb88dce8e7bf73264a54780a704b4c4b18 100644 --- a/tensorflow/python/ops/state_ops.py +++ b/tensorflow/python/ops/state_ops.py @@ -278,7 +278,7 @@ def assign(ref, value, validate_shape=None, use_locking=None, name=None): return gen_state_ops.assign( ref, value, use_locking=use_locking, name=name, validate_shape=validate_shape) - return ref.assign(value) + return ref.assign(value, name=name) @tf_export("count_up_to") diff --git a/tensorflow/python/platform/app.py b/tensorflow/python/platform/app.py index 9b92d9a18005ca5e6be3820427e3a3ba60a8ec2d..cce64c0ccafc29a9d0d0b51b4c97c5673264657b 100644 --- a/tensorflow/python/platform/app.py +++ b/tensorflow/python/platform/app.py @@ -23,6 +23,7 @@ import sys as _sys from tensorflow.python.platform import flags from tensorflow.python.util.all_util import remove_undocumented +from tensorflow.python.util.tf_export import tf_export def _usage(shorthelp): @@ -108,6 +109,7 @@ def _define_help_flags(): _define_help_flags_called = True +@tf_export('app.run') def run(main=None, argv=None): """Runs the program with an optional 'main' function and 'argv' list.""" diff --git a/tensorflow/python/platform/resource_loader.py b/tensorflow/python/platform/resource_loader.py index 2455acb4c0c469acbb928c4ec44571e50e06de1f..8f7b12e2b2b92d9b2bfe397d0e7cba59e11bc1f6 100644 --- a/tensorflow/python/platform/resource_loader.py +++ b/tensorflow/python/platform/resource_loader.py @@ -29,8 +29,10 @@ import sys as _sys from tensorflow.python.util import tf_inspect as _inspect from tensorflow.python.util.all_util import remove_undocumented +from tensorflow.python.util.tf_export import tf_export +@tf_export('resource_loader.load_resource') def load_resource(path): """Load the resource at given path, where path is relative to tensorflow/. @@ -52,6 +54,7 @@ def load_resource(path): # pylint: disable=protected-access +@tf_export('resource_loader.get_data_files_path') def get_data_files_path(): """Get a direct path to the data files colocated with the script. @@ -62,6 +65,7 @@ def get_data_files_path(): return _os.path.dirname(_inspect.getfile(_sys._getframe(1))) +@tf_export('resource_loader.get_root_dir_with_all_resources') def get_root_dir_with_all_resources(): """Get a root directory containing all the data attributes in the build rule. @@ -101,6 +105,7 @@ def get_root_dir_with_all_resources(): return data_files_dir or script_dir +@tf_export('resource_loader.get_path_to_datafile') def get_path_to_datafile(path): """Get the path to the specified file in the data dependencies. @@ -120,6 +125,7 @@ def get_path_to_datafile(path): return _os.path.join(data_files_path, path) +@tf_export('resource_loader.readahead_file_path') def readahead_file_path(path, readahead='128M'): # pylint: disable=unused-argument """Readahead files not implemented; simply returns given path.""" return path diff --git a/tensorflow/python/platform/stacktrace_handler_test.py b/tensorflow/python/platform/stacktrace_handler_test.py index 3f0e534f4cbd97ecbd7db1fae3b48af72310c24f..f2071f9d54ceb99831999ec08ab71d63862f1c36 100644 --- a/tensorflow/python/platform/stacktrace_handler_test.py +++ b/tensorflow/python/platform/stacktrace_handler_test.py @@ -57,7 +57,8 @@ class StacktraceHandlerTest(test.TestCase): # Capture its output. capture both stdout and stderr and append them. # We are not worried about timing or order of messages in this test. - child_output = child_process.stdout.read() + child_process.stderr.read() + child_stdout, child_stderr = child_process.communicate() + child_output = child_stdout + child_stderr # Make sure the child process is dead before we proceed. child_process.wait() diff --git a/tensorflow/python/platform/tf_logging.py b/tensorflow/python/platform/tf_logging.py index 85ed4f071c7022801f20db75d538e5917b8eea66..22aabfd7121ac9b2eebeae2693f174e044d504ef 100644 --- a/tensorflow/python/platform/tf_logging.py +++ b/tensorflow/python/platform/tf_logging.py @@ -35,6 +35,7 @@ import threading import six from tensorflow.python.util.all_util import remove_undocumented +from tensorflow.python.util.tf_export import tf_export # Don't use this directly. Use _get_logger() instead. @@ -90,30 +91,37 @@ def _get_logger(): _logger_lock.release() +@tf_export('logging.log') def log(level, msg, *args, **kwargs): _get_logger().log(level, msg, *args, **kwargs) +@tf_export('logging.debug') def debug(msg, *args, **kwargs): _get_logger().debug(msg, *args, **kwargs) +@tf_export('logging.error') def error(msg, *args, **kwargs): _get_logger().error(msg, *args, **kwargs) +@tf_export('logging.fatal') def fatal(msg, *args, **kwargs): _get_logger().fatal(msg, *args, **kwargs) +@tf_export('logging.info') def info(msg, *args, **kwargs): _get_logger().info(msg, *args, **kwargs) +@tf_export('logging.warn') def warn(msg, *args, **kwargs): _get_logger().warn(msg, *args, **kwargs) +@tf_export('logging.warning') def warning(msg, *args, **kwargs): _get_logger().warning(msg, *args, **kwargs) @@ -136,15 +144,18 @@ _log_prefix = None # later set to google2_log_prefix _log_counter_per_token = {} +@tf_export('logging.TaskLevelStatusMessage') def TaskLevelStatusMessage(msg): error(msg) +@tf_export('logging.flush') def flush(): raise NotImplementedError() # Code below is taken from pyglib/logging +@tf_export('logging.vlog') def vlog(level, msg, *args, **kwargs): _get_logger().log(level, msg, *args, **kwargs) @@ -164,6 +175,7 @@ def _GetNextLogCountPerToken(token): return _log_counter_per_token[token] +@tf_export('logging.log_every_n') def log_every_n(level, msg, n, *args): """Log 'msg % args' at level 'level' once per 'n' times. @@ -180,6 +192,7 @@ def log_every_n(level, msg, n, *args): log_if(level, msg, not (count % n), *args) +@tf_export('logging.log_first_n') def log_first_n(level, msg, n, *args): # pylint: disable=g-bad-name """Log 'msg % args' at level 'level' only first 'n' times. @@ -195,6 +208,7 @@ def log_first_n(level, msg, n, *args): # pylint: disable=g-bad-name log_if(level, msg, count < n, *args) +@tf_export('logging.log_if') def log_if(level, msg, condition, *args): """Log 'msg % args' at level 'level' only if condition is fulfilled.""" if condition: @@ -251,11 +265,13 @@ def google2_log_prefix(level, timestamp=None, file_and_line=None): return s +@tf_export('logging.get_verbosity') def get_verbosity(): """Return how much logging output will be produced.""" return _get_logger().getEffectiveLevel() +@tf_export('logging.set_verbosity') def set_verbosity(v): """Sets the threshold for what messages will be logged.""" _get_logger().setLevel(v) @@ -296,4 +312,10 @@ _allowed_symbols = [ 'warning', ] +tf_export('logging.DEBUG').export_constant(__name__, 'DEBUG') +tf_export('logging.ERROR').export_constant(__name__, 'ERROR') +tf_export('logging.FATAL').export_constant(__name__, 'FATAL') +tf_export('logging.INFO').export_constant(__name__, 'INFO') +tf_export('logging.WARN').export_constant(__name__, 'WARN') + remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/python/profiler/model_analyzer.py b/tensorflow/python/profiler/model_analyzer.py index 8f780545607f7ba2337c83ad2c3740f542b802f6..0e20ca35bba606079ed5b0f225dd3029772b5af3 100644 --- a/tensorflow/python/profiler/model_analyzer.py +++ b/tensorflow/python/profiler/model_analyzer.py @@ -33,6 +33,7 @@ from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.profiler import option_builder from tensorflow.python.profiler import tfprof_logger +from tensorflow.python.util.tf_export import tf_export _DEFAULT_PROFILE_OPTIONS = 0 _DEFAULT_ADVISE_OPTIONS = 0 @@ -121,6 +122,7 @@ def _build_advisor_options(options): return opts +@tf_export('profiler.Profiler') class Profiler(object): """TensorFlow multi-step profiler. @@ -304,6 +306,7 @@ class Profiler(object): print_mdl.WriteProfile(filename) +@tf_export('profiler.profile') def profile(graph=None, run_meta=None, op_log=None, @@ -378,6 +381,7 @@ def profile(graph=None, return tfprof_node +@tf_export('profiler.advise') def advise(graph=None, run_meta=None, options=_DEFAULT_ADVISE_OPTIONS): """Auto profile and advise. diff --git a/tensorflow/python/profiler/option_builder.py b/tensorflow/python/profiler/option_builder.py index 13942ad6a2adc1f1d1cad778ebd280d358f64a59..957ebe6dddc26118024f71cadef861f38f1805e0 100644 --- a/tensorflow/python/profiler/option_builder.py +++ b/tensorflow/python/profiler/option_builder.py @@ -20,8 +20,10 @@ from __future__ import print_function import copy from tensorflow.python.profiler import tfprof_logger +from tensorflow.python.util.tf_export import tf_export +@tf_export('profiler.ProfileOptionBuilder') class ProfileOptionBuilder(object): # pylint: disable=line-too-long """Option Builder for Profiling API. diff --git a/tensorflow/python/profiler/tfprof_logger.py b/tensorflow/python/profiler/tfprof_logger.py index ffda7ddad759ce68bf718bcfa6e568cfadd59b53..8d121064967f2f87cd0aefaa361bfd6f387a3e6e 100644 --- a/tensorflow/python/profiler/tfprof_logger.py +++ b/tensorflow/python/profiler/tfprof_logger.py @@ -30,6 +30,7 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.platform import gfile from tensorflow.python.profiler.internal import flops_registry # pylint: disable=unused-import +from tensorflow.python.util.tf_export import tf_export TRAINABLE_VARIABLES = '_trainable_variables' REGISTERED_FLOP_STATS = 'flops' @@ -187,6 +188,7 @@ def merge_default_with_oplog(graph, op_log=None, run_meta=None, return tmp_op_log +@tf_export('profiler.write_op_log') def write_op_log(graph, log_dir, op_log=None, run_meta=None, add_trace=True): """Log provided 'op_log', and add additional model information below. diff --git a/tensorflow/python/saved_model/loader_impl.py b/tensorflow/python/saved_model/loader_impl.py index 5ff954fd9f83989565e007cad3f0f66913e0a4dd..6e85df0cbf5623691d38b030036958e5955399ee 100644 --- a/tensorflow/python/saved_model/loader_impl.py +++ b/tensorflow/python/saved_model/loader_impl.py @@ -232,13 +232,9 @@ def load(sess, tags, export_dir, **saver_kwargs): asset_tensors_dictionary = _get_asset_tensors(export_dir, meta_graph_def_to_load) - main_op_tensor = _get_main_op_tensor(meta_graph_def_to_load) + main_op_tensor = (_get_main_op_tensor(meta_graph_def_to_load) or + (_get_legacy_init_op_tensor(meta_graph_def_to_load))) if main_op_tensor is not None: sess.run(fetches=[main_op_tensor], feed_dict=asset_tensors_dictionary) - else: - legacy_init_op_tensor = _get_legacy_init_op_tensor(meta_graph_def_to_load) - if legacy_init_op_tensor is not None: - sess.run( - fetches=[legacy_init_op_tensor], feed_dict=asset_tensors_dictionary) return meta_graph_def_to_load diff --git a/tensorflow/python/saved_model/saved_model_test.py b/tensorflow/python/saved_model/saved_model_test.py index 1ea619ff55dea00f8ee09024ab45dcd324a2ddce..f92247d52e4150b2347a95e84d4bbf9c6ffc258a 100644 --- a/tensorflow/python/saved_model/saved_model_test.py +++ b/tensorflow/python/saved_model/saved_model_test.py @@ -54,8 +54,14 @@ def tearDownModule(): file_io.delete_recursively(test.get_temp_dir()) +@test_util.with_c_api class SavedModelTest(test.TestCase): + def _get_export_dir(self, label): + if ops._USE_C_API: + label += "_c_api" + return os.path.join(test.get_temp_dir(), label) + def _init_and_validate_variable(self, sess, variable_name, variable_value): v = variables.Variable(variable_value, name=variable_name) sess.run(variables.global_variables_initializer()) @@ -123,8 +129,7 @@ class SavedModelTest(test.TestCase): self.assertFalse(loader.maybe_saved_model_directory(base_path)) def testBadSavedModelFileFormat(self): - export_dir = os.path.join(test.get_temp_dir(), - "test_bad_saved_model_file_format") + export_dir = self._get_export_dir("test_bad_saved_model_file_format") # Attempt to load a SavedModel from an export directory that does not exist. with self.test_session(graph=ops.Graph()) as sess: with self.assertRaisesRegexp(IOError, @@ -157,8 +162,7 @@ class SavedModelTest(test.TestCase): loader.load(sess, ["foo"], export_dir) def testVerifySessionGraphUsage(self): - export_dir = os.path.join(test.get_temp_dir(), - "test_verify_session_graph_usage") + export_dir = self._get_export_dir("test_verify_session_graph_usage") builder = saved_model_builder.SavedModelBuilder(export_dir) with self.test_session(graph=ops.Graph()) as sess: @@ -178,7 +182,7 @@ class SavedModelTest(test.TestCase): 42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval()) def testSequence(self): - export_dir = os.path.join(test.get_temp_dir(), "test_sequence") + export_dir = self._get_export_dir("test_sequence") builder = saved_model_builder.SavedModelBuilder(export_dir) # Expect an assertion error since add_meta_graph_and_variables() should be @@ -195,7 +199,7 @@ class SavedModelTest(test.TestCase): sess, ["baz"]) def testTags(self): - export_dir = os.path.join(test.get_temp_dir(), "test_tags") + export_dir = self._get_export_dir("test_tags") builder = saved_model_builder.SavedModelBuilder(export_dir) # Graph with a single variable. SavedModel invoked to: @@ -284,7 +288,7 @@ class SavedModelTest(test.TestCase): export_dir) def testVariables(self): - export_dir = os.path.join(test.get_temp_dir(), "test_variables") + export_dir = self._get_export_dir("test_variables") builder = saved_model_builder.SavedModelBuilder(export_dir) # Graph with two variables. SavedModel invoked to: @@ -336,7 +340,7 @@ class SavedModelTest(test.TestCase): export_dir) def testGraphWithoutVariables(self): - export_dir = os.path.join(test.get_temp_dir(), "test_graph_has_variables") + export_dir = self._get_export_dir("test_graph_has_variables") builder = saved_model_builder.SavedModelBuilder(export_dir) # Graph with no variables. @@ -371,7 +375,7 @@ class SavedModelTest(test.TestCase): self.assertEqual(30.0, sess.run(c)) def testNoOverwrite(self): - export_dir = os.path.join(test.get_temp_dir(), "test_no_overwrite") + export_dir = self._get_export_dir("test_no_overwrite") builder = saved_model_builder.SavedModelBuilder(export_dir) # Graph with a single variable. SavedModel invoked to: @@ -395,7 +399,7 @@ class SavedModelTest(test.TestCase): export_dir) def testSaveAsText(self): - export_dir = os.path.join(test.get_temp_dir(), "test_astext") + export_dir = self._get_export_dir("test_astext") builder = saved_model_builder.SavedModelBuilder(export_dir) # Graph with a single variable. SavedModel invoked to: @@ -426,7 +430,7 @@ class SavedModelTest(test.TestCase): 42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval()) def testCollections(self): - export_dir = os.path.join(test.get_temp_dir(), "test_collections") + export_dir = self._get_export_dir("test_collections") builder = saved_model_builder.SavedModelBuilder(export_dir) # Graph with a single variable added to a collection. SavedModel invoked to: @@ -476,7 +480,7 @@ class SavedModelTest(test.TestCase): self.assertEqual(len(ops.get_collection("foo_vars")), 0) def testSignatureDefs(self): - export_dir = os.path.join(test.get_temp_dir(), "test_signature_defs") + export_dir = self._get_export_dir("test_signature_defs") builder = saved_model_builder.SavedModelBuilder(export_dir) # Graph with a single variable and a single entry in the signature def map. @@ -536,8 +540,7 @@ class SavedModelTest(test.TestCase): self.assertEqual("foo_new", bar_signature["foo_key"].method_name) def testSignatureDefValidation(self): - export_dir = os.path.join(test.get_temp_dir(), - "test_signature_def_validation") + export_dir = self._get_export_dir("test_signature_def_validation") builder = saved_model_builder.SavedModelBuilder(export_dir) tensor_without_name = meta_graph_pb2.TensorInfo() @@ -555,7 +558,7 @@ class SavedModelTest(test.TestCase): self._validate_outputs_tensor_info(builder, tensor_empty) def testAssets(self): - export_dir = os.path.join(test.get_temp_dir(), "test_assets") + export_dir = self._get_export_dir("test_assets") builder = saved_model_builder.SavedModelBuilder(export_dir) with self.test_session(graph=ops.Graph()) as sess: @@ -588,7 +591,7 @@ class SavedModelTest(test.TestCase): self.assertFalse(file_io.file_exists(ignored_asset_path)) def testCustomMainOp(self): - export_dir = os.path.join(test.get_temp_dir(), "test_main_op") + export_dir = self._get_export_dir("test_main_op") builder = saved_model_builder.SavedModelBuilder(export_dir) with self.test_session(graph=ops.Graph()) as sess: @@ -623,7 +626,7 @@ class SavedModelTest(test.TestCase): self.assertEqual(3, ops.get_collection("v")[2].eval()) def testLegacyInitOp(self): - export_dir = os.path.join(test.get_temp_dir(), "test_legacy_init_op") + export_dir = self._get_export_dir("test_legacy_init_op") builder = saved_model_builder.SavedModelBuilder(export_dir) with self.test_session(graph=ops.Graph()) as sess: @@ -657,8 +660,8 @@ class SavedModelTest(test.TestCase): self.assertEqual(3, ops.get_collection("v")[2].eval()) def testLegacyInitOpWithNonEmptyCollection(self): - export_dir = os.path.join(test.get_temp_dir(), - "test_legacy_init_op_with_non_empty_collection") + export_dir = self._get_export_dir( + "test_legacy_init_op_with_non_empty_collection") builder = saved_model_builder.SavedModelBuilder(export_dir) with self.test_session(graph=ops.Graph()) as sess: @@ -685,7 +688,7 @@ class SavedModelTest(test.TestCase): sess, ["foo"], legacy_init_op=legacy_init_op) def testMultipleAssets(self): - export_dir = os.path.join(test.get_temp_dir(), "test_multiple_assets") + export_dir = self._get_export_dir("test_multiple_assets") builder = saved_model_builder.SavedModelBuilder(export_dir) with self.test_session(graph=ops.Graph()) as sess: @@ -727,7 +730,7 @@ class SavedModelTest(test.TestCase): "asset_file_tensor:0") def testDuplicateAssets(self): - export_dir = os.path.join(test.get_temp_dir(), "test_duplicate_assets") + export_dir = self._get_export_dir("test_duplicate_assets") builder = saved_model_builder.SavedModelBuilder(export_dir) with self.test_session(graph=ops.Graph()) as sess: @@ -775,7 +778,7 @@ class SavedModelTest(test.TestCase): "asset_file_tensor:0") def testOp(self): - export_dir = os.path.join(test.get_temp_dir(), "test_op") + export_dir = self._get_export_dir("test_op") builder = saved_model_builder.SavedModelBuilder(export_dir) with session.Session( @@ -818,7 +821,7 @@ class SavedModelTest(test.TestCase): self.assertEqual(3, ops.get_collection("v")[2].eval()) def testCustomSaveable(self): - export_dir = os.path.join(test.get_temp_dir(), "custom_saveable") + export_dir = self._get_export_dir("custom_saveable") builder = saved_model_builder.SavedModelBuilder(export_dir) with session.Session( @@ -847,7 +850,7 @@ class SavedModelTest(test.TestCase): self.assertEqual(3.0, v1.values().eval()) def testClearDevices(self): - export_dir = os.path.join(test.get_temp_dir(), "test_clear_devices") + export_dir = self._get_export_dir("test_clear_devices") builder = saved_model_builder.SavedModelBuilder(export_dir) # Specify a device and save a variable. @@ -871,7 +874,9 @@ class SavedModelTest(test.TestCase): 42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval()) def testStripDefaultAttrs(self): - export_dir = os.path.join(test.get_temp_dir(), "test_strip_default_attrs") + if ops._USE_C_API: return # TODO(skyewm): get this working + + export_dir = self._get_export_dir("test_strip_default_attrs") builder = saved_model_builder.SavedModelBuilder(export_dir) # Add a graph with two float32 variables and a Complex Op composing them @@ -941,8 +946,10 @@ class SavedModelTest(test.TestCase): self.assertIn("Tout", node_def.attr) def testStripDefaultAttrsInconsistentConsumerDefaults(self): - export_dir = os.path.join(test.get_temp_dir(), - "test_strip_default_attrs_no_consumer_defaults") + if ops._USE_C_API: return # TODO(skyewm): get this working + + export_dir = self._get_export_dir( + "test_strip_default_attrs_no_consumer_defaults") builder = saved_model_builder.SavedModelBuilder(export_dir) # Add a graph with two float32 variables and a Complex Op composing them diff --git a/tensorflow/python/summary/writer/writer.py b/tensorflow/python/summary/writer/writer.py index 12f120116f4439059f42c7212469ee835cc13ef4..1f3f2287043c021d636113b5a8807c9f4adf77aa 100644 --- a/tensorflow/python/summary/writer/writer.py +++ b/tensorflow/python/summary/writer/writer.py @@ -32,6 +32,7 @@ from tensorflow.python.platform import gfile from tensorflow.python.platform import tf_logging as logging from tensorflow.python.summary import plugin_asset from tensorflow.python.summary.writer.event_file_writer import EventFileWriter +from tensorflow.python.util.tf_export import tf_export _PLUGINS_DIR = "plugins" @@ -276,6 +277,7 @@ class SummaryToEventTransformer(object): self.event_writer.add_event(event) +@tf_export("summary.FileWriter") class FileWriter(SummaryToEventTransformer): """Writes `Summary` protocol buffers to event files. diff --git a/tensorflow/python/summary/writer/writer_cache.py b/tensorflow/python/summary/writer/writer_cache.py index bad289303c0fd0de7836b03a6762d04505521a89..645fa28a37fb125b6b1224961251bc8879d5fe6d 100644 --- a/tensorflow/python/summary/writer/writer_cache.py +++ b/tensorflow/python/summary/writer/writer_cache.py @@ -22,8 +22,10 @@ import threading from tensorflow.python.framework import ops from tensorflow.python.summary.writer.writer import FileWriter +from tensorflow.python.util.tf_export import tf_export +@tf_export('summary.FileWriterCache') class FileWriterCache(object): """Cache for file writers. diff --git a/tensorflow/python/tools/optimize_for_inference_lib.py b/tensorflow/python/tools/optimize_for_inference_lib.py index c2687bf557b03ff588fd369771077c92ba012a15..9c1927122252f45ddfa8092045c7589fa0f45532 100644 --- a/tensorflow/python/tools/optimize_for_inference_lib.py +++ b/tensorflow/python/tools/optimize_for_inference_lib.py @@ -349,6 +349,7 @@ def fold_batch_norms(input_graph_def): bias_add_op.op = "BiasAdd" bias_add_op.name = node.name bias_add_op.attr["T"].CopyFrom(conv_op.attr["T"]) + bias_add_op.attr["data_format"].CopyFrom(conv_op.attr["data_format"]) bias_add_op.input.extend([new_conv_op.name, offset_op.name]) new_ops.extend([scaled_weights_op, new_conv_op, offset_op, bias_add_op]) diff --git a/tensorflow/python/tools/optimize_for_inference_test.py b/tensorflow/python/tools/optimize_for_inference_test.py index 6dd24c0dca1d326592e4f33eba4e6233248dac5f..2ef612473b4bb64f611983f87f8cdd619a2d8a38 100644 --- a/tensorflow/python/tools/optimize_for_inference_test.py +++ b/tensorflow/python/tools/optimize_for_inference_test.py @@ -29,6 +29,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import importer from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_util +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_nn_ops from tensorflow.python.ops import image_ops @@ -38,6 +39,7 @@ from tensorflow.python.platform import test from tensorflow.python.tools import optimize_for_inference_lib +@test_util.with_c_api class OptimizeForInferenceTest(test.TestCase): def create_node_def(self, op, name, inputs): @@ -145,7 +147,7 @@ class OptimizeForInferenceTest(test.TestCase): np.array([0.1, 0.6]), shape=[2], dtype=dtypes.float32) gamma_op = constant_op.constant( np.array([1.0, 2.0]), shape=[2], dtype=dtypes.float32) - ops.get_default_graph().graph_def_versions.producer = 8 + test_util.set_producer_version(ops.get_default_graph(), 8) gen_nn_ops._batch_norm_with_global_normalization( conv_op, mean_op, @@ -171,48 +173,53 @@ class OptimizeForInferenceTest(test.TestCase): self.assertNotEqual("BatchNormWithGlobalNormalization", node.op) def testFoldFusedBatchNorms(self): - with self.test_session() as sess: - inputs = [1, 4, 2, 5, 3, 6, -1, -4, -2, -5, -3, -6] - input_op = constant_op.constant( - np.array(inputs), shape=[1, 1, 6, 2], dtype=dtypes.float32) - weights = [1, 2, 3, 4, 0.1, 0.2, 0.3, 0.4] - weights_op = constant_op.constant( - np.array(weights), shape=[1, 2, 2, 2], dtype=dtypes.float32) - conv_op = nn_ops.conv2d( - input_op, weights_op, [1, 1, 1, 1], padding="SAME", name="conv_op") - mean_op = constant_op.constant( - np.array([10, 20]), shape=[2], dtype=dtypes.float32) - variance_op = constant_op.constant( - np.array([0.25, 0.5]), shape=[2], dtype=dtypes.float32) - beta_op = constant_op.constant( - np.array([0.1, 0.6]), shape=[2], dtype=dtypes.float32) - gamma_op = constant_op.constant( - np.array([1.0, 2.0]), shape=[2], dtype=dtypes.float32) - ops.get_default_graph().graph_def_versions.producer = 9 - gen_nn_ops._fused_batch_norm( - conv_op, - gamma_op, - beta_op, - mean_op, - variance_op, - 0.00001, - is_training=False, - name="output") - original_graph_def = sess.graph_def - original_result = sess.run(["output:0"]) - optimized_graph_def = optimize_for_inference_lib.fold_batch_norms( - original_graph_def) - - with self.test_session() as sess: - _ = importer.import_graph_def( - optimized_graph_def, input_map={}, name="optimized") - optimized_result = sess.run(["optimized/output:0"]) - - self.assertAllClose( - original_result, optimized_result, rtol=1e-04, atol=1e-06) - - for node in optimized_graph_def.node: - self.assertNotEqual("FusedBatchNorm", node.op) + for data_format, use_gpu in [("NHWC", False), ("NCHW", True)]: + with self.test_session(use_gpu=use_gpu) as sess: + inputs = [1, 4, 2, 5, 3, 6, -1, -4, -2, -5, -3, -6] + input_op = constant_op.constant( + np.array(inputs), + shape=[1, 1, 6, 2] if data_format == "NHWC" else [1, 2, 1, 6], + dtype=dtypes.float32) + weights = [1, 2, 3, 4, 0.1, 0.2, 0.3, 0.4] + weights_op = constant_op.constant( + np.array(weights), shape=[1, 2, 2, 2], dtype=dtypes.float32) + conv_op = nn_ops.conv2d( + input_op, weights_op, [1, 1, 1, 1], padding="SAME", + data_format=data_format, name="conv_op") + mean_op = constant_op.constant( + np.array([10, 20]), shape=[2], dtype=dtypes.float32) + variance_op = constant_op.constant( + np.array([0.25, 0.5]), shape=[2], dtype=dtypes.float32) + beta_op = constant_op.constant( + np.array([0.1, 0.6]), shape=[2], dtype=dtypes.float32) + gamma_op = constant_op.constant( + np.array([1.0, 2.0]), shape=[2], dtype=dtypes.float32) + ops.get_default_graph().graph_def_versions.producer = 9 + gen_nn_ops._fused_batch_norm( + conv_op, + gamma_op, + beta_op, + mean_op, + variance_op, + 0.00001, + is_training=False, + data_format=data_format, + name="output") + original_graph_def = sess.graph_def + original_result = sess.run(["output:0"]) + optimized_graph_def = optimize_for_inference_lib.fold_batch_norms( + original_graph_def) + + with self.test_session(use_gpu=use_gpu) as sess: + _ = importer.import_graph_def( + optimized_graph_def, input_map={}, name="optimized") + optimized_result = sess.run(["optimized/output:0"]) + + self.assertAllClose( + original_result, optimized_result, rtol=1e-04, atol=1e-06) + + for node in optimized_graph_def.node: + self.assertNotEqual("FusedBatchNorm", node.op) def testFuseResizePadAndConv(self): with self.test_session() as sess: diff --git a/tensorflow/python/training/adadelta.py b/tensorflow/python/training/adadelta.py index 13c07cfd7bf4333fee3edc3c3ad9d2fb7bcbaad2..c08e3cca007dc17f1112d53bf729c1accf61b5df 100644 --- a/tensorflow/python/training/adadelta.py +++ b/tensorflow/python/training/adadelta.py @@ -22,8 +22,10 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import math_ops from tensorflow.python.training import optimizer from tensorflow.python.training import training_ops +from tensorflow.python.util.tf_export import tf_export +@tf_export("train.AdadeltaOptimizer") class AdadeltaOptimizer(optimizer.Optimizer): """Optimizer that implements the Adadelta algorithm. diff --git a/tensorflow/python/training/adagrad.py b/tensorflow/python/training/adagrad.py index afa192f7cc6e0ecd629fd94252d26961f1407183..deb4e6f546379eff330235dbc302a30c44193830 100644 --- a/tensorflow/python/training/adagrad.py +++ b/tensorflow/python/training/adagrad.py @@ -25,8 +25,10 @@ from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.training import optimizer from tensorflow.python.training import training_ops +from tensorflow.python.util.tf_export import tf_export +@tf_export("train.AdagradOptimizer") class AdagradOptimizer(optimizer.Optimizer): """Optimizer that implements the Adagrad algorithm. diff --git a/tensorflow/python/training/adagrad_da.py b/tensorflow/python/training/adagrad_da.py index b3f9ea323c2bb4fd9ecee93863fbc7955b47a947..5ba403554f570d9df33a5d525a40de2eb0d11138 100644 --- a/tensorflow/python/training/adagrad_da.py +++ b/tensorflow/python/training/adagrad_da.py @@ -23,8 +23,10 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.training import optimizer from tensorflow.python.training import training_ops +from tensorflow.python.util.tf_export import tf_export +@tf_export("train.AdagradDAOptimizer") class AdagradDAOptimizer(optimizer.Optimizer): """Adagrad Dual Averaging algorithm for sparse linear models. diff --git a/tensorflow/python/training/adam.py b/tensorflow/python/training/adam.py index 0c69f8bf3997452f0eeb71c93f4fcf98eb27d8f9..c92f6fc3015960a2b821651231bb94713e0d53dd 100644 --- a/tensorflow/python/training/adam.py +++ b/tensorflow/python/training/adam.py @@ -26,8 +26,10 @@ from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import state_ops from tensorflow.python.training import optimizer from tensorflow.python.training import training_ops +from tensorflow.python.util.tf_export import tf_export +@tf_export("train.AdamOptimizer") class AdamOptimizer(optimizer.Optimizer): """Optimizer that implements the Adam algorithm. diff --git a/tensorflow/python/training/basic_loops.py b/tensorflow/python/training/basic_loops.py index 52b0f4210612bad4a2e838153ac9cbdb1023bf66..7af821c81928e67e0f258bc064d582a4186995c1 100644 --- a/tensorflow/python/training/basic_loops.py +++ b/tensorflow/python/training/basic_loops.py @@ -18,8 +18,10 @@ from __future__ import division from __future__ import print_function from tensorflow.python.framework import errors +from tensorflow.python.util.tf_export import tf_export +@tf_export("train.basic_train_loop") def basic_train_loop(supervisor, train_step_fn, args=None, kwargs=None, master=""): """Basic loop to train a model. diff --git a/tensorflow/python/training/basic_session_run_hooks.py b/tensorflow/python/training/basic_session_run_hooks.py index 864c2e4406ebba0f35c3a1e9d783af43046325ce..aae757b99aa9abb2fca112dcc781fc31e367649d 100644 --- a/tensorflow/python/training/basic_session_run_hooks.py +++ b/tensorflow/python/training/basic_session_run_hooks.py @@ -47,6 +47,7 @@ from tensorflow.python.training import session_run_hook from tensorflow.python.training import training_util from tensorflow.python.training.session_run_hook import SessionRunArgs from tensorflow.python.training.summary_io import SummaryWriterCache +from tensorflow.python.util.tf_export import tf_export class _HookTimer(object): @@ -85,6 +86,7 @@ class _HookTimer(object): raise NotImplementedError +@tf_export("train.SecondOrStepTimer") class SecondOrStepTimer(_HookTimer): """Timer that triggers at most once every N seconds or once every N steps. """ @@ -164,6 +166,7 @@ class NeverTriggerTimer(_HookTimer): return None +@tf_export("train.LoggingTensorHook") class LoggingTensorHook(session_run_hook.SessionRunHook): """Prints the given tensors every N local steps, every N seconds, or at end. @@ -262,6 +265,7 @@ class LoggingTensorHook(session_run_hook.SessionRunHook): self._log_tensors(values) +@tf_export("train.StopAtStepHook") class StopAtStepHook(session_run_hook.SessionRunHook): """Hook that requests stop at a specified step.""" @@ -317,6 +321,7 @@ class StopAtStepHook(session_run_hook.SessionRunHook): run_context.request_stop() +@tf_export("train.CheckpointSaverListener") class CheckpointSaverListener(object): """Interface for listeners that take action before or after checkpoint save. @@ -375,6 +380,7 @@ class CheckpointSaverListener(object): pass +@tf_export("train.CheckpointSaverHook") class CheckpointSaverHook(session_run_hook.SessionRunHook): """Saves checkpoints every N steps or seconds.""" @@ -497,6 +503,7 @@ class CheckpointSaverHook(session_run_hook.SessionRunHook): return savers[0] +@tf_export("train.StepCounterHook") class StepCounterHook(session_run_hook.SessionRunHook): """Hook that counts steps per second.""" @@ -575,12 +582,14 @@ class StepCounterHook(session_run_hook.SessionRunHook): self._last_global_step = stale_global_step +@tf_export("train.NanLossDuringTrainingError") class NanLossDuringTrainingError(RuntimeError): def __str__(self): return "NaN loss during training." +@tf_export("train.NanTensorHook") class NanTensorHook(session_run_hook.SessionRunHook): """Monitors the loss tensor and stops training if loss is NaN. @@ -612,6 +621,7 @@ class NanTensorHook(session_run_hook.SessionRunHook): run_context.request_stop() +@tf_export("train.SummarySaverHook") class SummarySaverHook(session_run_hook.SessionRunHook): """Saves summaries every N steps.""" @@ -720,6 +730,7 @@ class SummarySaverHook(session_run_hook.SessionRunHook): return summary_op +@tf_export("train.GlobalStepWaiterHook") class GlobalStepWaiterHook(session_run_hook.SessionRunHook): """Delays execution until global step reaches `wait_until_step`. @@ -767,6 +778,7 @@ class GlobalStepWaiterHook(session_run_hook.SessionRunHook): time.sleep(0.5) +@tf_export("train.FinalOpsHook") class FinalOpsHook(session_run_hook.SessionRunHook): """A hook which evaluates `Tensors` at the end of a session.""" @@ -793,6 +805,7 @@ class FinalOpsHook(session_run_hook.SessionRunHook): feed_dict=self._final_ops_feed_dict) +@tf_export("train.FeedFnHook") class FeedFnHook(session_run_hook.SessionRunHook): """Runs `feed_fn` and sets the `feed_dict` accordingly.""" @@ -810,6 +823,7 @@ class FeedFnHook(session_run_hook.SessionRunHook): fetches=None, feed_dict=self.feed_fn()) +@tf_export("train.ProfilerHook") class ProfilerHook(session_run_hook.SessionRunHook): """Captures CPU/GPU profiling information every N steps or seconds. diff --git a/tensorflow/python/training/checkpoint_utils.py b/tensorflow/python/training/checkpoint_utils.py index b5d3e7879711c2a4fca1c8d7e47288c3d12d1b0d..fa3de6fad27b6cc773f9f2e86e9f95395eb7c285 100644 --- a/tensorflow/python/training/checkpoint_utils.py +++ b/tensorflow/python/training/checkpoint_utils.py @@ -29,6 +29,7 @@ from tensorflow.python.ops import variables from tensorflow.python.platform import gfile from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import saver +from tensorflow.python.util.tf_export import tf_export __all__ = [ @@ -36,6 +37,7 @@ __all__ = [ ] +@tf_export("train.load_checkpoint") def load_checkpoint(ckpt_dir_or_file): """Returns `CheckpointReader` for checkpoint found in `ckpt_dir_or_file`. @@ -60,6 +62,7 @@ def load_checkpoint(ckpt_dir_or_file): return pywrap_tensorflow.NewCheckpointReader(filename) +@tf_export("train.load_variable") def load_variable(ckpt_dir_or_file, name): """Returns the tensor value of the given variable in the checkpoint. @@ -77,6 +80,7 @@ def load_variable(ckpt_dir_or_file, name): return reader.get_tensor(name) +@tf_export("train.list_variables") def list_variables(ckpt_dir_or_file): """Returns list of all variables in the checkpoint. @@ -95,6 +99,7 @@ def list_variables(ckpt_dir_or_file): return result +@tf_export("train.init_from_checkpoint") def init_from_checkpoint(ckpt_dir_or_file, assignment_map): """Initializes current variables with tensors loaded from given checkpoint. @@ -242,6 +247,9 @@ def init_from_checkpoint(ckpt_dir_or_file, assignment_map): full_tensor_name = full_tensor_name[1:] if tensor_name_in_ckpt != "/": full_tensor_name = tensor_name_in_ckpt + full_tensor_name + # Remove trailing '/', if any, in the full_tensor_name + if full_tensor_name.endswith("/"): + full_tensor_name = full_tensor_name[:-1] if full_tensor_name not in variable_map: raise ValueError( "Tensor %s (%s in %s) is not found in %s checkpoint" % ( diff --git a/tensorflow/python/training/coordinator.py b/tensorflow/python/training/coordinator.py index 0e31255b74f64657cffc4a2f58798835513f0444..0ff97d85e37e6167f1200ba56940f4a663c259a2 100644 --- a/tensorflow/python/training/coordinator.py +++ b/tensorflow/python/training/coordinator.py @@ -27,8 +27,10 @@ import six from tensorflow.python.framework import errors from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import compat +from tensorflow.python.util.tf_export import tf_export +@tf_export("train.Coordinator") class Coordinator(object): """A coordinator for threads. @@ -406,6 +408,7 @@ class Coordinator(object): # Threads for the standard services. +@tf_export("train.LooperThread") class LooperThread(threading.Thread): """A thread that runs code repeatedly, optionally on a timer. diff --git a/tensorflow/python/training/device_setter.py b/tensorflow/python/training/device_setter.py index 37ab625779f788b1b8e270a15db3244ea6f1bef3..689088bb41edfd94a1d483ed2b5f7447e9e060e7 100644 --- a/tensorflow/python/training/device_setter.py +++ b/tensorflow/python/training/device_setter.py @@ -23,6 +23,7 @@ from tensorflow.core.framework import node_def_pb2 from tensorflow.python.framework import device as pydev from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import server_lib +from tensorflow.python.util.tf_export import tf_export class _RoundRobinStrategy(object): @@ -121,6 +122,7 @@ class _ReplicaDeviceChooser(object): return worker_device.to_string() +@tf_export("train.replica_device_setter") def replica_device_setter(ps_tasks=0, ps_device="/job:ps", worker_device="/job:worker", merge_devices=True, cluster=None, ps_ops=None, ps_strategy=None): diff --git a/tensorflow/python/training/ftrl.py b/tensorflow/python/training/ftrl.py index c64a1b3f799e776c7bbbbcfb691bdd97e4a34466..9d02e694db15637126f37ee5575638908b351def 100644 --- a/tensorflow/python/training/ftrl.py +++ b/tensorflow/python/training/ftrl.py @@ -22,8 +22,10 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import math_ops from tensorflow.python.training import optimizer from tensorflow.python.training import training_ops +from tensorflow.python.util.tf_export import tf_export +@tf_export("train.FtrlOptimizer") class FtrlOptimizer(optimizer.Optimizer): """Optimizer that implements the FTRL algorithm. @@ -265,4 +267,3 @@ class FtrlOptimizer(optimizer.Optimizer): grad.dtype), math_ops.cast(self._learning_rate_power_tensor, grad.dtype), use_locking=self._use_locking) - diff --git a/tensorflow/python/training/gradient_descent.py b/tensorflow/python/training/gradient_descent.py index 5a536e27297f054671e7e44a9e5d20a8b36580b7..380e14e02497fbe3681d6bae03fe9c636c5d13aa 100644 --- a/tensorflow/python/training/gradient_descent.py +++ b/tensorflow/python/training/gradient_descent.py @@ -23,8 +23,10 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.training import optimizer from tensorflow.python.training import training_ops +from tensorflow.python.util.tf_export import tf_export +@tf_export("train.GradientDescentOptimizer") class GradientDescentOptimizer(optimizer.Optimizer): """Optimizer that implements the gradient descent algorithm. """ diff --git a/tensorflow/python/training/input.py b/tensorflow/python/training/input.py index 331a51e8bc848917967fed06632fe0d1c5bcad9c..bd9985a7c5c181c0431e0c0a91186bc36b11c787 100644 --- a/tensorflow/python/training/input.py +++ b/tensorflow/python/training/input.py @@ -44,6 +44,7 @@ from tensorflow.python.ops import sparse_ops from tensorflow.python.ops import variable_scope as vs from tensorflow.python.summary import summary from tensorflow.python.training import queue_runner +from tensorflow.python.util.tf_export import tf_export # pylint: disable=protected-access @@ -53,9 +54,12 @@ _restore_sparse = sparse_ops._take_many_sparse_from_tensors_map # pylint: enable=protected-access +@tf_export("train.match_filenames_once") def match_filenames_once(pattern, name=None): """Save the list of files matching pattern, so it is only computed once. + NOTE: The order of the files returned can be non-deterministic. + Args: pattern: A file pattern (glob), or 1D tensor of file patterns. name: A name for the operations (optional). @@ -70,6 +74,7 @@ def match_filenames_once(pattern, name=None): collections=[ops.GraphKeys.LOCAL_VARIABLES]) +@tf_export("train.limit_epochs") def limit_epochs(tensor, num_epochs=None, name=None): """Returns tensor `num_epochs` times and then raises an `OutOfRange` error. @@ -102,6 +107,7 @@ def limit_epochs(tensor, num_epochs=None, name=None): return array_ops.identity(tensor, name=name) +@tf_export("train.input_producer") def input_producer(input_tensor, element_shape=None, num_epochs=None, @@ -184,6 +190,7 @@ def input_producer(input_tensor, return q +@tf_export("train.string_input_producer") def string_input_producer(string_tensor, num_epochs=None, shuffle=True, @@ -253,6 +260,7 @@ def string_input_producer(string_tensor, cancel_op=cancel_op) +@tf_export("train.range_input_producer") def range_input_producer(limit, num_epochs=None, shuffle=True, seed=None, capacity=32, shared_name=None, name=None): """Produces the integers from 0 to limit-1 in a queue. @@ -290,6 +298,7 @@ def range_input_producer(limit, num_epochs=None, shuffle=True, seed=None, shared_name, "fraction_of_%d_full" % capacity, name) +@tf_export("train.slice_input_producer") def slice_input_producer(tensor_list, num_epochs=None, shuffle=True, seed=None, capacity=32, shared_name=None, name=None): """Produces a slice of each `Tensor` in `tensor_list`. @@ -885,6 +894,7 @@ def _shuffle_batch_join(tensors_list, batch_size, capacity, # Batching functions ---------------------------------------------------------- +@tf_export("train.batch") def batch(tensors, batch_size, num_threads=1, capacity=32, enqueue_many=False, shapes=None, dynamic_pad=False, allow_smaller_final_batch=False, shared_name=None, name=None): @@ -979,6 +989,7 @@ def batch(tensors, batch_size, num_threads=1, capacity=32, name=name) +@tf_export("train.maybe_batch") def maybe_batch(tensors, keep_input, batch_size, num_threads=1, capacity=32, enqueue_many=False, shapes=None, dynamic_pad=False, allow_smaller_final_batch=False, shared_name=None, name=None): @@ -1031,6 +1042,7 @@ def maybe_batch(tensors, keep_input, batch_size, num_threads=1, capacity=32, name=name) +@tf_export("train.batch_join") def batch_join(tensors_list, batch_size, capacity=32, enqueue_many=False, shapes=None, dynamic_pad=False, allow_smaller_final_batch=False, shared_name=None, name=None): @@ -1136,6 +1148,7 @@ def batch_join(tensors_list, batch_size, capacity=32, enqueue_many=False, name=name) +@tf_export("train.maybe_batch_join") def maybe_batch_join(tensors_list, keep_input, batch_size, capacity=32, enqueue_many=False, shapes=None, dynamic_pad=False, allow_smaller_final_batch=False, shared_name=None, @@ -1188,6 +1201,7 @@ def maybe_batch_join(tensors_list, keep_input, batch_size, capacity=32, name=name) +@tf_export("train.shuffle_batch") def shuffle_batch(tensors, batch_size, capacity, min_after_dequeue, num_threads=1, seed=None, enqueue_many=False, shapes=None, allow_smaller_final_batch=False, shared_name=None, name=None): @@ -1287,6 +1301,7 @@ def shuffle_batch(tensors, batch_size, capacity, min_after_dequeue, name=name) +@tf_export("train.maybe_shuffle_batch") def maybe_shuffle_batch(tensors, batch_size, capacity, min_after_dequeue, keep_input, num_threads=1, seed=None, enqueue_many=False, shapes=None, @@ -1346,6 +1361,7 @@ def maybe_shuffle_batch(tensors, batch_size, capacity, min_after_dequeue, name=name) +@tf_export("train.shuffle_batch_join") def shuffle_batch_join(tensors_list, batch_size, capacity, min_after_dequeue, seed=None, enqueue_many=False, shapes=None, allow_smaller_final_batch=False, @@ -1439,6 +1455,7 @@ def shuffle_batch_join(tensors_list, batch_size, capacity, name=name) +@tf_export("train.maybe_shuffle_batch_join") def maybe_shuffle_batch_join(tensors_list, batch_size, capacity, min_after_dequeue, keep_input, seed=None, enqueue_many=False, shapes=None, diff --git a/tensorflow/python/training/learning_rate_decay.py b/tensorflow/python/training/learning_rate_decay.py index 343a49cded01ec3665170c85767bcc812531df78..10ab4c1137ff226d88902143d4f2281ad77de531 100644 --- a/tensorflow/python/training/learning_rate_decay.py +++ b/tensorflow/python/training/learning_rate_decay.py @@ -25,8 +25,10 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops +from tensorflow.python.util.tf_export import tf_export +@tf_export("train.exponential_decay") def exponential_decay(learning_rate, global_step, decay_steps, @@ -103,6 +105,7 @@ def exponential_decay(learning_rate, learning_rate, math_ops.pow(decay_rate, p), name=name) +@tf_export("train.piecewise_constant") def piecewise_constant(x, boundaries, values, name=None): """Piecewise constant from boundaries and interval values. @@ -182,6 +185,7 @@ def piecewise_constant(x, boundaries, values, name=None): return control_flow_ops.case(pred_fn_pairs, default, exclusive=True) +@tf_export("train.polynomial_decay") def polynomial_decay(learning_rate, global_step, decay_steps, @@ -291,6 +295,7 @@ def polynomial_decay(learning_rate, name=name) +@tf_export("train.natural_exp_decay") def natural_exp_decay(learning_rate, global_step, decay_steps, @@ -362,6 +367,7 @@ def natural_exp_decay(learning_rate, return math_ops.multiply(learning_rate, exponent, name=name) +@tf_export("train.inverse_time_decay") def inverse_time_decay(learning_rate, global_step, decay_steps, @@ -444,6 +450,7 @@ def inverse_time_decay(learning_rate, return math_ops.div(learning_rate, denom, name=name) +@tf_export("train.cosine_decay") def cosine_decay(learning_rate, global_step, decay_steps, alpha=0.0, name=None): """Applies cosine decay to the learning rate. @@ -503,6 +510,7 @@ def cosine_decay(learning_rate, global_step, decay_steps, alpha=0.0, name=None): return math_ops.multiply(learning_rate, decayed) +@tf_export("train.cosine_decay_restarts") def cosine_decay_restarts(learning_rate, global_step, first_decay_steps, @@ -596,6 +604,7 @@ def cosine_decay_restarts(learning_rate, return math_ops.multiply(learning_rate, decayed, name=name) +@tf_export("train.linear_cosine_decay") def linear_cosine_decay(learning_rate, global_step, decay_steps, @@ -679,6 +688,7 @@ def linear_cosine_decay(learning_rate, return math_ops.multiply(learning_rate, linear_cosine_decayed, name=name) +@tf_export("train.noisy_linear_cosine_decay") def noisy_linear_cosine_decay(learning_rate, global_step, decay_steps, diff --git a/tensorflow/python/training/momentum.py b/tensorflow/python/training/momentum.py index cf9530d87c46783b517884610b644b076bef6807..bd9fa79d8feac68c149f787ee8501bdddb173d33 100644 --- a/tensorflow/python/training/momentum.py +++ b/tensorflow/python/training/momentum.py @@ -22,8 +22,10 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import math_ops from tensorflow.python.training import optimizer from tensorflow.python.training import training_ops +from tensorflow.python.util.tf_export import tf_export +@tf_export("train.MomentumOptimizer") class MomentumOptimizer(optimizer.Optimizer): """Optimizer that implements the Momentum algorithm. diff --git a/tensorflow/python/training/monitored_session.py b/tensorflow/python/training/monitored_session.py index fa3517db27be4581deb85f77f022406b8b30ec56..6c5c9e01a76d539b550420134b09090b89beed46 100644 --- a/tensorflow/python/training/monitored_session.py +++ b/tensorflow/python/training/monitored_session.py @@ -41,6 +41,7 @@ from tensorflow.python.training import queue_runner from tensorflow.python.training import saver as training_saver from tensorflow.python.training import session_manager as sm from tensorflow.python.training import session_run_hook +from tensorflow.python.util.tf_export import tf_export # The list of exceptions that we should recover from. Exceptions not in this @@ -52,6 +53,7 @@ _PREEMPTION_ERRORS = (errors.AbortedError, errors.UnavailableError) USE_DEFAULT = object() +@tf_export('train.Scaffold') class Scaffold(object): """Structure to create or gather pieces commonly needed to train a model. @@ -272,6 +274,7 @@ class Scaffold(object): resources.initialize_resources(resources.local_resources())) +@tf_export('train.MonitoredTrainingSession') def MonitoredTrainingSession(master='', # pylint: disable=invalid-name is_chief=True, checkpoint_dir=None, @@ -381,6 +384,7 @@ def MonitoredTrainingSession(master='', # pylint: disable=invalid-name stop_grace_period_secs=stop_grace_period_secs) +@tf_export('train.SessionCreator') class SessionCreator(object): """A factory for tf.Session.""" @@ -390,6 +394,7 @@ class SessionCreator(object): 'create_session is not implemented for {}.'.format(self)) +@tf_export('train.ChiefSessionCreator') class ChiefSessionCreator(SessionCreator): """Creates a tf.Session for a chief.""" @@ -441,6 +446,7 @@ class ChiefSessionCreator(SessionCreator): init_fn=self._scaffold.init_fn) +@tf_export('train.WorkerSessionCreator') class WorkerSessionCreator(SessionCreator): """Creates a tf.Session for a worker.""" @@ -706,6 +712,7 @@ class _MonitoredSession(object): return self._coordinated_creator.tf_sess +@tf_export('train.MonitoredSession') class MonitoredSession(_MonitoredSession): """Session-like object that handles initialization, recovery and hooks. @@ -788,6 +795,7 @@ class MonitoredSession(_MonitoredSession): stop_grace_period_secs=stop_grace_period_secs) +@tf_export('train.SingularMonitoredSession') class SingularMonitoredSession(_MonitoredSession): """Session-like object that handles initialization, restoring, and hooks. diff --git a/tensorflow/python/training/moving_averages.py b/tensorflow/python/training/moving_averages.py index 43ed1ac170d0d993bf7b5bcaff3ff6a8cbbde6b2..2d89082ad75ff8b39575711bdbbc3f454f99a70d 100644 --- a/tensorflow/python/training/moving_averages.py +++ b/tensorflow/python/training/moving_averages.py @@ -26,6 +26,7 @@ from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.training import slot_creator +from tensorflow.python.util.tf_export import tf_export # TODO(touts): switch to variables.Variable. @@ -230,6 +231,7 @@ def _zero_debias(unbiased_var, value, decay): return unbiased_ema_delta +@tf_export("train.ExponentialMovingAverage") class ExponentialMovingAverage(object): """Maintains moving averages of variables by employing an exponential decay. diff --git a/tensorflow/python/training/optimizer.py b/tensorflow/python/training/optimizer.py index a06b3eada6cf2f0e4b5690ebe0c92e60f5d2ec0e..9ec588bac96d8c8404dee994bc5991f897abbf77 100644 --- a/tensorflow/python/training/optimizer.py +++ b/tensorflow/python/training/optimizer.py @@ -36,6 +36,7 @@ from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.training import slot_creator from tensorflow.python.util import nest +from tensorflow.python.util.tf_export import tf_export def _get_variable_for(v): @@ -187,6 +188,7 @@ def _get_processor(v): raise NotImplementedError("Trying to optimize unsupported type ", v) +@tf_export("train.Optimizer") class Optimizer(object): """Base class for optimizers. @@ -600,7 +602,7 @@ class Optimizer(object): if executing_eagerly: # No variable.op in eager mode. We don't expect lots of eager graphs, # but behavior should be consistent with graph mode. - return variable._container_prefix == current_graph._container_prefix # pylint: disable=protected-access + return variable._graph_key == current_graph._graph_key # pylint: disable=protected-access else: return variable.op.graph is current_graph diff --git a/tensorflow/python/training/proximal_adagrad.py b/tensorflow/python/training/proximal_adagrad.py index da31ab325d5e45e1943f554c45717cceb4dc638f..9bd677b8efcd447f74ec2a3cbe94d63eeb9a4dd1 100644 --- a/tensorflow/python/training/proximal_adagrad.py +++ b/tensorflow/python/training/proximal_adagrad.py @@ -23,8 +23,10 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import math_ops from tensorflow.python.training import optimizer from tensorflow.python.training import training_ops +from tensorflow.python.util.tf_export import tf_export +@tf_export("train.ProximalAdagradOptimizer") class ProximalAdagradOptimizer(optimizer.Optimizer): # pylint: disable=line-too-long """Optimizer that implements the Proximal Adagrad algorithm. diff --git a/tensorflow/python/training/proximal_gradient_descent.py b/tensorflow/python/training/proximal_gradient_descent.py index 53e9dc2ef2c86a20070fdbdc690b39d2c0e9df06..369b6cbb50e5c621737c095a24eeb473f3870534 100644 --- a/tensorflow/python/training/proximal_gradient_descent.py +++ b/tensorflow/python/training/proximal_gradient_descent.py @@ -24,8 +24,10 @@ from tensorflow.python.ops import math_ops # pylint: enable=unused-import from tensorflow.python.training import optimizer from tensorflow.python.training import training_ops +from tensorflow.python.util.tf_export import tf_export +@tf_export("train.ProximalGradientDescentOptimizer") class ProximalGradientDescentOptimizer(optimizer.Optimizer): # pylint: disable=line-too-long """Optimizer that implements the proximal gradient descent algorithm. diff --git a/tensorflow/python/training/queue_runner_impl.py b/tensorflow/python/training/queue_runner_impl.py index 4e7c81d7b2913d71a23dcaa3751db2aaffdc67cf..07afba79abf4d636c9ec2d53bcf2641594a35733 100644 --- a/tensorflow/python/training/queue_runner_impl.py +++ b/tensorflow/python/training/queue_runner_impl.py @@ -27,8 +27,10 @@ from tensorflow.python.eager import context from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.util.tf_export import tf_export +@tf_export("train.queue_runner.QueueRunner", "train.QueueRunner") class QueueRunner(object): """Holds a list of enqueue operations for a queue, each to be run in a thread. @@ -384,6 +386,7 @@ class QueueRunner(object): import_scope=import_scope) +@tf_export("train.queue_runner.add_queue_runner", "train.add_queue_runner") def add_queue_runner(qr, collection=ops.GraphKeys.QUEUE_RUNNERS): """Adds a `QueueRunner` to a collection in the graph. @@ -402,6 +405,8 @@ def add_queue_runner(qr, collection=ops.GraphKeys.QUEUE_RUNNERS): ops.add_to_collection(collection, qr) +@tf_export("train.queue_runner.start_queue_runners", + "train.start_queue_runners") def start_queue_runners(sess=None, coord=None, daemon=True, start=True, collection=ops.GraphKeys.QUEUE_RUNNERS): """Starts all queue runners collected in the graph. diff --git a/tensorflow/python/training/rmsprop.py b/tensorflow/python/training/rmsprop.py index 745e61201823d81ac01f117be89b910b2810f80c..89d1099a49fedf2cd2ae372cb9c5f7422d43acc2 100644 --- a/tensorflow/python/training/rmsprop.py +++ b/tensorflow/python/training/rmsprop.py @@ -46,8 +46,10 @@ from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.training import optimizer from tensorflow.python.training import training_ops +from tensorflow.python.util.tf_export import tf_export +@tf_export("train.RMSPropOptimizer") class RMSPropOptimizer(optimizer.Optimizer): """Optimizer that implements the RMSProp algorithm. diff --git a/tensorflow/python/training/saver.py b/tensorflow/python/training/saver.py index 4f3773c0fc71e1f1abd8197dea94ce2a63881389..764f8400122118b6abcbad25ce0555954e38d29d 100644 --- a/tensorflow/python/training/saver.py +++ b/tensorflow/python/training/saver.py @@ -53,6 +53,7 @@ from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import training_util from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState from tensorflow.python.util import compat +from tensorflow.python.util.tf_export import tf_export # Op names which identify variable reads which should be saved. @@ -889,6 +890,7 @@ def _GetCheckpointFilename(save_dir, latest_filename): return os.path.join(save_dir, latest_filename) +@tf_export("train.generate_checkpoint_state_proto") def generate_checkpoint_state_proto(save_dir, model_checkpoint_path, all_model_checkpoint_paths=None): @@ -933,6 +935,7 @@ def generate_checkpoint_state_proto(save_dir, return coord_checkpoint_proto +@tf_export("train.update_checkpoint_state") def update_checkpoint_state(save_dir, model_checkpoint_path, all_model_checkpoint_paths=None, @@ -1025,6 +1028,7 @@ def _update_checkpoint_state(save_dir, text_format.MessageToString(ckpt)) +@tf_export("train.get_checkpoint_state") def get_checkpoint_state(checkpoint_dir, latest_filename=None): """Returns CheckpointState proto from the "checkpoint" file. @@ -1082,6 +1086,7 @@ def get_checkpoint_state(checkpoint_dir, latest_filename=None): return ckpt +@tf_export("train.Saver") class Saver(object): """Saves and restores variables. @@ -1229,7 +1234,7 @@ class Saver(object): The `saver_def` proto should be the one returned by the `as_saver_def()` call of the `Saver` that was created for that `Graph`. builder: Optional `SaverBuilder` to use if a `saver_def` was not provided. - Defaults to `BaseSaverBuilder()`. + Defaults to `BulkSaverBuilder()`. defer_build: If `True`, defer adding the save and restore ops to the `build()` call. In that case `build()` should be called before finalizing the graph or using the saver. @@ -1309,7 +1314,7 @@ class Saver(object): if not self.saver_def or context.in_eager_mode(): if self._builder is None: - self._builder = BaseSaverBuilder(self._write_version) + self._builder = BulkSaverBuilder(self._write_version) if self._var_list is None: # pylint: disable=protected-access @@ -1788,6 +1793,7 @@ def _prefix_to_checkpoint_path(prefix, format_version): return prefix # Just the data file. +@tf_export("train.latest_checkpoint") def latest_checkpoint(checkpoint_dir, latest_filename=None): """Finds the filename of latest saved checkpoint file. @@ -1817,6 +1823,7 @@ def latest_checkpoint(checkpoint_dir, latest_filename=None): return None +@tf_export("train.import_meta_graph") def import_meta_graph(meta_graph_or_file, clear_devices=False, import_scope=None, **kwargs): """Recreates a Graph saved in a `MetaGraphDef` proto. @@ -1918,6 +1925,7 @@ def import_meta_graph(meta_graph_or_file, clear_devices=False, return None +@tf_export("train.export_meta_graph") def export_meta_graph(filename=None, meta_info_def=None, graph_def=None, @@ -1994,6 +2002,7 @@ def export_meta_graph(filename=None, return meta_graph_def +@tf_export("train.checkpoint_exists") def checkpoint_exists(checkpoint_prefix): """Checks whether a V1 or V2 checkpoint exists with the specified prefix. @@ -2018,6 +2027,7 @@ def checkpoint_exists(checkpoint_prefix): return False +@tf_export("train.get_checkpoint_mtimes") def get_checkpoint_mtimes(checkpoint_prefixes): """Returns the mtimes (modification timestamps) of the checkpoints. diff --git a/tensorflow/python/training/server_lib.py b/tensorflow/python/training/server_lib.py index 29da67a30a58c1b8b8e172b2ccede340880fef58..2f421d1cc0a0190670082fabf4e25470c6a1723b 100644 --- a/tensorflow/python/training/server_lib.py +++ b/tensorflow/python/training/server_lib.py @@ -23,6 +23,7 @@ from tensorflow.core.protobuf import tensorflow_server_pb2 from tensorflow.python import pywrap_tensorflow from tensorflow.python.framework import errors from tensorflow.python.util import compat +from tensorflow.python.util.tf_export import tf_export def _make_server_def(server_or_cluster_def, job_name, task_index, protocol, @@ -92,6 +93,7 @@ def _make_server_def(server_or_cluster_def, job_name, task_index, protocol, return server_def +@tf_export("train.Server") class Server(object): """An in-process TensorFlow server, for use in distributed training. @@ -221,6 +223,7 @@ class Server(object): start=start) +@tf_export("train.ClusterSpec") class ClusterSpec(object): """Represents a cluster as a set of "tasks", organized into "jobs". diff --git a/tensorflow/python/training/session_manager.py b/tensorflow/python/training/session_manager.py index b396a1e7d0a06ec7b952ba2980e081e01e681d4d..360e02fb44c1062f71bb50449b9ef381510a9c69 100644 --- a/tensorflow/python/training/session_manager.py +++ b/tensorflow/python/training/session_manager.py @@ -25,6 +25,7 @@ from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import saver as saver_mod +from tensorflow.python.util.tf_export import tf_export def _maybe_name(obj): @@ -44,6 +45,7 @@ def _maybe_name(obj): return "" % type(obj) +@tf_export("train.SessionManager") class SessionManager(object): """Training helper that restores from checkpoint and creates session. diff --git a/tensorflow/python/training/session_run_hook.py b/tensorflow/python/training/session_run_hook.py index 5b023d8a2672af5d1fab1c2566b19fca738fd1f7..89f40300650f3b6cd1ae15d946640c9df91771e2 100644 --- a/tensorflow/python/training/session_run_hook.py +++ b/tensorflow/python/training/session_run_hook.py @@ -96,8 +96,10 @@ from __future__ import division from __future__ import print_function import collections +from tensorflow.python.util.tf_export import tf_export +@tf_export("train.SessionRunHook") class SessionRunHook(object): """Hook to extend calls to MonitoredSession.run().""" @@ -189,6 +191,7 @@ class SessionRunHook(object): pass +@tf_export("train.SessionRunArgs") class SessionRunArgs( collections.namedtuple("SessionRunArgs", ["fetches", "feed_dict", "options"])): @@ -213,6 +216,7 @@ class SessionRunArgs( return super(SessionRunArgs, cls).__new__(cls, fetches, feed_dict, options) +@tf_export("train.SessionRunContext") class SessionRunContext(object): """Provides information about the `session.run()` call being made. @@ -264,6 +268,7 @@ class SessionRunContext(object): self._stop_requested = True +@tf_export("train.SessionRunValues") class SessionRunValues( collections.namedtuple("SessionRunValues", ["results", "options", "run_metadata"])): diff --git a/tensorflow/python/training/supervisor.py b/tensorflow/python/training/supervisor.py index e4514aaea223b6b254a7a72e11e6b70b576fd54b..d2ad34773e0615256c340826dcc312cc8a00dc23 100644 --- a/tensorflow/python/training/supervisor.py +++ b/tensorflow/python/training/supervisor.py @@ -37,8 +37,10 @@ from tensorflow.python.training import saver as saver_mod from tensorflow.python.training import session_manager as session_manager_mod from tensorflow.python.training import training_util from tensorflow.python.util import deprecation +from tensorflow.python.util.tf_export import tf_export +@tf_export("train.Supervisor") class Supervisor(object): """A training helper that checkpoints models and computes summaries. diff --git a/tensorflow/python/training/sync_replicas_optimizer.py b/tensorflow/python/training/sync_replicas_optimizer.py index 47702fdad05d13015e0cbf7768129b0c53b6c14c..0c6cf910d1a01dc20b15fb1cd5dbb249fbb60ef5 100644 --- a/tensorflow/python/training/sync_replicas_optimizer.py +++ b/tensorflow/python/training/sync_replicas_optimizer.py @@ -31,6 +31,7 @@ from tensorflow.python.training import optimizer from tensorflow.python.training import queue_runner from tensorflow.python.training import session_manager from tensorflow.python.training import session_run_hook +from tensorflow.python.util.tf_export import tf_export # Please note that the gradients from replicas are averaged instead of summed @@ -38,6 +39,7 @@ from tensorflow.python.training import session_run_hook # rate according to the number of replicas. This change is introduced to be # consistent with how gradients are aggregated (averaged) within a batch in a # replica. +@tf_export("train.SyncReplicasOptimizer") class SyncReplicasOptimizer(optimizer.Optimizer): """Class to synchronize, aggregate gradients and pass them to the optimizer. diff --git a/tensorflow/python/training/training_util.py b/tensorflow/python/training/training_util.py index 89a9e129328fe38da2ce497a7f26dc11446ea032..499f1feb2dbf8aee26314a43b0a000fb91a1c686 100644 --- a/tensorflow/python/training/training_util.py +++ b/tensorflow/python/training/training_util.py @@ -29,6 +29,7 @@ from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.util.tf_export import tf_export # Picked a long key value to minimize the chance of collision with user defined @@ -40,6 +41,7 @@ GLOBAL_STEP_READ_KEY = 'global_step_read_op_cache' write_graph = graph_io.write_graph +@tf_export('train.global_step') def global_step(sess, global_step_tensor): """Small helper to get the global step. @@ -67,6 +69,7 @@ def global_step(sess, global_step_tensor): return int(sess.run(global_step_tensor)) +@tf_export('train.get_global_step') def get_global_step(graph=None): """Get the global step tensor. @@ -101,6 +104,7 @@ def get_global_step(graph=None): return global_step_tensor +@tf_export('train.create_global_step') def create_global_step(graph=None): """Create global step tensor in graph. @@ -139,6 +143,7 @@ def create_global_step(graph=None): ops.GraphKeys.GLOBAL_STEP]) +@tf_export('train.get_or_create_global_step') def get_or_create_global_step(graph=None): """Returns and create (if necessary) the global step tensor. @@ -156,6 +161,7 @@ def get_or_create_global_step(graph=None): return global_step_tensor +@tf_export('train.assert_global_step') def assert_global_step(global_step_tensor): """Asserts `global_step_tensor` is a scalar int `Variable` or `Tensor`. diff --git a/tensorflow/python/util/compat.py b/tensorflow/python/util/compat.py index 270d96a3c7c831d8c06dd86199cf2dc5dfc43421..7e5f192b8f1ae5c86e463c7560553f2bcfd15995 100644 --- a/tensorflow/python/util/compat.py +++ b/tensorflow/python/util/compat.py @@ -41,8 +41,10 @@ import numpy as _np import six as _six from tensorflow.python.util.all_util import remove_undocumented +from tensorflow.python.util.tf_export import tf_export +@tf_export('compat.as_bytes', 'compat.as_str') def as_bytes(bytes_or_text, encoding='utf-8'): """Converts either bytes or unicode to `bytes`, using utf-8 encoding for text. @@ -65,6 +67,7 @@ def as_bytes(bytes_or_text, encoding='utf-8'): (bytes_or_text,)) +@tf_export('compat.as_text') def as_text(bytes_or_text, encoding='utf-8'): """Returns the given argument as a unicode string. @@ -93,6 +96,7 @@ else: as_str = as_text +@tf_export('compat.as_str_any') def as_str_any(value): """Converts to `str` as `str(value)`, but use `as_str` for `bytes`. @@ -125,11 +129,16 @@ def path_to_str(path): # Numpy 1.8 scalars don't inherit from numbers.Integral in Python 3, so we # need to check them specifically. The same goes from Real and Complex. integral_types = (_numbers.Integral, _np.integer) +tf_export('compat.integral_types').export_constant(__name__, 'integral_types') real_types = (_numbers.Real, _np.integer, _np.floating) +tf_export('compat.real_types').export_constant(__name__, 'real_types') complex_types = (_numbers.Complex, _np.number) +tf_export('compat.complex_types').export_constant(__name__, 'complex_types') # Either bytes or text. bytes_or_text_types = (bytes, _six.text_type) +tf_export('compat.bytes_or_text_types').export_constant(__name__, + 'bytes_or_text_types') _allowed_symbols = [ 'as_str', diff --git a/tensorflow/python/util/nest.py b/tensorflow/python/util/nest.py index 874df3d1087e157f8bfcec12ba3495e341c14b7b..c8525ed42039e151f2b44c472690daf1b0727be7 100644 --- a/tensorflow/python/util/nest.py +++ b/tensorflow/python/util/nest.py @@ -532,8 +532,8 @@ def assert_shallow_structure(shallow_tree, input_tree, check_types=True): (list(_six.iterkeys(input_tree)), list(_six.iterkeys(shallow_tree)))) - input_tree = list(_six.iteritems(input_tree)) - shallow_tree = list(_six.iteritems(shallow_tree)) + input_tree = list(sorted(_six.iteritems(input_tree))) + shallow_tree = list(sorted(_six.iteritems(shallow_tree))) for shallow_branch, input_branch in zip(shallow_tree, input_tree): assert_shallow_structure(shallow_branch, input_branch, diff --git a/tensorflow/python/util/nest_test.py b/tensorflow/python/util/nest_test.py index 6bec397db577c5be5847a701ccc92367dc008fc9..8aaf799fd05420d898a53a11d65e09f3a545e69d 100644 --- a/tensorflow/python/util/nest_test.py +++ b/tensorflow/python/util/nest_test.py @@ -425,6 +425,10 @@ class NestTest(test.TestCase): with self.assertRaisesRegexp(ValueError, expected_message): nest.assert_shallow_structure(inp_ab2, inp_ab1) + inp_ab = collections.OrderedDict([("a", 1), ("b", (2, 3))]) + inp_ba = collections.OrderedDict([("b", (2, 3)), ("a", 1)]) + nest.assert_shallow_structure(inp_ab, inp_ba) + def testFlattenUpTo(self): # Shallow tree ends at scalar. input_tree = [[[2, 2], [3, 3]], [[4, 9], [5, 5]]] diff --git a/tensorflow/tools/api/generator/BUILD b/tensorflow/tools/api/generator/BUILD index d11031639592aa1d3e6ce1c7f09c2f0679b29854..66bbd572a673e3ef2da9abc75348e4f70e0cea47 100644 --- a/tensorflow/tools/api/generator/BUILD +++ b/tensorflow/tools/api/generator/BUILD @@ -77,6 +77,16 @@ genrule( "api/nn/rnn_cell/__init__.py", "api/sets/__init__.py", "api/summary/__init__.py", + "api/train/queue_runner/__init__.py", + "api/compat/__init__.py", + "api/data/__init__.py", + "api/estimator/__init__.py", + "api/estimator/export/__init__.py", + "api/estimator/inputs/__init__.py", + "api/feature_column/__init__.py", + "api/losses/__init__.py", + "api/profiler/__init__.py", + "api/python_io/__init__.py", ], cmd = "$(location create_python_api) $(OUTS)", tools = ["create_python_api"], diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-baseline-classifier.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-baseline-classifier.pbtxt index ab697b1b95b15e3ac7974e7092f1d5934b088bb6..874a73f661d782ff5637b751f104fd2209734599 100644 --- a/tensorflow/tools/api/golden/tensorflow.estimator.-baseline-classifier.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.estimator.-baseline-classifier.pbtxt @@ -21,7 +21,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'model_dir\', \'n_classes\', \'weight_column\', \'label_vocabulary\', \'optimizer\', \'config\'], varargs=None, keywords=None, defaults=[\'None\', \'2\', \'None\', \'None\', \'Ftrl\', \'None\'], " + argspec: "args=[\'self\', \'model_dir\', \'n_classes\', \'weight_column\', \'label_vocabulary\', \'optimizer\', \'config\', \'loss_reduction\'], varargs=None, keywords=None, defaults=[\'None\', \'2\', \'None\', \'None\', \'Ftrl\', \'None\', \'weighted_sum\'], " } member_method { name: "evaluate" diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-baseline-regressor.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-baseline-regressor.pbtxt index b73f6433e226f6b570b68c6a419c53d5c808d9d6..8da2a2b6867a3f9a3d82fcdb76ac4a62d5cee825 100644 --- a/tensorflow/tools/api/golden/tensorflow.estimator.-baseline-regressor.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.estimator.-baseline-regressor.pbtxt @@ -21,7 +21,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'model_dir\', \'label_dimension\', \'weight_column\', \'optimizer\', \'config\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'None\', \'Ftrl\', \'None\'], " + argspec: "args=[\'self\', \'model_dir\', \'label_dimension\', \'weight_column\', \'optimizer\', \'config\', \'loss_reduction\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'None\', \'Ftrl\', \'None\', \'weighted_sum\'], " } member_method { name: "evaluate" diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-estimator-spec.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-estimator-spec.pbtxt index dbcc187f94509e3c9265d59cb76d0cdd01bd2333..aa6ac46613fbead7457b19e1aae5f2532afddef1 100644 --- a/tensorflow/tools/api/golden/tensorflow.estimator.-estimator-spec.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.estimator.-estimator-spec.pbtxt @@ -23,6 +23,10 @@ tf_class { name: "mode" mtype: "" } + member { + name: "prediction_hooks" + mtype: "" + } member { name: "predictions" mtype: "" diff --git a/tensorflow/tools/api/golden/tensorflow.pbtxt b/tensorflow/tools/api/golden/tensorflow.pbtxt index e8890e9cc0a3c659b3f5f377136a2ca616d55993..066c4513ff5185b50bdf193f579e71e505dbd3b6 100644 --- a/tensorflow/tools/api/golden/tensorflow.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.pbtxt @@ -2056,6 +2056,18 @@ tf_module { name: "unsorted_segment_max" argspec: "args=[\'data\', \'segment_ids\', \'num_segments\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "unsorted_segment_min" + argspec: "args=[\'data\', \'segment_ids\', \'num_segments\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "unsorted_segment_prod" + argspec: "args=[\'data\', \'segment_ids\', \'num_segments\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "unsorted_segment_sqrt_n" + argspec: "args=[\'data\', \'segment_ids\', \'num_segments\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "unsorted_segment_sum" argspec: "args=[\'data\', \'segment_ids\', \'num_segments\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/api/tests/BUILD b/tensorflow/tools/api/tests/BUILD index 8fb6b1cdfd8981e427062e186f6ac26b24231b8b..608a34ab7b32bdc26cebbe43b383155406fb51b2 100644 --- a/tensorflow/tools/api/tests/BUILD +++ b/tensorflow/tools/api/tests/BUILD @@ -17,10 +17,6 @@ py_test( name = "api_compatibility_test", srcs = ["api_compatibility_test.py"], data = [ - ":convert_from_multiline", - "//tensorflow/core/api_def:base_api_def", - "//tensorflow/core/api_def:python_api_def", - "//tensorflow/python:hidden_ops", "//tensorflow/tools/api/golden:api_golden", "//tensorflow/tools/api/tests:API_UPDATE_WARNING.txt", "//tensorflow/tools/api/tests:README.txt", @@ -29,7 +25,6 @@ py_test( deps = [ "//tensorflow:tensorflow_py", "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_test_lib", "//tensorflow/python:lib", "//tensorflow/python:platform", "//tensorflow/tools/api/lib:python_object_to_proto_visitor", diff --git a/tensorflow/tools/api/tests/api_compatibility_test.py b/tensorflow/tools/api/tests/api_compatibility_test.py index afcbf50944cc47b3ae3086b17279f2ce2fdc6ee7..c1e09cc531ed8e8995e3e73b87e96b72fba6c038 100644 --- a/tensorflow/tools/api/tests/api_compatibility_test.py +++ b/tensorflow/tools/api/tests/api_compatibility_test.py @@ -28,10 +28,8 @@ from __future__ import division from __future__ import print_function import argparse -from collections import defaultdict import os import re -import subprocess import sys import unittest @@ -39,7 +37,6 @@ import tensorflow as tf from google.protobuf import text_format -from tensorflow.core.framework import api_def_pb2 from tensorflow.python.lib.io import file_io from tensorflow.python.platform import resource_loader from tensorflow.python.platform import test @@ -67,11 +64,6 @@ _API_GOLDEN_FOLDER = 'tensorflow/tools/api/golden' _TEST_README_FILE = 'tensorflow/tools/api/tests/README.txt' _UPDATE_WARNING_FILE = 'tensorflow/tools/api/tests/API_UPDATE_WARNING.txt' -_CONVERT_FROM_MULTILINE_SCRIPT = 'tensorflow/tools/api/tests/convert_from_multiline' -_BASE_API_DIR = 'tensorflow/core/api_def/base_api' -_PYTHON_API_DIR = 'tensorflow/core/api_def/python_api' -_HIDDEN_OPS_FILE = 'tensorflow/python/ops/hidden_ops.txt' - def _KeyToFilePath(key): """From a given key, construct a filepath.""" @@ -96,55 +88,6 @@ def _FileNameToKey(filename): return api_object_key -def _GetSymbol(symbol_id): - """Get TensorFlow symbol based on the given identifier. - - Args: - symbol_id: Symbol identifier in the form module1.module2. ... .sym. - - Returns: - Symbol corresponding to the given id. - """ - # Ignore first module which should be tensorflow - symbol_id_split = symbol_id.split('.')[1:] - symbol = tf - for sym in symbol_id_split: - symbol = getattr(symbol, sym) - return symbol - - -def _IsGenModule(module_name): - if not module_name: - return False - module_name_split = module_name.split('.') - return module_name_split[-1].startswith('gen_') - - -def _GetHiddenOps(): - hidden_ops_file = file_io.FileIO(_HIDDEN_OPS_FILE, 'r') - hidden_ops = set() - for line in hidden_ops_file: - line = line.strip() - if not line: - continue - if line[0] == '#': # comment line - continue - # If line is of the form "op_name # comment", only keep the op_name. - line_split = line.split('#') - hidden_ops.add(line_split[0].strip()) - return hidden_ops - - -def _GetGoldenApiDefs(): - old_api_def_files = file_io.get_matching_files(_GetApiDefFilePath('*')) - return {file_path: file_io.read_file_to_string(file_path) - for file_path in old_api_def_files} - - -def _GetApiDefFilePath(graph_op_name): - return os.path.join(_PYTHON_API_DIR, 'api_def_%s.pbtxt' % graph_op_name) - - class ApiCompatibilityTest(test.TestCase): def __init__(self, *args, **kwargs): @@ -287,188 +230,6 @@ class ApiCompatibilityTest(test.TestCase): update_goldens=FLAGS.update_goldens) -class ApiDefTest(test.TestCase): - - def __init__(self, *args, **kwargs): - super(ApiDefTest, self).__init__(*args, **kwargs) - self._first_cap_pattern = re.compile('(.)([A-Z][a-z]+)') - self._all_cap_pattern = re.compile('([a-z0-9])([A-Z])') - - def _GenerateLowerCaseOpName(self, op_name): - lower_case_name = self._first_cap_pattern.sub(r'\1_\2', op_name) - return self._all_cap_pattern.sub(r'\1_\2', lower_case_name).lower() - - def _CreatePythonApiDef(self, base_api_def, endpoint_names): - """Creates Python ApiDef that overrides base_api_def if needed. - - Args: - base_api_def: (api_def_pb2.ApiDef) base ApiDef instance. - endpoint_names: List of Python endpoint names. - - Returns: - api_def_pb2.ApiDef instance with overrides for base_api_def - if module.name endpoint is different from any existing - endpoints in base_api_def. Otherwise, returns None. - """ - endpoint_names_set = set(endpoint_names) - - # If the only endpoint is equal to graph_op_name then - # it is equivalent to having no endpoints. - if (not base_api_def.endpoint and len(endpoint_names) == 1 - and endpoint_names[0] == - self._GenerateLowerCaseOpName(base_api_def.graph_op_name)): - return None - - base_endpoint_names_set = { - self._GenerateLowerCaseOpName(endpoint.name) - for endpoint in base_api_def.endpoint} - - if endpoint_names_set == base_endpoint_names_set: - return None # All endpoints are the same - - api_def = api_def_pb2.ApiDef() - api_def.graph_op_name = base_api_def.graph_op_name - - for endpoint_name in sorted(endpoint_names): - new_endpoint = api_def.endpoint.add() - new_endpoint.name = endpoint_name - - return api_def - - def _GetBaseApiMap(self): - """Get a map from graph op name to its base ApiDef. - - Returns: - Dictionary mapping graph op name to corresponding ApiDef. - """ - # Convert base ApiDef in Multiline format to Proto format. - converted_base_api_dir = os.path.join( - test.get_temp_dir(), 'temp_base_api_defs') - subprocess.check_call( - [os.path.join(resource_loader.get_root_dir_with_all_resources(), - _CONVERT_FROM_MULTILINE_SCRIPT), - _BASE_API_DIR, converted_base_api_dir]) - - name_to_base_api_def = {} - base_api_files = file_io.get_matching_files( - os.path.join(converted_base_api_dir, 'api_def_*.pbtxt')) - for base_api_file in base_api_files: - if file_io.file_exists(base_api_file): - api_defs = api_def_pb2.ApiDefs() - text_format.Merge( - file_io.read_file_to_string(base_api_file), api_defs) - for api_def in api_defs.op: - name_to_base_api_def[api_def.graph_op_name] = api_def - return name_to_base_api_def - - def _AddHiddenOpOverrides(self, name_to_base_api_def, api_def_map): - """Adds ApiDef overrides to api_def_map for hidden Python ops. - - Args: - name_to_base_api_def: Map from op name to base api_def_pb2.ApiDef. - api_def_map: Map from file path to api_def_pb2.ApiDefs for Python API - overrides. - """ - hidden_ops = _GetHiddenOps() - for hidden_op in hidden_ops: - if hidden_op not in name_to_base_api_def: - logging.warning('Unexpected hidden op name: %s' % hidden_op) - continue - - base_api_def = name_to_base_api_def[hidden_op] - if base_api_def.visibility != api_def_pb2.ApiDef.HIDDEN: - api_def = api_def_pb2.ApiDef() - api_def.graph_op_name = base_api_def.graph_op_name - api_def.visibility = api_def_pb2.ApiDef.HIDDEN - - file_path = _GetApiDefFilePath(base_api_def.graph_op_name) - api_def_map[file_path].op.extend([api_def]) - - @unittest.skipUnless( - sys.version_info.major == 2 and os.uname()[0] == 'Linux', - 'API compabitility test goldens are generated using python2 on Linux.') - def testAPIDefCompatibility(self): - # Get base ApiDef - name_to_base_api_def = self._GetBaseApiMap() - snake_to_camel_graph_op_names = { - self._GenerateLowerCaseOpName(name): name - for name in name_to_base_api_def.keys()} - # Extract Python API - visitor = python_object_to_proto_visitor.PythonObjectToProtoVisitor() - public_api_visitor = public_api.PublicAPIVisitor(visitor) - public_api_visitor.do_not_descend_map['tf'].append('contrib') - traverse.traverse(tf, public_api_visitor) - proto_dict = visitor.GetProtos() - - # Map from file path to Python ApiDefs. - new_api_defs_map = defaultdict(api_def_pb2.ApiDefs) - # We need to override all endpoints even if 1 endpoint differs from base - # ApiDef. So, we first create a map from an op to all its endpoints. - op_to_endpoint_name = defaultdict(list) - - # Generate map from generated python op to endpoint names. - for public_module, value in proto_dict.items(): - module_obj = _GetSymbol(public_module) - for sym in value.tf_module.member_method: - obj = getattr(module_obj, sym.name) - - # Check if object is defined in gen_* module. That is, - # the object has been generated from OpDef. - if hasattr(obj, '__module__') and _IsGenModule(obj.__module__): - if obj.__name__ not in snake_to_camel_graph_op_names: - # Symbol might be defined only in Python and not generated from - # C++ api. - continue - relative_public_module = public_module[len('tensorflow.'):] - full_name = (relative_public_module + '.' + sym.name - if relative_public_module else sym.name) - op_to_endpoint_name[obj].append(full_name) - - # Generate Python ApiDef overrides. - for op, endpoint_names in op_to_endpoint_name.items(): - graph_op_name = snake_to_camel_graph_op_names[op.__name__] - api_def = self._CreatePythonApiDef( - name_to_base_api_def[graph_op_name], endpoint_names) - - if api_def: - file_path = _GetApiDefFilePath(graph_op_name) - api_defs = new_api_defs_map[file_path] - api_defs.op.extend([api_def]) - - self._AddHiddenOpOverrides(name_to_base_api_def, new_api_defs_map) - - old_api_defs_map = _GetGoldenApiDefs() - for file_path, new_api_defs in new_api_defs_map.items(): - # Get new ApiDef string. - new_api_defs_str = str(new_api_defs) - - # Get current ApiDef for the given file. - old_api_defs_str = ( - old_api_defs_map[file_path] if file_path in old_api_defs_map else '') - - if old_api_defs_str == new_api_defs_str: - continue - - if FLAGS.update_goldens: - logging.info('Updating %s...' % file_path) - file_io.write_string_to_file(file_path, new_api_defs_str) - else: - self.assertMultiLineEqual( - old_api_defs_str, new_api_defs_str, - 'To update golden API files, run api_compatibility_test locally ' - 'with --update_goldens=True flag.') - - for file_path in set(old_api_defs_map) - set(new_api_defs_map): - if FLAGS.update_goldens: - logging.info('Deleting %s...' % file_path) - file_io.delete_file(file_path) - else: - self.fail( - '%s file is no longer needed and should be removed.' - 'To update golden API files, run api_compatibility_test locally ' - 'with --update_goldens=True flag.' % file_path) - - if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument( diff --git a/tensorflow/tools/ci_build/ci_sanity.sh b/tensorflow/tools/ci_build/ci_sanity.sh index 6f4f0f98597ba4a87cfbf5406e33545a53e978b0..fd5d0058441c366be76defdde5dac73405837262 100755 --- a/tensorflow/tools/ci_build/ci_sanity.sh +++ b/tensorflow/tools/ci_build/ci_sanity.sh @@ -184,7 +184,8 @@ do_pylint() { # W0312 mixed-indentation # C0330 bad-continuation # C0301 line-too-long - grep -E '(\[E|\[W0311|\[W0312|\[C0330|\[C0301)' ${OUTPUT_FILE} > ${ERRORS_FILE} + # C0326 bad-whitespace + grep -E '(\[E|\[W0311|\[W0312|\[C0330|\[C0301|\[C0326)' ${OUTPUT_FILE} > ${ERRORS_FILE} N_ERRORS=0 while read -r LINE; do @@ -320,7 +321,7 @@ do_external_licenses_check(){ EXTRA_LICENSES_FILE="$(mktemp)_extra_licenses.log" echo "Getting external dependencies for ${BUILD_TARGET}" - bazel query "attr('licenses', 'notice', deps(${BUILD_TARGET}))" --no_implicit_deps --no_host_deps --keep_going \ + bazel query "attr('licenses', 'notice', deps(${BUILD_TARGET}))" --keep_going \ | grep -E -v "^//tensorflow" \ | sed -e 's|:.*||' \ | sort \ @@ -329,7 +330,7 @@ do_external_licenses_check(){ echo echo "Getting list of external licenses mentioned in ${LICENSES_TARGET}." - bazel query "deps(${LICENSES_TARGET})" --no_implicit_deps --no_host_deps --keep_going \ + bazel query "deps(${LICENSES_TARGET})" --keep_going \ | grep -E -v "^//tensorflow" \ | sed -e 's|:.*||' \ | sort \ @@ -343,6 +344,18 @@ do_external_licenses_check(){ EXTERNAL_LICENSES_CHECK_END_TIME=$(date +'%s') + # Blacklist + echo ${MISSING_LICENSES_FILE} + grep -e "@bazel_tools//third_party/" -e "@com_google_absl//absl" -e "@org_tensorflow//" -v ${MISSING_LICENSES_FILE} > temp.txt + mv temp.txt ${MISSING_LICENSES_FILE} + + # Whitelist + echo ${EXTRA_LICENSE_FILE} + grep -e "@bazel_tools//src" -e "@bazel_tools//tools/" -e "@com_google_absl//" -e "//external" -e "@local" -v ${EXTRA_LICENSES_FILE} > temp.txt + mv temp.txt ${EXTRA_LICENSES_FILE} + + + echo echo "do_external_licenses_check took $((EXTERNAL_LICENSES_CHECK_END_TIME - EXTERNAL_LICENSES_CHECK_START_TIME)) s" echo @@ -516,9 +529,14 @@ do_check_futures_test() { python check_futures_test.py } +do_check_file_name_test() { + cd "$ROOT_DIR/tensorflow/tools/test" + python file_name_test.py +} + # Supply all sanity step commands and descriptions -SANITY_STEPS=("do_pylint PYTHON2" "do_pylint PYTHON3" "do_check_futures_test" "do_buildifier" "do_bazel_nobuild" "do_pip_package_licenses_check" "do_lib_package_licenses_check" "do_java_package_licenses_check" "do_pip_smoke_test" "do_check_load_py_test" "do_code_link_check" "do_cmake_python_sanity") -SANITY_STEPS_DESC=("Python 2 pylint" "Python 3 pylint" "Check that python files have certain __future__ imports" "buildifier check" "bazel nobuild" "pip: license check for external dependencies" "C library: license check for external dependencies" "Java Native Library: license check for external dependencies" "Pip Smoke Test: Checking py_test dependencies exist in pip package" "Check load py_test: Check that BUILD files with py_test target properly load py_test" "Code Link Check: Check there are no broken links" "Test entries in /tensorflow/contrib/cmake/python_{modules|protos|protos_cc}.txt for validity and consistency") +SANITY_STEPS=("do_pylint PYTHON2" "do_pylint PYTHON3" "do_check_futures_test" "do_buildifier" "do_bazel_nobuild" "do_pip_package_licenses_check" "do_lib_package_licenses_check" "do_java_package_licenses_check" "do_pip_smoke_test" "do_check_load_py_test" "do_code_link_check" "do_cmake_python_sanity" "do_check_file_name_test") +SANITY_STEPS_DESC=("Python 2 pylint" "Python 3 pylint" "Check that python files have certain __future__ imports" "buildifier check" "bazel nobuild" "pip: license check for external dependencies" "C library: license check for external dependencies" "Java Native Library: license check for external dependencies" "Pip Smoke Test: Checking py_test dependencies exist in pip package" "Check load py_test: Check that BUILD files with py_test target properly load py_test" "Code Link Check: Check there are no broken links" "Test entries in /tensorflow/contrib/cmake/python_{modules|protos|protos_cc}.txt for validity and consistency" "Check file names for cases") INCREMENTAL_FLAG="" DEFAULT_BAZEL_CONFIGS="--config=hdfs --config=gcp" diff --git a/tensorflow/tools/ci_build/install/install_pip_packages.sh b/tensorflow/tools/ci_build/install/install_pip_packages.sh index 71744c04f2f432bc76eadfac406233ad8241a52a..d406b83a6246d18c335fb52cea1256d7809fa61a 100755 --- a/tensorflow/tools/ci_build/install/install_pip_packages.sh +++ b/tensorflow/tools/ci_build/install/install_pip_packages.sh @@ -43,8 +43,8 @@ pip2 install --upgrade werkzeug==0.11.10 pip3 install --upgrade werkzeug==0.11.10 # Install bleach. html5lib will be picked up as a dependency. -pip2 install --upgrade bleach==1.5.0 -pip3 install --upgrade bleach==1.5.0 +pip2 install --upgrade bleach==2.0.0 +pip3 install --upgrade bleach==2.0.0 # Install markdown. pip2 install --upgrade markdown==2.6.8 diff --git a/tensorflow/tools/ci_build/windows/cpu/cmake/run_build.bat b/tensorflow/tools/ci_build/windows/cpu/cmake/run_build.bat index 957729bb37db3ae49800c277f4090a52117c699d..c1bc71850754c5b4b42a6eb50be465ba8f98c218 100644 --- a/tensorflow/tools/ci_build/windows/cpu/cmake/run_build.bat +++ b/tensorflow/tools/ci_build/windows/cpu/cmake/run_build.bat @@ -36,7 +36,7 @@ SET CMAKE_DIR=%REPO_ROOT%\tensorflow\contrib\cmake SET MSBUILD_EXE="C:\Program Files (x86)\MSBuild\14.0\Bin\msbuild.exe" :: Run cmake to create Visual Studio Project files. -%CMAKE_EXE% %CMAKE_DIR% -A x64 -DSWIG_EXECUTABLE=%SWIG_EXE% -DPYTHON_EXECUTABLE=%PY_EXE% -DCMAKE_BUILD_TYPE=Release -DPYTHON_LIBRARIES=%PY_LIB% -Dtensorflow_BUILD_PYTHON_TESTS=%BUILD_PYTHON_TESTS% -Dtensorflow_BUILD_CC_TESTS=%BUILD_CC_TESTS% -Dtensorflow_TF_NIGHTLY=%TF_NIGHTLY% -Dtensorflow_DISABLE_EIGEN_FORCEINLINE=%DISABLE_FORCEINLINE% +%CMAKE_EXE% %CMAKE_DIR% -A x64 -DSWIG_EXECUTABLE=%SWIG_EXE% -DPYTHON_EXECUTABLE=%PY_EXE% -DCMAKE_BUILD_TYPE=Release -DPYTHON_LIBRARIES=%PY_LIB% -Dtensorflow_BUILD_PYTHON_TESTS=%BUILD_PYTHON_TESTS% -Dtensorflow_BUILD_CC_TESTS=%BUILD_CC_TESTS% -Dtensorflow_TF_NIGHTLY=%TF_NIGHTLY% -Dtensorflow_DISABLE_EIGEN_FORCEINLINE=%DISABLE_FORCEINLINE% -Dtensorflow_WIN_CPU_SIMD_OPTIONS=/arch:AVX :: Run msbuild in the resulting VS project files to build a pip package. %MSBUILD_EXE% /p:Configuration=Release /maxcpucount:32 tf_python_build_pip_package.vcxproj diff --git a/tensorflow/tools/ci_build/windows/gpu/cmake/run_build.bat b/tensorflow/tools/ci_build/windows/gpu/cmake/run_build.bat index 5a362de3992156fea8a5fc6ab4c70ba67ab47f89..b87e4a9bec41264827d415a11dfa6f23aeda725d 100644 --- a/tensorflow/tools/ci_build/windows/gpu/cmake/run_build.bat +++ b/tensorflow/tools/ci_build/windows/gpu/cmake/run_build.bat @@ -37,7 +37,7 @@ SET CMAKE_DIR=%REPO_ROOT%\tensorflow\contrib\cmake SET MSBUILD_EXE="C:\Program Files (x86)\MSBuild\14.0\Bin\msbuild.exe" :: Run cmake to create Visual Studio Project files. -%CMAKE_EXE% %CMAKE_DIR% -A x64 -DSWIG_EXECUTABLE=%SWIG_EXE% -DPYTHON_EXECUTABLE=%PY_EXE% -DCMAKE_BUILD_TYPE=Release -DPYTHON_LIBRARIES=%PY_LIB% -Dtensorflow_BUILD_PYTHON_TESTS=%BUILD_PYTHON_TESTS% -Dtensorflow_BUILD_CC_TESTS=%BUILD_CC_TESTS% -Dtensorflow_ENABLE_GPU=ON -DCUDNN_HOME=%CUDNN_HOME% -Dtensorflow_TF_NIGHTLY=%TF_NIGHTLY% -Dtensorflow_DISABLE_EIGEN_FORCEINLINE=%DISABLE_FORCEINLINE% +%CMAKE_EXE% %CMAKE_DIR% -A x64 -DSWIG_EXECUTABLE=%SWIG_EXE% -DPYTHON_EXECUTABLE=%PY_EXE% -DCMAKE_BUILD_TYPE=Release -DPYTHON_LIBRARIES=%PY_LIB% -Dtensorflow_BUILD_PYTHON_TESTS=%BUILD_PYTHON_TESTS% -Dtensorflow_BUILD_CC_TESTS=%BUILD_CC_TESTS% -Dtensorflow_ENABLE_GPU=ON -DCUDNN_HOME=%CUDNN_HOME% -Dtensorflow_TF_NIGHTLY=%TF_NIGHTLY% -Dtensorflow_DISABLE_EIGEN_FORCEINLINE=%DISABLE_FORCEINLINE% -Dtensorflow_WIN_CPU_SIMD_OPTIONS=/arch:AVX :: Run msbuild in the resulting VS project files to build a pip package. %MSBUILD_EXE% /p:Configuration=Release /maxcpucount:32 tf_python_build_pip_package.vcxproj diff --git a/tensorflow/tools/ci_build/windows/libtensorflow_cpu.sh b/tensorflow/tools/ci_build/windows/libtensorflow_cpu.sh index fa28e3d79ca4ee5f429a41dd3e871248d5c047ca..583d1d5f09527861015458c636af2259b34d45f8 100755 --- a/tensorflow/tools/ci_build/windows/libtensorflow_cpu.sh +++ b/tensorflow/tools/ci_build/windows/libtensorflow_cpu.sh @@ -41,7 +41,7 @@ run_configure_for_cpu_build # build_libtensorflow_tarball in ../builds/libtensorflow.sh # cannot be used on Windows since it relies on pkg_tar rules. # So we do something special here -bazel build -c opt \ +bazel build -c opt --copt=/arch:AVX \ tensorflow:libtensorflow.so \ tensorflow/tools/lib_package:clicenses_generate \ tensorflow/java:libtensorflow_jni.so \ diff --git a/tensorflow/tools/ci_build/windows/libtensorflow_gpu.sh b/tensorflow/tools/ci_build/windows/libtensorflow_gpu.sh index 573c926203fc76b787ba08b10bd71c8effda29b6..94276c6c5c9ce897ca24f03efe3d93e1ea1e00c9 100644 --- a/tensorflow/tools/ci_build/windows/libtensorflow_gpu.sh +++ b/tensorflow/tools/ci_build/windows/libtensorflow_gpu.sh @@ -41,7 +41,7 @@ run_configure_for_gpu_build # build_libtensorflow_tarball in ../builds/libtensorflow.sh # cannot be used on Windows since it relies on pkg_tar rules. # So we do something special here -bazel build -c opt \ +bazel build -c opt --copt=/arch:AVX \ tensorflow:libtensorflow.so \ tensorflow/tools/lib_package:clicenses_generate \ tensorflow/java:libtensorflow_jni.so \ diff --git a/tensorflow/tools/compatibility/tf_upgrade.py b/tensorflow/tools/compatibility/tf_upgrade.py index f678681dac27805d6748b426698b4fe2a7c08067..6e90b286c99f894ddd25268afc69043759571c36 100644 --- a/tensorflow/tools/compatibility/tf_upgrade.py +++ b/tensorflow/tools/compatibility/tf_upgrade.py @@ -46,8 +46,9 @@ class APIChangeSpec(object): """ -class _FileEditTuple(collections.namedtuple( - "_FileEditTuple", ["comment", "line", "start", "old", "new"])): +class _FileEditTuple( + collections.namedtuple("_FileEditTuple", + ["comment", "line", "start", "old", "new"])): """Each edit that is recorded by a _FileEditRecorder. Fields: @@ -179,8 +180,7 @@ class _ASTCallVisitor(ast.NodeVisitor): function_renames = self._api_change_spec.function_renames try: new_name = function_renames[full_name] - self._file_edit.add("Renamed function %r to %r" % (full_name, - new_name), + self._file_edit.add("Renamed function %r to %r" % (full_name, new_name), node.lineno, node.col_offset, full_name, new_name) except KeyError: pass @@ -227,7 +227,7 @@ class _ASTCallVisitor(ast.NodeVisitor): # loop over lines while 1: # Reverse the text to and regular expression search for whitespace - text = self._lines[line-1] + text = self._lines[line - 1] reversed_preceding_text = text[:col][::-1] # First find if a [ can be found with only whitespace between it and # col. @@ -248,8 +248,8 @@ class _ASTCallVisitor(ast.NodeVisitor): # node ranges to filter out spurious #'s that appear in string # literals. comment_start = prev_line.find("#") - if comment_start == -1: - col = len(prev_line) -1 + if comment_start == -1: + col = len(prev_line) - 1 elif find_string_chars.search(prev_line[comment_start:]) is None: col = comment_start else: @@ -260,7 +260,6 @@ class _ASTCallVisitor(ast.NodeVisitor): # it is not possible to use that in an argument. return node.lineno, node.col_offset - def visit_Call(self, node): # pylint: disable=invalid-name """Handle visiting a call node in the AST. @@ -268,7 +267,6 @@ class _ASTCallVisitor(ast.NodeVisitor): node: Current Node """ - # Find a simple attribute name path e.g. "tf.foo.bar" full_name = self._get_attribute_full_path(node.func) @@ -293,18 +291,21 @@ class _ASTCallVisitor(ast.NodeVisitor): lineno, col_offset = self._find_true_position(arg) if lineno is None or col_offset is None: self._file_edit.add( - "Failed to add keyword %r to reordered function %r" - % (reordered[idx], full_name), arg.lineno, arg.col_offset, - "", "", + "Failed to add keyword %r to reordered function %r" % + (reordered[idx], full_name), + arg.lineno, + arg.col_offset, + "", + "", error="A necessary keyword argument failed to be inserted.") else: keyword_arg = reordered[idx] if (full_name in function_keyword_renames and keyword_arg in function_keyword_renames[full_name]): keyword_arg = function_keyword_renames[full_name][keyword_arg] - self._file_edit.add("Added keyword %r to reordered function %r" - % (reordered[idx], full_name), lineno, - col_offset, "", keyword_arg + "=") + self._file_edit.add("Added keyword %r to reordered function %r" % + (reordered[idx], full_name), lineno, col_offset, + "", keyword_arg + "=") # Examine each keyword argument and convert it to the final renamed form renamed_keywords = ({} if full_name not in function_keyword_renames else @@ -322,11 +323,11 @@ class _ASTCallVisitor(ast.NodeVisitor): # value. key_start = argval_col_offset - len(argkey) - 1 key_end = key_start + len(argkey) + 1 - if (self._lines[argval_lineno - 1][key_start:key_end] == - argkey + "="): + if (self._lines[argval_lineno - 1][key_start:key_end] == argkey + + "="): self._file_edit.add("Renamed keyword argument from %r to %r" % - (argkey, renamed_keywords[argkey]), - argval_lineno, + (argkey, + renamed_keywords[argkey]), argval_lineno, argval_col_offset - len(argkey) - 1, argkey + "=", renamed_keywords[argkey] + "=") continue @@ -335,7 +336,8 @@ class _ASTCallVisitor(ast.NodeVisitor): (argkey, renamed_keywords[argkey]), argval.lineno, argval.col_offset - len(argkey) - 1, - "", "", + "", + "", error="Failed to find keyword lexographically. Fix manually.") ast.NodeVisitor.generic_visit(self, node) @@ -352,7 +354,7 @@ class _ASTCallVisitor(ast.NodeVisitor): if full_name in self._api_change_spec.change_to_function: if not hasattr(node, "is_function_for_call"): new_text = full_name + "()" - self._file_edit.add("Changed %r to %r"%(full_name, new_text), + self._file_edit.add("Changed %r to %r" % (full_name, new_text), node.lineno, node.col_offset, full_name, new_text) ast.NodeVisitor.generic_visit(self, node) @@ -380,8 +382,8 @@ class ASTCodeUpgrader(object): # Write to a temporary file, just in case we are doing an implace modify. with open(in_filename, "r") as in_file, \ tempfile.NamedTemporaryFile("w", delete=False) as temp_file: - ret = self.process_opened_file( - in_filename, in_file, out_filename, temp_file) + ret = self.process_opened_file(in_filename, in_file, out_filename, + temp_file) shutil.move(temp_file.name, out_filename) return ret @@ -424,6 +426,7 @@ class ASTCodeUpgrader(object): out_file.write(out_text) text += "\n" return 1, text, process_errors + # pylint: enable=broad-except def process_tree(self, root_directory, output_root_directory, @@ -444,16 +447,16 @@ class ASTCodeUpgrader(object): # make sure output directory doesn't exist if output_root_directory and os.path.exists(output_root_directory): - print("Output directory %r must not already exist." % ( - output_root_directory)) + print("Output directory %r must not already exist." % + (output_root_directory)) sys.exit(1) # make sure output directory does not overlap with root_directory norm_root = os.path.split(os.path.normpath(root_directory)) norm_output = os.path.split(os.path.normpath(output_root_directory)) if norm_root == norm_output: - print("Output directory %r same as input directory %r" % ( - root_directory, output_root_directory)) + print("Output directory %r same as input directory %r" % + (root_directory, output_root_directory)) sys.exit(1) # Collect list of files to process (we do this to correctly handle if the @@ -465,14 +468,16 @@ class ASTCodeUpgrader(object): copy_files = [f for f in file_list if not f.endswith(".py")] for filename in py_files: fullpath = os.path.join(dir_name, filename) - fullpath_output = os.path.join( - output_root_directory, os.path.relpath(fullpath, root_directory)) + fullpath_output = os.path.join(output_root_directory, + os.path.relpath(fullpath, + root_directory)) files_to_process.append((fullpath, fullpath_output)) if copy_other_files: for filename in copy_files: fullpath = os.path.join(dir_name, filename) - fullpath_output = os.path.join( - output_root_directory, os.path.relpath(fullpath, root_directory)) + fullpath_output = os.path.join(output_root_directory, + os.path.relpath( + fullpath, root_directory)) files_to_copy.append((fullpath, fullpath_output)) file_count = 0 @@ -641,18 +646,17 @@ class TFAPIChangeSpec(APIChangeSpec): "tf.concat": ["concat_dim", "values", "name"], "tf.svd": ["tensor", "compute_uv", "full_matrices", "name"], "tf.nn.softmax_cross_entropy_with_logits": [ - "logits", "labels", "dim", "name"], + "logits", "labels", "dim", "name" + ], "tf.nn.sparse_softmax_cross_entropy_with_logits": [ - "logits", "labels", "name"], - "tf.nn.sigmoid_cross_entropy_with_logits": [ - "logits", "labels", "name"], + "logits", "labels", "name" + ], + "tf.nn.sigmoid_cross_entropy_with_logits": ["logits", "labels", "name"], "tf.op_scope": ["values", "name", "default_name"], } # Specially handled functions. - self.function_handle = { - "tf.reverse": self._reverse_handler - } + self.function_handle = {"tf.reverse": self._reverse_handler} @staticmethod def _reverse_handler(file_edit_recorder, node): @@ -661,12 +665,13 @@ class TFAPIChangeSpec(APIChangeSpec): comment = ("ERROR: tf.reverse has had its argument semantics changed\n" "significantly the converter cannot detect this reliably, so you" "need to inspect this usage manually.\n") - file_edit_recorder.add(comment, - node.lineno, - node.col_offset, - "tf.reverse", - "tf.reverse", - error="tf.reverse requires manual check.") + file_edit_recorder.add( + comment, + node.lineno, + node.col_offset, + "tf.reverse", + "tf.reverse", + error="tf.reverse requires manual check.") if __name__ == "__main__": diff --git a/tensorflow/tools/dist_test/build_server.sh b/tensorflow/tools/dist_test/build_server.sh index 878fabd248f3c1dd5cb79983df5220ebf5893026..225c0347416ec8c8fef855946d18e838bd767690 100755 --- a/tensorflow/tools/dist_test/build_server.sh +++ b/tensorflow/tools/dist_test/build_server.sh @@ -16,14 +16,15 @@ # # Builds the test server for distributed (GRPC) TensorFlow # -# Usage: build_server.sh [--test] +# Usage: build_server.sh [--test] # # Arguments: # docker_image_name: Name of the docker image to build. # E.g.: tensorflow/tf_grpc_test_server:0.11.0rc1 # -# whl_url: URL from which the TensorFlow whl file will be downloaded. +# whl_file_location: URL from which the TensorFlow whl file will be downloaded. # E.g.: https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.11.0rc1-cp27-none-linux_x86_64.whl +# E.g.: /path/to/folder/tensorflow-0.11.0rc1-cp27-none-linux_x86_64.whl # # The optional flag --test lets the script to use the Dockerfile for the # testing GRPC server. Without the flag, the script will build the non-test @@ -41,11 +42,11 @@ die() { # Check arguments if [[ $# -lt 2 ]]; then - die "Usage: $0 [--test]" + die "Usage: $0 [--test]" fi DOCKER_IMG_NAME=$1 -WHL_URL=$2 +WHL_FILE_LOCATION=$2 shift 2 # Current script directory @@ -53,7 +54,7 @@ DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" BUILD_DIR=$(mktemp -d) echo "" -echo "Using whl file URL: ${WHL_URL}" +echo "Using whl file URL: ${WHL_FILE_LOCATION}" echo "Building in temporary directory: ${BUILD_DIR}" cp -r ${DIR}/* "${BUILD_DIR}"/ || \ @@ -65,9 +66,15 @@ if [[ $1 == "--test" ]]; then fi echo "Using Docker file: ${DOCKER_FILE}" +if [[ $WHL_FILE_LOCATION =~ 'http://' || $WHL_FILE_LOCATION =~ 'https://' ]]; then + # Download whl file into the build context directory. + wget -P "${BUILD_DIR}" "${WHL_FILE_LOCATION}" || \ + die "Failed to download tensorflow whl file from URL: ${WHL_FILE_LOCATION}" +else + cp "${WHL_FILE_LOCATION}" "${BUILD_DIR}" +fi + # Download whl file into the build context directory. -wget -P "${BUILD_DIR}" ${WHL_URL} || \ - die "Failed to download tensorflow whl file from URL: ${WHL_URL}" if [[ ! -f "${DOCKER_FILE}" ]]; then die "ERROR: Unable to find dockerfile: ${DOCKER_FILE}" diff --git a/tensorflow/tools/dist_test/local_test.sh b/tensorflow/tools/dist_test/local_test.sh index 7d7f92d246e1ca0b519ac3bf30fde673621ff755..435f9d0dc9c55a3dcfc45e7e46f279b4679a9086 100755 --- a/tensorflow/tools/dist_test/local_test.sh +++ b/tensorflow/tools/dist_test/local_test.sh @@ -24,19 +24,20 @@ # 3) Call a script to launch a k8s TensorFlow GRPC cluster inside the container # and run the distributed test suite. # -# Usage: local_test.sh +# Usage: local_test.sh # [--leave_container_running] # [--model_name ] # [--num_workers ] # [--num_parameter_servers ] # [--sync_replicas] # -# E.g., local_test.sh --model_name CENSUS_WIDENDEEP -# local_test.sh --num_workers 3 --num_parameter_servers 3 +# E.g., local_test.sh --model_name CENSUS_WIDENDEEP +# local_test.sh --num_workers 3 --num_parameter_servers 3 # # Arguments: -# -# Specify custom TensorFlow whl file URL to install in the test Docker image. +# whl_file_location: URL from which the TensorFlow whl file will be acquired. +# E.g.: https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.11.0rc1-cp27-none-linux_x86_64.whl +# E.g.: /path/to/folder/tensorflow-0.11.0rc1-cp27-none-linux_x86_64.whl # # --leave_container_running: Do not stop the docker-in-docker container after # the termination of the tests, e.g., for debugging @@ -81,9 +82,9 @@ NUM_WORKERS=2 NUM_PARAMETER_SERVERS=2 SYNC_REPLICAS_FLAG="" -WHL_URL=${1} -if [[ -z "${WHL_URL}" ]]; then - die "whl file URL is not specified" +WHL_FILE_LOCATION=${1} +if [[ -z "${WHL_FILE_LOCATION}" ]]; then + die "whl file location is not specified" fi while true; do @@ -98,8 +99,8 @@ while true; do NUM_PARAMETER_SERVERS=$2 elif [[ $1 == "--sync_replicas" ]]; then SYNC_REPLICAS_FLAG="--sync_replicas" - elif [[ $1 == "--whl_url" ]]; then - WHL_URL=$2 + elif [[ $1 == "--WHL_FILE_LOCATION" ]]; then + WHL_FILE_LOCATION=$2 fi shift @@ -130,15 +131,19 @@ fi # Create docker build context directory. BUILD_DIR=$(mktemp -d) echo "" -echo "Using whl file URL: ${WHL_URL}" +echo "Using whl file location: ${WHL_FILE_LOCATION}" echo "Building in temporary directory: ${BUILD_DIR}" cp -r ${DIR}/* "${BUILD_DIR}"/ || \ die "Failed to copy files to ${BUILD_DIR}" -# Download whl file into the build context directory. -wget -P "${BUILD_DIR}" ${WHL_URL} || \ - die "Failed to download tensorflow whl file from URL: ${WHL_URL}" +if [[ $WHL_FILE_LOCATION =~ 'http://' || $WHL_FILE_LOCATION =~ 'https://' ]]; then + # Download whl file into the build context directory. + wget -P "${BUILD_DIR}" "${WHL_FILE_LOCATION}" || \ + die "Failed to download tensorflow whl file from URL: ${WHL_FILE_LOCATION}" +else + cp "${WHL_FILE_LOCATION}" "${BUILD_DIR}" +fi # Build docker image for test. docker build ${NO_CACHE_FLAG} -t ${DOCKER_IMG_NAME} \ diff --git a/tensorflow/tools/docker/Dockerfile.devel b/tensorflow/tools/docker/Dockerfile.devel index 5dc4a053fd2cae7d83739507fea31e7afc92d77c..d16761c3675942838fd2be0ea6e0b7463a3bf249 100644 --- a/tensorflow/tools/docker/Dockerfile.devel +++ b/tensorflow/tools/docker/Dockerfile.devel @@ -70,7 +70,7 @@ RUN mkdir /bazel && \ # Download and build TensorFlow. WORKDIR /tensorflow -RUN git clone --branch=r1.5 --depth=1 https://github.com/tensorflow/tensorflow.git . +RUN git clone --branch=r1.6 --depth=1 https://github.com/tensorflow/tensorflow.git . # TODO(craigcitro): Don't install the pip package, since it makes it # more difficult to experiment with local changes. Instead, just add diff --git a/tensorflow/tools/docker/Dockerfile.devel-cpu-mkl b/tensorflow/tools/docker/Dockerfile.devel-cpu-mkl index 96b260ad3aeb78622dd1ad276f7d524dd598e3bf..3690e7dfe57a4682276a90b10cb84c9a329b3f5e 100644 --- a/tensorflow/tools/docker/Dockerfile.devel-cpu-mkl +++ b/tensorflow/tools/docker/Dockerfile.devel-cpu-mkl @@ -3,7 +3,7 @@ FROM tensorflow/tensorflow:latest-devel LABEL maintainer="Clayne Robison" # These arguments are parameterized. Use --build-args to override. -ARG TF_BRANCH=r1.5 +ARG TF_BRANCH=r1.6 ARG WHL_DIR=/whl RUN apt-get update && apt-get install -y --no-install-recommends \ diff --git a/tensorflow/tools/docker/Dockerfile.devel-gpu b/tensorflow/tools/docker/Dockerfile.devel-gpu index 07ffd3839a32ef194100322e54b9133412e4b664..4ef37881bc91aaa58bab031c69b4a96c2a9d8ec1 100644 --- a/tensorflow/tools/docker/Dockerfile.devel-gpu +++ b/tensorflow/tools/docker/Dockerfile.devel-gpu @@ -79,7 +79,7 @@ RUN mkdir /bazel && \ # Download and build TensorFlow. WORKDIR /tensorflow -RUN git clone --branch=r1.5 --depth=1 https://github.com/tensorflow/tensorflow.git . +RUN git clone --branch=r1.6 --depth=1 https://github.com/tensorflow/tensorflow.git . # Configure the build for our CUDA configuration. ENV CI_BUILD_PYTHON python diff --git a/tensorflow/tools/docs/pretty_docs.py b/tensorflow/tools/docs/pretty_docs.py index b5df633800ae5a3ce67cf03910d472b9908d6249..543b5fa6fefcd8e8dca99ad7eac7cca76781ccd3 100644 --- a/tensorflow/tools/docs/pretty_docs.py +++ b/tensorflow/tools/docs/pretty_docs.py @@ -162,7 +162,7 @@ def _build_class_page(page_info): parts.append(h3.format(**method_info.__dict__)) if method_info.signature is not None: - parts.append(_build_signature(method_info)) + parts.append(_build_signature(method_info, use_full_name=False)) parts.append(method_info.doc.docstring) parts.append(_build_function_details(method_info.doc.function_details)) @@ -259,14 +259,14 @@ def _build_module_page(page_info): return ''.join(parts) -def _build_signature(obj_info): +def _build_signature(obj_info, use_full_name=True): """Returns a md code block showing the function signature.""" # Special case tf.range, since it has an optional first argument if obj_info.full_name == 'tf.range': return ( '``` python\n' - "range(limit, delta=1, dtype=None, name='range')\n" - "range(start, limit, delta=1, dtype=None, name='range')\n" + "tf.range(limit, delta=1, dtype=None, name='range')\n" + "tf.range(start, limit, delta=1, dtype=None, name='range')\n" '```\n\n') parts = ['``` python'] @@ -281,7 +281,11 @@ def _build_signature(obj_info): sig = ',\n'.join(' %s' % sig_item for sig_item in obj_info.signature) sig = '\n'+sig+'\n' - parts.append(signature_template.format(name=obj_info.short_name, sig=sig)) + if use_full_name: + obj_name = obj_info.full_name + else: + obj_name = obj_info.short_name + parts.append(signature_template.format(name=obj_name, sig=sig)) parts.append('```\n\n') return '\n'.join(parts) diff --git a/tensorflow/tools/graph_transforms/sparsify_gather.cc b/tensorflow/tools/graph_transforms/sparsify_gather.cc index 593c654f9fbbabe7e89c1ff2a43e56d30e8919d6..214ec721e2c9b8bdd761a1cb7a92a74f4a2a42a0 100644 --- a/tensorflow/tools/graph_transforms/sparsify_gather.cc +++ b/tensorflow/tools/graph_transforms/sparsify_gather.cc @@ -86,8 +86,17 @@ void CreateConstNode(const Tensor& tensor, const string& name, SetNodeTensorAttr("value", tensor, node_def); } +string GetMonolithicTensorKey(const string& tensor_slice_name) { + std::vector names = Split(tensor_slice_name, "/"); + if (StringPiece(names[names.size() - 1]).starts_with("part_")) { + CHECK_GE(names.size(), 2); + names.pop_back(); + } + return Join(names, "/"); +} + Status ObtainTensorSlice(const GraphDef& input_graph_def, - const string& tensor_name, + const string& target_name, string* shape_slice_string) { string restore_node_name; for (const auto& node : input_graph_def.node()) { @@ -95,39 +104,53 @@ Status ObtainTensorSlice(const GraphDef& input_graph_def, if (node_name_parts.size() == 2 && StringPiece(node_name_parts[0]).starts_with("save") && StringPiece(node_name_parts[1]).starts_with("Assign") && - node.input(0) == tensor_name) { + node.input(0) == target_name) { restore_node_name = node.input(1); break; } } + + std::vector restore_node_parts = Split(restore_node_name, ":"); + CHECK_LE(restore_node_parts.size(), 2); + string tensor_names_node; string shape_and_slices_node; for (const auto& node : input_graph_def.node()) { - if ((node.name() == restore_node_name) && (node.op() == "RestoreV2")) { + if ((node.name() == restore_node_parts[0]) && (node.op() == "RestoreV2")) { + tensor_names_node = node.input(1); shape_and_slices_node = node.input(2); break; } } + + int offset = -1; + for (const auto& node : input_graph_def.node()) { + if (node.name() == tensor_names_node) { + Tensor tensor_names_tensor; + TF_RETURN_IF_ERROR(GetNodeAttr(node, "value", &tensor_names_tensor)); + const auto& tensor_names_value = tensor_names_tensor.flat(); + for (int i = 0; i < tensor_names_value.size(); i++) { + if (tensor_names_value(i) == GetMonolithicTensorKey(target_name)) { + offset = i; + break; + } + } + } + } + if (offset == -1) { + return errors::Internal("Unable to find RestoreV2 entry for variable: ", + target_name); + } for (const auto& node : input_graph_def.node()) { if (node.name() == shape_and_slices_node) { Tensor shape_and_slices_tensor; TF_RETURN_IF_ERROR(GetNodeAttr(node, "value", &shape_and_slices_tensor)); const auto& shape_and_slices_value = shape_and_slices_tensor.flat(); - *shape_slice_string = shape_and_slices_value(0); + *shape_slice_string = shape_and_slices_value(offset); return Status::OK(); } } - return errors::Internal("Unable to find slice for variable: ", tensor_name); -} - -string GetMonolithicTensorKey(const string& tensor_slice_name) { - std::vector names = Split(tensor_slice_name, "/"); - CHECK_GE(names.size(), 2); - CHECK(StringPiece(names[names.size() - 1]).starts_with("part_")); - - // Remove the "part_x" suffix - names.pop_back(); - return Join(names, "/"); + return errors::Internal("Unable to find slice for variable: ", target_name); } Status ReadTensorFromCheckpoint( @@ -181,6 +204,14 @@ Status ObtainVariableInfo( return Status::OK(); } +Status RemoveInputAtIndex(NodeDef* n, int index) { + for (int i = index; i < n->input_size() - 1; i++) { + n->mutable_input()->SwapElements(i, i + 1); + } + n->mutable_input()->RemoveLast(); + return Status::OK(); +} + Status SparsifyGatherInternal( const GraphDef& input_graph_def, const std::unique_ptr >& @@ -301,13 +332,13 @@ Status SparsifyGatherInternal( TF_RETURN_IF_ERROR(ReadTensorFromCheckpoint( weights_node.name(), ckpt_reader, (*shapes_and_slices)[weights_node.name()], &weight)); - // Add both both weight and identity node names. - removed_node_names.push_back(weights_node.name()); - removed_node_names.push_back(match.inputs[0].node.name()); - for (auto input_node : match.inputs[0].node.input()) { - auto parsed_input = StringReplace(input_node, "^", "", true); - refs[parsed_input]--; - } + } + // Add both both weight and identity node names. + removed_node_names.push_back(weights_node.name()); + removed_node_names.push_back(match.inputs[0].node.name()); + for (auto input_node : match.inputs[0].node.input()) { + auto parsed_input = StringReplace(input_node, "^", "", true); + refs[parsed_input]--; } Tensor indices_tensor; Tensor values_tensor; @@ -468,26 +499,49 @@ Status SparsifyGatherInternal( continue; } int j = 0; + bool deleted_inputs = false; while (j < replaced_graph_def.node(i).input_size()) { if (replaced_graph_def.node(i).input(j) == name || replaced_graph_def.node(i).input(j) == ("^" + name)) { - replaced_graph_def.mutable_node(i)->mutable_input()->SwapElements( - j, replaced_graph_def.node(i).input_size() - 1); - replaced_graph_def.mutable_node(i)->mutable_input()->RemoveLast(); + TF_RETURN_IF_ERROR( + RemoveInputAtIndex(replaced_graph_def.mutable_node(i), j)); + deleted_inputs = true; continue; } j++; } - if (!replaced_graph_def.node(i).input_size()) { - if ((refs.find(replaced_graph_def.node(i).name()) != refs.end()) && - (refs[replaced_graph_def.node(i).name()] == 0)) { + if (deleted_inputs) { + if (replaced_graph_def.node(i).op() == "ConcatV2") { + if (replaced_graph_def.node(i).input_size() > 2) { + SetNodeAttr("N", replaced_graph_def.node(i).input_size() - 1, + replaced_graph_def.mutable_node(i)); + } else if (replaced_graph_def.node(i).input_size() == 2) { + if (refs[replaced_graph_def.node(i).input(1)] != 1) { + return errors::Internal( + "Expect axis tensor of ConcatV2 node to only be referenced " + "once."); + } + refs[replaced_graph_def.node(i).input(1)] -= 1; + removed_node_names.push_back(replaced_graph_def.node(i).input(1)); + replaced_graph_def.mutable_node(i)->mutable_input()->RemoveLast(); + replaced_graph_def.mutable_node(i)->mutable_attr()->erase("N"); + replaced_graph_def.mutable_node(i)->set_op("Identity"); + } else { + return errors::Internal( + "ConcatV2 should have at least two elements"); + } + } + if ((replaced_graph_def.node(i).op() == "Assign" || + replaced_graph_def.node(i).op() == "Reshape" || + replaced_graph_def.node(i).op() == "Equal" || + replaced_graph_def.node(i).op() == "Mean" || + replaced_graph_def.node(i).op() == "ScalarSummary") && + replaced_graph_def.node(i).input_size() == 1) { + removed_node_names.push_back(replaced_graph_def.node(i).name()); + } + if (!replaced_graph_def.node(i).input_size()) { removed_node_names.push_back(replaced_graph_def.node(i).name()); } - } - - if (replaced_graph_def.node(i).op() == "Assign" && - replaced_graph_def.node(i).input_size() == 1) { - removed_node_names.push_back(replaced_graph_def.node(i).name()); } i++; } @@ -528,17 +582,22 @@ Status SparsifyGather(const GraphDef& input_graph_def, }; // clang-format on + GraphDef cleaned_input_graph_def; + RemoveAttributes(input_graph_def, {"_output_shapes"}, + &cleaned_input_graph_def); + GraphDef temp_output; std::unique_ptr ckpt_reader; TF_RETURN_IF_ERROR(InitializeCheckpointReader(context, &ckpt_reader)); std::unique_ptr > shapes_and_slices; - TF_RETURN_IF_ERROR(ObtainVariableInfo(input_graph_def, &shapes_and_slices)); + TF_RETURN_IF_ERROR( + ObtainVariableInfo(cleaned_input_graph_def, &shapes_and_slices)); - TF_RETURN_IF_ERROR(SparsifyGatherInternal(input_graph_def, shapes_and_slices, - context, gather_pattern, - ckpt_reader, &temp_output)); + TF_RETURN_IF_ERROR(SparsifyGatherInternal( + cleaned_input_graph_def, shapes_and_slices, context, gather_pattern, + ckpt_reader, &temp_output)); TF_RETURN_IF_ERROR(SparsifyGatherInternal(temp_output, shapes_and_slices, context, gather_v2_pattern, diff --git a/tensorflow/tools/graph_transforms/sparsify_gather_test.cc b/tensorflow/tools/graph_transforms/sparsify_gather_test.cc index 6627df1331a6eaf49857c3ecba4a4d55859cad7c..d41321c9a6df755eed099ec453f162e2132cfb57 100644 --- a/tensorflow/tools/graph_transforms/sparsify_gather_test.cc +++ b/tensorflow/tools/graph_transforms/sparsify_gather_test.cc @@ -71,7 +71,7 @@ class SparsifyGatherTest : public ::testing::Test { } void TestSinglePartition(bool gather_v2, bool include_shared_init, - bool test_variable, + bool test_variable, bool test_kept_concat, const string& shared_init_name = "group_deps") { GraphDef graph_def; @@ -106,11 +106,15 @@ class SparsifyGatherTest : public ::testing::Test { NodeDef* save_const_node = CreateNode("save/Const", "Const", {}, &graph_def); + Tensor tensor_names_values(DT_STRING, TensorShape({1})); + test::FillValues(&tensor_names_values, {"w"}); NodeDef* tensor_names_node = CreateNode("save/RestoreV2/tensor_names", "Const", {}, &graph_def); + SetNodeTensorAttr("value", tensor_names_values, + tensor_names_node); + NodeDef* tensor_shapes_slices_node = CreateNode( "save/RestoreV2/shape_and_slices", "Const", {}, &graph_def); - Tensor shapes_slices_val(DT_STRING, TensorShape({1})); shapes_slices_val.flat()(0) = "4 1 0,4:0,1"; SetNodeTensorAttr("value", shapes_slices_val, @@ -139,6 +143,26 @@ class SparsifyGatherTest : public ::testing::Test { } } + NodeDef* concat_axis_node = + CreateNode("linear/concat/axis", "Const", {}, &graph_def); + NodeDef* concat_input_node = + CreateNode("concat/input/node", "Const", {}, &graph_def); + NodeDef* concat_node = nullptr; + if (!test_kept_concat) { + concat_node = CreateNode( + "concat/node", "ConcatV2", + {identity_node, concat_input_node, concat_axis_node}, &graph_def); + SetNodeAttr("N", 2, concat_node); + } else { + NodeDef* concat_input_node_2 = + CreateNode("concat/input/node_2", "Const", {}, &graph_def); + concat_node = CreateNode("concat/node", "ConcatV2", + {identity_node, concat_input_node, + concat_input_node_2, concat_axis_node}, + &graph_def); + SetNodeAttr("N", 3, concat_node); + } + // Run the op. GraphDef result; TransformFuncContext context; @@ -166,6 +190,23 @@ class SparsifyGatherTest : public ::testing::Test { EXPECT_EQ(1, node_lookup.count("ids")); EXPECT_EQ("Const", node_lookup.at("ids")->op()); + EXPECT_EQ(1, node_lookup.count("concat/node")); + + if (!test_kept_concat) { + EXPECT_EQ(0, node_lookup.count("linear/concat/axis")); + EXPECT_EQ("Identity", node_lookup.at("concat/node")->op()); + EXPECT_EQ(1, node_lookup.at("concat/node")->input_size()); + EXPECT_EQ("concat/input/node", node_lookup.at("concat/node")->input(0)); + } else { + EXPECT_EQ(1, node_lookup.count("linear/concat/axis")); + EXPECT_EQ("ConcatV2", node_lookup.at("concat/node")->op()); + EXPECT_EQ(3, node_lookup.at("concat/node")->input_size()); + EXPECT_EQ("concat/input/node", node_lookup.at("concat/node")->input(0)); + EXPECT_EQ("concat/input/node_2", node_lookup.at("concat/node")->input(1)); + EXPECT_EQ("linear/concat/axis", node_lookup.at("concat/node")->input(2)); + EXPECT_EQ(2, node_lookup.at("concat/node")->attr().at("N").i()); + } + EXPECT_EQ(1, node_lookup.count("w/part_1/indices")); EXPECT_EQ("Const", node_lookup.at("w/part_1/indices")->op()); Tensor expected_indices_tensor(DT_INT64, TensorShape({3})); @@ -273,6 +314,29 @@ class SparsifyGatherTest : public ::testing::Test { SetNodeTensorAttr("value", weights, w_node1); SetNodeTensorAttr("value", weights, w_node2); } else { + NodeDef* save_const_node = + CreateNode("save/Const", "Const", {}, &graph_def); + + NodeDef* tensor_names_node = + CreateNode("save/RestoreV2/tensor_names", "Const", {}, &graph_def); + Tensor tensor_names_values(DT_STRING, TensorShape({2})); + test::FillValues(&tensor_names_values, {"w1", "w2"}); + SetNodeTensorAttr("value", tensor_names_values, + tensor_names_node); + + NodeDef* tensor_shapes_slices_node = CreateNode( + "save/RestoreV2/shape_and_slices", "Const", {}, &graph_def); + Tensor shapes_slices_val(DT_STRING, TensorShape({2})); + shapes_slices_val.flat()(0) = "4 1 0,4:0,1"; + shapes_slices_val.flat()(1) = "4 1 0,4:0,1"; + SetNodeTensorAttr("value", shapes_slices_val, + tensor_shapes_slices_node); + + NodeDef* restore_node = CreateNode( + "save/RestoreV2", "RestoreV2", + {save_const_node, tensor_names_node, tensor_shapes_slices_node}, + &graph_def); + w_node1 = CreateNode("w1/part_1", "VariableV2", {}, &graph_def); zeros_shape1 = CreateNode("w1/part_1/Initializer/zeros/shape_as_tensor", @@ -284,23 +348,7 @@ class SparsifyGatherTest : public ::testing::Test { assign_node1 = CreateNode("w1/part_1/Assign", "Assign", {w_node1, zeros_node1}, &graph_def); - NodeDef* save_const_node = - CreateNode("save/Const", "Const", {}, &graph_def); - NodeDef* tensor_names_node1 = - CreateNode("save/RestoreV2/tensor_names", "Const", {}, &graph_def); - NodeDef* tensor_shapes_slices_node1 = CreateNode( - "save/RestoreV2/shape_and_slices", "Const", {}, &graph_def); - - Tensor shapes_slices_val1(DT_STRING, TensorShape({1})); - shapes_slices_val1.flat()(0) = "4 1 0,4:0,1"; - SetNodeTensorAttr("value", shapes_slices_val1, - tensor_shapes_slices_node1); - - NodeDef* restore_node1 = CreateNode( - "save/RestoreV2", "RestoreV2", - {save_const_node, tensor_names_node1, tensor_shapes_slices_node1}, - &graph_def); - CreateNode("save/Assign", "Assign", {w_node1, restore_node1}, &graph_def); + CreateNode("save/Assign", "Assign", {w_node1, restore_node}, &graph_def); w_node2 = CreateNode("w2/part_1", "VariableV2", {}, &graph_def); zeros_shape2 = CreateNode("w2/part_1/Initializer/zeros/shape_as_tensor", @@ -312,21 +360,7 @@ class SparsifyGatherTest : public ::testing::Test { assign_node2 = CreateNode("w2/part_1/Assign", "Assign", {w_node2, zeros_node2}, &graph_def); - NodeDef* tensor_names_node2 = - CreateNode("save/RestoreV2_1/tensor_names", "Const", {}, &graph_def); - NodeDef* tensor_shapes_slices_node2 = CreateNode( - "save/RestoreV2_1/shape_and_slices", "Const", {}, &graph_def); - - Tensor shapes_slices_val2(DT_STRING, TensorShape({1})); - shapes_slices_val2.flat()(0) = "4 1 0,4:0,1"; - SetNodeTensorAttr("value", shapes_slices_val2, - tensor_shapes_slices_node2); - - NodeDef* restore_node2 = CreateNode( - "save/RestoreV2_1", "RestoreV2", - {save_const_node, tensor_names_node2, tensor_shapes_slices_node2}, - &graph_def); - CreateNode("save/Assign_1", "Assign", {w_node2, restore_node2}, + CreateNode("save/Assign_1", "Assign", {w_node2, restore_node}, &graph_def); BundleWriter writer(Env::Default(), checkpoint_path); @@ -344,6 +378,13 @@ class SparsifyGatherTest : public ::testing::Test { MakeGather("gather1", gather_v2, identity_node1, input_node, &graph_def); MakeGather("gather2", gather_v2, identity_node2, input_node, &graph_def); + NodeDef* concat_axis_node = + CreateNode("linear/concat/axis", "Const", {}, &graph_def); + NodeDef* concat_node = CreateNode( + "concat/node", "ConcatV2", + {identity_node1, identity_node2, concat_axis_node}, &graph_def); + SetNodeAttr("N", 2, concat_node); + // Shared init node if (include_shared_init) { if (!test_variable) { @@ -515,6 +556,9 @@ class SparsifyGatherTest : public ::testing::Test { node_lookup.at("gather2/LookupTableFind")->input(2)); EXPECT_EQ("gather2/LookupTableFind", node_lookup.at("gather2")->input(0)); + EXPECT_EQ(0, node_lookup.count("linear/concat/axis")); + EXPECT_EQ(0, node_lookup.count("concat/node")); + // Check control deps. EXPECT_EQ(2, node_lookup.at(shared_init_name)->input_size()); EXPECT_NE(std::find(node_lookup.at(shared_init_name)->input().begin(), @@ -550,18 +594,31 @@ class SparsifyGatherTest : public ::testing::Test { }; TEST_F(SparsifyGatherTest, TestSinglePartition) { - TestSinglePartition(false, false, false); - TestSinglePartition(false, true, false); - TestSinglePartition(true, false, false); - TestSinglePartition(true, true, false); - TestSinglePartition(false, false, true); - TestSinglePartition(false, true, true); - TestSinglePartition(true, false, true); - TestSinglePartition(true, true, true); - TestSinglePartition(false, true, false, "shared_inits"); - TestSinglePartition(true, true, false, "shared_inits"); - TestSinglePartition(false, true, true, "shared_inits"); - TestSinglePartition(true, true, true, "shared_inits"); + TestSinglePartition(false, false, false, false); + TestSinglePartition(false, true, false, false); + TestSinglePartition(true, false, false, false); + TestSinglePartition(true, true, false, false); + TestSinglePartition(false, false, true, false); + TestSinglePartition(false, true, true, false); + TestSinglePartition(true, false, true, false); + TestSinglePartition(true, true, true, false); + TestSinglePartition(false, true, false, false, "shared_inits"); + TestSinglePartition(true, true, false, false, "shared_inits"); + TestSinglePartition(false, true, true, false, "shared_inits"); + TestSinglePartition(true, true, true, false, "shared_inits"); + + TestSinglePartition(false, false, false, true); + TestSinglePartition(false, true, false, true); + TestSinglePartition(true, false, false, true); + TestSinglePartition(true, true, false, true); + TestSinglePartition(false, false, true, true); + TestSinglePartition(false, true, true, true); + TestSinglePartition(true, false, true, true); + TestSinglePartition(true, true, true, true); + TestSinglePartition(false, true, false, true, "shared_inits"); + TestSinglePartition(true, true, false, true, "shared_inits"); + TestSinglePartition(false, true, true, true, "shared_inits"); + TestSinglePartition(true, true, true, true, "shared_inits"); } TEST_F(SparsifyGatherTest, TestMultiPartition) { diff --git a/tensorflow/tools/lib_package/BUILD b/tensorflow/tools/lib_package/BUILD index dbc81599de8539ce58933f9d40bf99fcae8f8e67..7717d8d7de27e827aab5208404f2e2275d60c8d3 100644 --- a/tensorflow/tools/lib_package/BUILD +++ b/tensorflow/tools/lib_package/BUILD @@ -99,6 +99,7 @@ genrule( "//third_party/hadoop:LICENSE.txt", "//third_party/eigen3:LICENSE", "//third_party/fft2d:LICENSE", + "@aws//:LICENSE", "@boringssl//:LICENSE", "@com_googlesource_code_re2//:LICENSE", "@cub_archive//:LICENSE.TXT", @@ -114,6 +115,7 @@ genrule( "@libxsmm_archive//:LICENSE", "@lmdb//:LICENSE", "@local_config_sycl//sycl:LICENSE.text", + "@nasm//:LICENSE", "@nsync//:LICENSE", "@png_archive//:LICENSE", "@protobuf_archive//:LICENSE", @@ -134,6 +136,7 @@ genrule( "//third_party/hadoop:LICENSE.txt", "//third_party/eigen3:LICENSE", "//third_party/fft2d:LICENSE", + "@aws//:LICENSE", "@boringssl//:LICENSE", "@com_googlesource_code_re2//:LICENSE", "@cub_archive//:LICENSE.TXT", @@ -149,6 +152,7 @@ genrule( "@libxsmm_archive//:LICENSE", "@lmdb//:LICENSE", "@local_config_sycl//sycl:LICENSE.text", + "@nasm//:LICENSE", "@nsync//:LICENSE", "@png_archive//:LICENSE", "@protobuf_archive//:LICENSE", diff --git a/tensorflow/tools/pip_package/BUILD b/tensorflow/tools/pip_package/BUILD index 598080ed2753b862056ebcc76c4c572ae45b46e6..a9c4a8de42a7633b09985cdd4470495c2c4749e2 100644 --- a/tensorflow/tools/pip_package/BUILD +++ b/tensorflow/tools/pip_package/BUILD @@ -88,13 +88,20 @@ filegroup( "//third_party/eigen3:LICENSE", "//third_party/fft2d:LICENSE", "//third_party/hadoop:LICENSE.txt", + "@absl_py//absl/flags:LICENSE", + "@arm_neon_2_x86_sse//:LICENSE", + "@astor_archive//:LICENSE", + "@aws//:LICENSE", "@boringssl//:LICENSE", + "@com_google_absl//:LICENSE", "@com_googlesource_code_re2//:LICENSE", "@cub_archive//:LICENSE.TXT", "@curl//:COPYING", "@eigen_archive//:COPYING.MPL2", "@farmhash_archive//:COPYING", "@fft2d//:fft/readme.txt", + "@flatbuffers//:LICENSE.txt", + "@gast_archive//:PKG-INFO", "@gemmlowp//:LICENSE", "@gif_archive//:COPYING", "@grpc//:LICENSE", @@ -105,11 +112,15 @@ filegroup( "@lmdb//:LICENSE", "@local_config_sycl//sycl:LICENSE.text", "@grpc//third_party/nanopb:LICENSE.txt", + "@nasm//:LICENSE", "@nsync//:LICENSE", + "@pcre//:LICENCE", "@png_archive//:LICENSE", "@protobuf_archive//:LICENSE", "@six_archive//:LICENSE", "@snappy//:COPYING", + "@swig//:LICENSE", + "@termcolor_archive//:COPYING.txt", "@zlib_archive//:zlib.h", "@org_python_pypi_backports_weakref//:LICENSE", ] + if_mkl([ @@ -151,9 +162,10 @@ sh_binary( "//tensorflow/contrib/ndlstm:ndlstm", "//tensorflow/contrib/nn:nn_py", "//tensorflow/contrib/predictor:predictor_pip", - "//tensorflow/contrib/py2tf:py2tf_internal", + "//tensorflow/contrib/py2tf:py2tf", "//tensorflow/contrib/py2tf/converters:converters", "//tensorflow/contrib/py2tf/converters:test_lib", + "//tensorflow/contrib/py2tf/impl:impl", "//tensorflow/contrib/py2tf/pyct:pyct", "//tensorflow/contrib/py2tf/pyct/static_analysis:static_analysis", "//tensorflow/contrib/receptive_field:receptive_field_pip", diff --git a/tensorflow/tools/pip_package/pip_smoke_test.py b/tensorflow/tools/pip_package/pip_smoke_test.py index 38a900738786e2413f5b1dd914caaebeafc92e21..73d759eb130633094b402c821cc32eb76c076a44 100644 --- a/tensorflow/tools/pip_package/pip_smoke_test.py +++ b/tensorflow/tools/pip_package/pip_smoke_test.py @@ -65,7 +65,6 @@ BLACKLIST = [ "//tensorflow/contrib/framework:checkpoint_ops_testdata", "//tensorflow/contrib/bayesflow:reinforce_simple_example", "//tensorflow/contrib/bayesflow:examples/reinforce_simple/reinforce_simple_example.py", # pylint:disable=line-too-long - "//tensorflow/contrib/py2tf:py2tf_internal", "//tensorflow/contrib/timeseries/examples:predict", "//tensorflow/contrib/timeseries/examples:multivariate", "//tensorflow/contrib/timeseries/examples:known_anomaly", diff --git a/tensorflow/tools/pip_package/setup.py b/tensorflow/tools/pip_package/setup.py index 62df6453fb5d39728c2985a28a70a263d79804b1..20027869990013098c405b4707318a3ce63000fc 100644 --- a/tensorflow/tools/pip_package/setup.py +++ b/tensorflow/tools/pip_package/setup.py @@ -29,16 +29,17 @@ from setuptools.dist import Distribution # This version string is semver compatible, but incompatible with pip. # For pip, we will remove all '-' characters from this string, and use the # result for pip. -_VERSION = '1.5.0-rc1' +_VERSION = '1.6.0-rc0' REQUIRED_PACKAGES = [ 'absl-py >= 0.1.6', 'astor >= 0.6.0', 'gast >= 0.2.0', + 'grpcio >= 1.8.6', 'numpy >= 1.12.1', 'six >= 1.10.0', 'protobuf >= 3.4.0', - 'tensorflow-tensorboard >= 0.4.0', + 'tensorflow-tensorboard >= 1.5.0, < 1.6.0', 'termcolor >= 1.1.0', ] @@ -79,13 +80,13 @@ CONSOLE_SCRIPTS = [ # is now declared by the tensorboard pip package. If we remove the # TensorBoard command, pip will inappropriately remove it during install, # even though the command is not removed, just moved to a different wheel. - 'tensorboard = tensorboard.main:main', + 'tensorboard = tensorboard.main:run_main', ] # pylint: enable=line-too-long # remove the tensorboard console script if building tf_nightly if 'tf_nightly' in project_name: - CONSOLE_SCRIPTS.remove('tensorboard = tensorboard.main:main') + CONSOLE_SCRIPTS.remove('tensorboard = tensorboard.main:run_main') TEST_PACKAGES = [ 'scipy >= 0.15.1', diff --git a/tensorflow/tools/test/file_name_test.py b/tensorflow/tools/test/file_name_test.py new file mode 100644 index 0000000000000000000000000000000000000000..16fb8a822d09ed136cf79dd2473fc202ca632d83 --- /dev/null +++ b/tensorflow/tools/test/file_name_test.py @@ -0,0 +1,48 @@ +#!/usr/bin/python +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# +# Test that checks if we have any issues with case insensitive filesystems. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +BASE_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')) +ERROR_MESSAGE = """ +Files with same name but different case detected in directory: {} +""" + + +def main(): + # Make sure BASE_DIR ends with tensorflow. If it doesn't, we probably + # computed the wrong directory. + if os.path.split(BASE_DIR)[-1] != 'tensorflow': + raise AssertionError( + "BASE_DIR = '%s' doesn't end with tensorflow" % BASE_DIR) + + for dirpath, dirnames, filenames in os.walk(BASE_DIR, followlinks=True): + lowercase_directories = [x.lower() for x in dirnames] + lowercase_files = [x.lower() for x in filenames] + + lowercase_dir_contents = lowercase_directories + lowercase_files + if len(lowercase_dir_contents) != len(set(lowercase_dir_contents)): + raise AssertionError(ERROR_MESSAGE.format(dirpath)) + + +if __name__ == '__main__': + main() diff --git a/tensorflow/tools/test/run_and_gather_logs_lib.py b/tensorflow/tools/test/run_and_gather_logs_lib.py index a953ed1b53d13504f92d2ffeb4c1ac6bcb0b8477..3b4921bb983a72223b092d99eb3fb59332fc6345 100644 --- a/tensorflow/tools/test/run_and_gather_logs_lib.py +++ b/tensorflow/tools/test/run_and_gather_logs_lib.py @@ -136,7 +136,7 @@ def run_and_gather_logs(name, test_name, test_args, gpu_config = gpu_info_lib.gather_gpu_devices() if gpu_config: gpu_name = gpu_config[0].model - gpu_short_name_match = re.search(r"Tesla (K40|K80|P100)", gpu_name) + gpu_short_name_match = re.search(r"Tesla (K40|K80|P100|V100)", gpu_name) if gpu_short_name_match: gpu_short_name = gpu_short_name_match.group(0) test_adjusted_name = name + "|" + gpu_short_name.replace(" ", "_") diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl index c88727b2659e03a054757655038a9eec73d1eadb..eca744a920c8dacdad40d618484de87931e6d3c7 100644 --- a/tensorflow/workspace.bzl +++ b/tensorflow/workspace.bzl @@ -114,16 +114,17 @@ def tf_workspace(path_prefix="", tf_repo_name=""): ], sha256 = "5996380e3e8b981f55d1c8d58e709c00dbb4806ba367be75d0925a68cc2f6478", strip_prefix = "abseil-cpp-720c017e30339fd1786ce4aac68bc8559736e53f", + build_file = str(Label("//third_party:com_google_absl.BUILD")), ) tf_http_archive( name = "eigen_archive", urls = [ - "https://mirror.bazel.build/bitbucket.org/eigen/eigen/get/14e1418fcf12.tar.gz", - "https://bitbucket.org/eigen/eigen/get/14e1418fcf12.tar.gz", + "https://mirror.bazel.build/bitbucket.org/eigen/eigen/get/2355b229ea4c.tar.gz", + "https://bitbucket.org/eigen/eigen/get/2355b229ea4c.tar.gz", ], - sha256 = "2b526c6888639025323fd4f2600533c0f982d304ea48e4f1663e8066bd9f6368", - strip_prefix = "eigen-eigen-14e1418fcf12", + sha256 = "0cadb31a35b514bf2dfd6b5d38205da94ef326ec6908fc3fd7c269948467214f", + strip_prefix = "eigen-eigen-2355b229ea4c", build_file = str(Label("//third_party:eigen.BUILD")), ) @@ -352,11 +353,11 @@ def tf_workspace(path_prefix="", tf_repo_name=""): tf_http_archive( name = "protobuf_archive", urls = [ - "https://mirror.bazel.build/github.com/google/protobuf/archive/b04e5cba356212e4e8c66c61bbe0c3a20537c5b9.tar.gz", - "https://github.com/google/protobuf/archive/b04e5cba356212e4e8c66c61bbe0c3a20537c5b9.tar.gz", + "https://mirror.bazel.build/github.com/google/protobuf/archive/396336eb961b75f03b25824fe86cf6490fb75e3a.tar.gz", + "https://github.com/google/protobuf/archive/396336eb961b75f03b25824fe86cf6490fb75e3a.tar.gz", ], - sha256 = "e178a25c52efcb6b05988bdbeace4c0d3f2d2fe5b46696d1d9898875c3803d6a", - strip_prefix = "protobuf-b04e5cba356212e4e8c66c61bbe0c3a20537c5b9", + sha256 = "846d907acf472ae233ec0882ef3a2d24edbbe834b80c305e867ac65a1f2c59e3", + strip_prefix = "protobuf-396336eb961b75f03b25824fe86cf6490fb75e3a", ) # We need to import the protobuf library under the names com_google_protobuf @@ -365,21 +366,21 @@ def tf_workspace(path_prefix="", tf_repo_name=""): tf_http_archive( name = "com_google_protobuf", urls = [ - "https://mirror.bazel.build/github.com/google/protobuf/archive/b04e5cba356212e4e8c66c61bbe0c3a20537c5b9.tar.gz", - "https://github.com/google/protobuf/archive/b04e5cba356212e4e8c66c61bbe0c3a20537c5b9.tar.gz", + "https://mirror.bazel.build/github.com/google/protobuf/archive/396336eb961b75f03b25824fe86cf6490fb75e3a.tar.gz", + "https://github.com/google/protobuf/archive/396336eb961b75f03b25824fe86cf6490fb75e3a.tar.gz", ], - sha256 = "e178a25c52efcb6b05988bdbeace4c0d3f2d2fe5b46696d1d9898875c3803d6a", - strip_prefix = "protobuf-b04e5cba356212e4e8c66c61bbe0c3a20537c5b9", + sha256 = "846d907acf472ae233ec0882ef3a2d24edbbe834b80c305e867ac65a1f2c59e3", + strip_prefix = "protobuf-396336eb961b75f03b25824fe86cf6490fb75e3a", ) tf_http_archive( name = "com_google_protobuf_cc", urls = [ - "https://mirror.bazel.build/github.com/google/protobuf/archive/b04e5cba356212e4e8c66c61bbe0c3a20537c5b9.tar.gz", - "https://github.com/google/protobuf/archive/b04e5cba356212e4e8c66c61bbe0c3a20537c5b9.tar.gz", + "https://mirror.bazel.build/github.com/google/protobuf/archive/396336eb961b75f03b25824fe86cf6490fb75e3a.tar.gz", + "https://github.com/google/protobuf/archive/396336eb961b75f03b25824fe86cf6490fb75e3a.tar.gz", ], - sha256 = "e178a25c52efcb6b05988bdbeace4c0d3f2d2fe5b46696d1d9898875c3803d6a", - strip_prefix = "protobuf-b04e5cba356212e4e8c66c61bbe0c3a20537c5b9", + sha256 = "846d907acf472ae233ec0882ef3a2d24edbbe834b80c305e867ac65a1f2c59e3", + strip_prefix = "protobuf-396336eb961b75f03b25824fe86cf6490fb75e3a", ) tf_http_archive( @@ -472,11 +473,11 @@ def tf_workspace(path_prefix="", tf_repo_name=""): tf_http_archive( name = "llvm", urls = [ - "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/11a2ca6eea8a7fe240a14c0c35fd2017341279be.tar.gz", - "https://github.com/llvm-mirror/llvm/archive/11a2ca6eea8a7fe240a14c0c35fd2017341279be.tar.gz", + "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/299f8c346e1ab483463da5f02536ffd00b7ad9c6.tar.gz", + "https://github.com/llvm-mirror/llvm/archive/299f8c346e1ab483463da5f02536ffd00b7ad9c6.tar.gz", ], - sha256 = "b5429ccf8d57273cb8489714f728c997cd720ec66fc2c0292422ab8f0e729ce0", - strip_prefix = "llvm-11a2ca6eea8a7fe240a14c0c35fd2017341279be", + sha256 = "0556bc6a85000c573d92fe00946b6418cbcd3844912696a81055e4768299dda4", + strip_prefix = "llvm-299f8c346e1ab483463da5f02536ffd00b7ad9c6", build_file = str(Label("//third_party/llvm:llvm.BUILD")), ) diff --git a/third_party/com_google_absl.BUILD b/third_party/com_google_absl.BUILD new file mode 100644 index 0000000000000000000000000000000000000000..8fca145f751eacfa3e5a0af046dcc8c19e6a85d4 --- /dev/null +++ b/third_party/com_google_absl.BUILD @@ -0,0 +1,5 @@ +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # Apache + +exports_files(["LICENSE"]) diff --git a/third_party/flatbuffers/flatbuffers.BUILD b/third_party/flatbuffers/flatbuffers.BUILD index f6b8e6ddb05e67a4bb4833a3bba6db3cbd4c79e0..824c97be60e7ef148a363b964ed330ba3c5fcb0c 100644 --- a/third_party/flatbuffers/flatbuffers.BUILD +++ b/third_party/flatbuffers/flatbuffers.BUILD @@ -4,6 +4,8 @@ package( licenses(["notice"]) # Apache 2.0 +exports_files(["LICENSE.txt"]) + config_setting( name = "freebsd", values = {"cpu": "freebsd"}, diff --git a/third_party/gast.BUILD b/third_party/gast.BUILD index 06db528ada27e2f26f6de48c1ce6e9348ce09173..4866982e1fda6d6f19e575c8b0c0273cb9de154b 100644 --- a/third_party/gast.BUILD +++ b/third_party/gast.BUILD @@ -3,7 +3,7 @@ licenses(["notice"]) # BSD 3-clause -exports_files(["LICENSE"]) +exports_files(["PKG-INFO"]) py_library( name = "gast", diff --git a/third_party/gpus/cuda_configure.bzl b/third_party/gpus/cuda_configure.bzl index 8e1dd8a54f53f58a367af21ef1c17f2695431bad..255ae0119095ee17babc00f43e93fa4c4931c1fb 100644 --- a/third_party/gpus/cuda_configure.bzl +++ b/third_party/gpus/cuda_configure.bzl @@ -826,7 +826,7 @@ def symlink_genrule_for_dir(repository_ctx, src_dir, dest_dir, genrule_name, if src_dir != None: src_dir = _norm_path(src_dir) dest_dir = _norm_path(dest_dir) - files = _read_dir(repository_ctx, src_dir) + files = '\n'.join(sorted(_read_dir(repository_ctx, src_dir).splitlines())) # Create a list with the src_dir stripped to use for outputs. dest_files = files.replace(src_dir, '').splitlines() src_files = files.splitlines() diff --git a/third_party/llvm/llvm.BUILD b/third_party/llvm/llvm.BUILD index 5344525ba8b42e8a3dbcf42397458d190a77f9d3..a9e1341a03c2e725e96bd7c8cbd7b09853bb8af4 100644 --- a/third_party/llvm/llvm.BUILD +++ b/third_party/llvm/llvm.BUILD @@ -670,6 +670,28 @@ cc_library( ], ) +cc_library( + name = "aggressive_inst_combine", + srcs = glob([ + "lib/Transforms/AggressiveInstCombine/*.c", + "lib/Transforms/AggressiveInstCombine/*.cpp", + "lib/Transforms/AggressiveInstCombine/*.inc", + "lib/Transforms/AggressiveInstCombine/*.h", + ]), + hdrs = glob([ + "include/llvm/Transforms/AggressiveInstCombine/*.h", + "include/llvm/Transforms/AggressiveInstCombine/*.def", + "include/llvm/Transforms/AggressiveInstCombine/*.inc", + ]), + deps = [ + ":analysis", + ":config", + ":core", + ":support", + ":transform_utils", + ], +) + cc_library( name = "analysis", srcs = glob([ @@ -1405,6 +1427,7 @@ cc_library( "include/llvm/Transforms/IPO/*.inc", ]), deps = [ + ":aggressive_inst_combine", ":analysis", ":bit_reader", ":bit_writer", @@ -1931,6 +1954,7 @@ cc_library( "include/llvm/Transforms/IPO/SCCP.h", ]), deps = [ + ":aggressive_inst_combine", ":analysis", ":config", ":core", diff --git a/third_party/pcre.BUILD b/third_party/pcre.BUILD index e2cdec40295d369548ff26e3493b5d2300041916..3a8e7a10b43debb5eeca690a64d5795de998a3ac 100644 --- a/third_party/pcre.BUILD +++ b/third_party/pcre.BUILD @@ -1,6 +1,6 @@ licenses(["notice"]) # BSD -exports_files(["COPYING"]) +exports_files(["LICENCE"]) cc_library( name = "pcre", diff --git a/third_party/py/python_configure.bzl b/third_party/py/python_configure.bzl index c16eb3a12a86f3c2eb3813f5c8c7631fec8e97c6..954f21f5f8fe8029c869f8870464a750cfc8a3db 100644 --- a/third_party/py/python_configure.bzl +++ b/third_party/py/python_configure.bzl @@ -118,7 +118,7 @@ def _symlink_genrule_for_dir(repository_ctx, src_dir, dest_dir, genrule_name, if src_dir != None: src_dir = _norm_path(src_dir) dest_dir = _norm_path(dest_dir) - files = _read_dir(repository_ctx, src_dir) + files = '\n'.join(sorted(_read_dir(repository_ctx, src_dir).splitlines())) # Create a list with the src_dir stripped to use for outputs. dest_files = files.replace(src_dir, '').splitlines() src_files = files.splitlines() diff --git a/third_party/repo.bzl b/third_party/repo.bzl index 11e9c842d2f0c8deb123d7b13d85865b089d73d7..aa178fa8cab92d9d299e5ed09927d8572816a0af 100644 --- a/third_party/repo.bzl +++ b/third_party/repo.bzl @@ -27,7 +27,7 @@ def _wrap_bash_cmd(ctx, cmd): bazel_sh = _get_env_var(ctx, "BAZEL_SH") if not bazel_sh: fail("BAZEL_SH environment variable is not set") - cmd = [bazel_sh, "-c", " ".join(cmd)] + cmd = [bazel_sh, "-l", "-c", " ".join(cmd)] return cmd def _get_env_var(ctx, name): diff --git a/third_party/termcolor.BUILD b/third_party/termcolor.BUILD index 6000e3289deff8183193883a9b796da9384365b8..655d7cb85e584027d12014c53718a15e2522b4ae 100644 --- a/third_party/termcolor.BUILD +++ b/third_party/termcolor.BUILD @@ -3,7 +3,7 @@ licenses(["notice"]) # MIT -exports_files(["LICENSE"]) +exports_files(["COPYING.txt"]) py_library( name = "termcolor",
Version:CPU/GPU:Python Version:Compiler:Build Tools:cuDNN:CUDA:
tensorflow-1.5.0-rc1CPU3.5-3.6MSVC 2015 update 3Cmake v3.6.3N/AN/A
tensorflow_gpu-1.5.0-rc1GPU3.5-3.6MSVC 2015 update 3Cmake v3.6.379
tensorflow-1.6.0rc0CPU3.5-3.6MSVC 2015 update 3Cmake v3.6.3N/AN/A
tensorflow_gpu-1.6.0rc0GPU3.5-3.6MSVC 2015 update 3Cmake v3.6.379
tensorflow-1.5.0CPU3.5-3.6MSVC 2015 update 3Cmake v3.6.3N/AN/A
tensorflow_gpu-1.5.0GPU3.5-3.6MSVC 2015 update 3Cmake v3.6.379
tensorflow-1.4.0CPU3.5-3.6MSVC 2015 update 3Cmake v3.6.3N/AN/A
tensorflow_gpu-1.4.0GPU3.5-3.6MSVC 2015 update 3Cmake v3.6.368
tensorflow-1.3.0CPU3.5-3.6MSVC 2015 update 3Cmake v3.6.3N/AN/A